Source code for fm_weck.zenodo

# This file is part of fm-weck: executing fm-tools in containerized environments.
# https://gitlab.com/sosy-lab/software/fm-weck
#
# SPDX-FileCopyrightText: 2024 Dirk Beyer <https://www.sosy-lab.org>
#
# SPDX-License-Identifier: Apache-2.0
import json
import logging
import os
from datetime import datetime
from pathlib import Path

import httpx
from dotenv import find_dotenv, load_dotenv, set_key
from tqdm import tqdm

from fm_weck.config import Config
from fm_weck.exceptions import ZenodoError
from fm_weck.image_mgr import ImageMgr


[docs] class TokenAuth(httpx.Auth): def __init__(self, token: str): self.token = token
[docs] def auth_flow(self, request): request.headers["Authorization"] = f"Bearer {self.token}" yield request
[docs] class ZenodoSession: def __init__(self, access_token: str | None = None, is_sandbox: bool = False): if is_sandbox: self.base_url = "https://sandbox.zenodo.org/api" else: self.base_url = "https://zenodo.org/api" self.deposit_url = self.base_url + "/deposit/depositions" self.download_url = self.base_url + "/records" self.access_token = access_token if self.access_token is not None: self.client = httpx.Client(auth=TokenAuth(self.access_token)) else: self.client = httpx.Client() def _require_access_token(self): if self.access_token is None: raise ZenodoError("An API access token is required to interact with the Zenodo API.")
[docs] def publish(self, image_info: tuple): """ Publish an image on Zenodo. """ self._require_access_token() deposition_id = self.upload(image_info) publish_url = self.deposit_url + "/" + deposition_id + "/" + "actions/publish" r = httpx.post(publish_url) if r.status_code >= 400: self.delete_deposition(deposition_id) raise ZenodoError(f"Error trying to publish an image: {r.json()['message']}") return r.json()["doi"]
[docs] def upload(self, image_info: tuple[str, str, Path]) -> str: """ Upload an image to Zenodo and leave it in draft mode (unpublished). """ self._require_access_token() image_name = image_info[0] image_version = image_info[1] image_path = image_info[2] # Upload the actual file upload_name = image_name + "_" + image_version + ".tar" upload_name = upload_name.replace("/", "-") metadata = self.get_upload_metadata(upload_name, image_version) record_id = self.check_existing_records(metadata) logging.debug(f"Record ID: {record_id}") # Check if we are adding a new version or creating a new record if record_id: r = self.client.post(self.deposit_url + "/" + record_id + "/actions/newversion") else: logging.debug(f"{metadata=}") r = self.client.post( self.deposit_url, json=metadata, headers={"Content-Type": "application/json", "Authorization": f"Bearer {self.access_token}"}, ) if r.status_code >= 400: logging.debug(f"{r.status_code} - {r.text}") raise ZenodoError(f"Error trying to upload the deposition metadata to Zenodo: {r.json()['message']}") bucket_url = r.json()["links"]["bucket"] deposition_id = str(r.json()["id"]) if record_id: self.prepare_new_version_upload(metadata, deposition_id) r = self.upload_with_progress_bar(image_path, upload_name, bucket_url) if r.status_code >= 400: self.delete_deposition(deposition_id) raise ZenodoError(f"Error trying to upload image to Zenodo: {r.json()['message']}") return deposition_id
[docs] def upload_with_progress_bar(self, image_path, upload_name, bucket_url): file_size = os.path.getsize(image_path) with ( open(image_path, "rb") as image, tqdm(total=file_size, unit="B", unit_scale=True, desc="Uploading") as progress_bar, ): def file_iterator(): chunk_size = 1024 * 1024 * 5 # 5 MB while True: chunk = image.read(chunk_size) if not chunk: break yield chunk progress_bar.update(len(chunk)) r = self.client.put( f"{bucket_url}/{upload_name}", data=file_iterator(), headers={"Content-Length": str(file_size)}, timeout=300, ) return r
[docs] def download(self, doi: str, download_path: str | None, no_progress_bar: bool = False) -> Path: """ Download an image from Zenodo using its DOI. """ download_url = self.resolve_download_info(doi) image_name = download_url.split("/")[-2] + ".tar" if no_progress_bar: r = self.client.get(download_url) if r.status_code >= 400: raise ZenodoError(f"Error trying to download image: {r.json()['message']}") content = r.content else: content = self.download_with_progress_bar(download_url) return self.save_image(image_name, content, download_path)
[docs] def download_with_progress_bar(self, url: str): content = bytearray() with self.client.stream("GET", url) as response: response.raise_for_status() total_size = int(response.headers.get("Content-Length", 0)) chunk_size = 1024 * 1024 * 5 # 5 MB with tqdm(total=total_size, unit="B", unit_scale=True, desc="Downloading") as progress_bar: for chunk in response.iter_bytes(chunk_size): content.extend(chunk) progress_bar.update(len(chunk)) return content
[docs] def resolve_download_info(self, doi: str) -> str: """ Extract the download URL from the deposition metadata. :return: The download URL of the image. """ doi_id = doi.split(".")[-1] record_url = self.download_url + "/" + doi_id r = self.client.get(record_url) if r.status_code >= 400: raise ZenodoError(f"Error trying to access the image record: {r.json()['message']}") file_name = r.json()["files"][0]["key"] return record_url + "/files/" + file_name + "/content"
[docs] def save_image(self, image_name: str, image: bytes, save_path): save_path = Path().cwd() / image_name if not save_path else Path().cwd() / Path(save_path) / image_name save_path.parent.mkdir(parents=True, exist_ok=True) save_path.write_bytes(image) return save_path
[docs] def delete_deposition(self, deposition_id: str): r = self.client.delete(self.deposit_url + "/" + deposition_id) if r.status_code >= 400: raise ZenodoError(f"Error trying to delete a deposition resource from Zenodo: {r.json()['message']}")
[docs] def check_existing_records(self, metadata: dict): """ Check if a record with the same name already exists in the depositions. If it exists, prompt whether the image should be uploaded as a new version. :return: The deposition ID in case of a positive prompt response, or an empty string. """ r = self.client.get(self.deposit_url) if r.status_code >= 400: raise ZenodoError(f"Error trying to access the image record: {r.json()['message']}") for record in r.json(): if metadata["metadata"]["title"] == record["title"]: while True: response = ( input( ( "The record already exists. Do you want to upload this as a new version " "of the existing record?\n" "(y)es: upload as a new version of an existing record\n" "(n)o: create a new record\n" "(a)bort: cancel the upload\n" ) ) .strip() .lower() ) if response == "y": return str(record["id"]) elif response == "n": break elif response == "a": exit(0) return ""
[docs] def prepare_new_version_upload(self, metadata: dict, deposition_id: str) -> None: """ Set up the metadata for the new version and delete the old versions' files from that version. """ r = self.client.put( self.deposit_url + "/" + deposition_id, data=json.dumps(metadata), # type: ignore ) if r.status_code >= 400: raise ZenodoError(f"Error trying to upload the new version metadata to Zenodo: {r.json()['message']}") for file in r.json()["files"]: r = self.client.delete( self.deposit_url + "/" + deposition_id + "/files/" + file["id"], ) if r.status_code >= 400: raise ZenodoError( f"Error trying to delete old version files from the new version: {r.json()['message']}" )
[docs] def get_upload_metadata(self, image_name: str, image_version: str) -> dict: description = ( f"This is an image file.\n" f"<p>In case of manual download, the image can be imported locally using the command: " f"<code>[docker|podman] load --input {image_name}</code></p>" ) upload_metadata = { "metadata": { "title": image_name, "version": image_version, "upload_type": "software", "creators": [{"name": "fm-weck"}], "publication_date": datetime.now().date().isoformat(), "access_right": "open", "license": "Apache-2.0", "description": description, } } return upload_metadata
[docs] class ZenodoMgr(object): """ Zenodo API manager. Args: engine (object): The fm-weck engine object. access_token (str, optional): The access token for authentication with the API. Defaults to None. image (str, optional): The image to be published. Required if no doi is given. Defaults to None. doi (str, optional): The DOI of the image to be pulled. Required if no image is given. Defaults to None. Raises: ValueError: If a chosen image is invalid or does not exist. """ def __init__( self, engine, access_token: str | None, image: str | None, doi: str | None, is_sandbox: bool = False, config: Config | None = None, ): access_token = access_token or ZenodoMgr.load_zenodo_token(config) self.session = ZenodoSession(access_token, is_sandbox) self.image = image self.engine = engine if image: self.image_info = ImageMgr().prepare_image_for_zenodo(engine, image) self.doi = doi
[docs] def publish(self): """ Publish an image on Zenodo. Raises: ZenodoError: If Zenodo API returns an error. """ if self.image is None: raise ValueError("No image specified for publishing.") print("Publishing image to Zenodo...") doi = self.session.publish(self.image_info) print("Image successfully published.") ImageMgr().tag_image(self.engine, self.image, doi)
[docs] def push(self): """ Upload an image to Zenodo and leave it in draft mode (unpublished). Raises: ZenodoError: If Zenodo API returns an error. """ if self.image is None: raise ValueError("No image specified for pushing.") logging.info(f"Uploading image to Zenodo... {self.image_info}") self.session.upload(self.image_info) print("Image successfully uploaded.")
[docs] def pull(self, download_path: str | None, force_pull: bool = False): """ Download an image from Zenodo using its DOI. Raises: ValueError: If no DOI is provided. ZenodoError: If Zenodo API returns an error. """ if self.doi: doi = ImageMgr().check_if_doi_exists_locally(self.engine, self.doi) else: raise ValueError("No DOI provided.") if not force_pull and not doi: print("Image already exists locally. Pulling aborted.") return print("Pulling image from Zenodo...") tar_path = self.session.download(self.doi, download_path) ImageMgr().load_image(self.engine, tar_path, self.doi) print("Image successfully pulled.") tar_path.unlink()
[docs] @staticmethod def save_zenodo_token(config, token): token_path = find_dotenv(filename=".zenodo_token") token_path = Path.cwd() / config.get("defaults").get("token_location") if not token_path else Path(token_path) print(f"This action will overwrite the token at: {token_path}\nDo you wish to continue? [y/N]") if input().lower() == "y": load_dotenv(token_path) set_key(token_path, "ZENODO_TOKEN", token) print(f"Token saved to: {token_path}") else: print("Action aborted.")
[docs] @staticmethod def load_zenodo_token(config): """ Load the Zenodo token from the Zenodo config file. Raises: ZenodoError: If no Zenodo token is found. """ token_path = find_dotenv(filename=".zenodo_token") token_path = Path.cwd() / config.get("defaults").get("token_location") if not token_path else Path(token_path) load_dotenv(Path(token_path)) token = os.getenv("ZENODO_TOKEN") if token: print(f"Using token from: {token_path}\n") return token