# This file is part of fm-weck: executing fm-tools in containerized environments.
# https://gitlab.com/sosy-lab/software/fm-weck
#
# SPDX-FileCopyrightText: 2024 Dirk Beyer <https://www.sosy-lab.org>
#
# SPDX-License-Identifier: Apache-2.0
import json
import logging
import os
from datetime import datetime
from pathlib import Path
import httpx
from dotenv import find_dotenv, load_dotenv, set_key
from tqdm import tqdm
from fm_weck.config import Config
from fm_weck.exceptions import ZenodoError
from fm_weck.image_mgr import ImageMgr
[docs]
class TokenAuth(httpx.Auth):
def __init__(self, token: str):
self.token = token
[docs]
def auth_flow(self, request):
request.headers["Authorization"] = f"Bearer {self.token}"
yield request
[docs]
class ZenodoSession:
def __init__(self, access_token: str | None = None, is_sandbox: bool = False):
if is_sandbox:
self.base_url = "https://sandbox.zenodo.org/api"
else:
self.base_url = "https://zenodo.org/api"
self.deposit_url = self.base_url + "/deposit/depositions"
self.download_url = self.base_url + "/records"
self.access_token = access_token
if self.access_token is not None:
self.client = httpx.Client(auth=TokenAuth(self.access_token))
else:
self.client = httpx.Client()
def _require_access_token(self):
if self.access_token is None:
raise ZenodoError("An API access token is required to interact with the Zenodo API.")
[docs]
def publish(self, image_info: tuple):
"""
Publish an image on Zenodo.
"""
self._require_access_token()
deposition_id = self.upload(image_info)
publish_url = self.deposit_url + "/" + deposition_id + "/" + "actions/publish"
r = httpx.post(publish_url)
if r.status_code >= 400:
self.delete_deposition(deposition_id)
raise ZenodoError(f"Error trying to publish an image: {r.json()['message']}")
return r.json()["doi"]
[docs]
def upload(self, image_info: tuple[str, str, Path]) -> str:
"""
Upload an image to Zenodo and leave it in draft mode (unpublished).
"""
self._require_access_token()
image_name = image_info[0]
image_version = image_info[1]
image_path = image_info[2]
# Upload the actual file
upload_name = image_name + "_" + image_version + ".tar"
upload_name = upload_name.replace("/", "-")
metadata = self.get_upload_metadata(upload_name, image_version)
record_id = self.check_existing_records(metadata)
logging.debug(f"Record ID: {record_id}")
# Check if we are adding a new version or creating a new record
if record_id:
r = self.client.post(self.deposit_url + "/" + record_id + "/actions/newversion")
else:
logging.debug(f"{metadata=}")
r = self.client.post(
self.deposit_url,
json=metadata,
headers={"Content-Type": "application/json", "Authorization": f"Bearer {self.access_token}"},
)
if r.status_code >= 400:
logging.debug(f"{r.status_code} - {r.text}")
raise ZenodoError(f"Error trying to upload the deposition metadata to Zenodo: {r.json()['message']}")
bucket_url = r.json()["links"]["bucket"]
deposition_id = str(r.json()["id"])
if record_id:
self.prepare_new_version_upload(metadata, deposition_id)
r = self.upload_with_progress_bar(image_path, upload_name, bucket_url)
if r.status_code >= 400:
self.delete_deposition(deposition_id)
raise ZenodoError(f"Error trying to upload image to Zenodo: {r.json()['message']}")
return deposition_id
[docs]
def upload_with_progress_bar(self, image_path, upload_name, bucket_url):
file_size = os.path.getsize(image_path)
with (
open(image_path, "rb") as image,
tqdm(total=file_size, unit="B", unit_scale=True, desc="Uploading") as progress_bar,
):
def file_iterator():
chunk_size = 1024 * 1024 * 5 # 5 MB
while True:
chunk = image.read(chunk_size)
if not chunk:
break
yield chunk
progress_bar.update(len(chunk))
r = self.client.put(
f"{bucket_url}/{upload_name}",
data=file_iterator(),
headers={"Content-Length": str(file_size)},
timeout=300,
)
return r
[docs]
def download(self, doi: str, download_path: str | None, no_progress_bar: bool = False) -> Path:
"""
Download an image from Zenodo using its DOI.
"""
download_url = self.resolve_download_info(doi)
image_name = download_url.split("/")[-2] + ".tar"
if no_progress_bar:
r = self.client.get(download_url)
if r.status_code >= 400:
raise ZenodoError(f"Error trying to download image: {r.json()['message']}")
content = r.content
else:
content = self.download_with_progress_bar(download_url)
return self.save_image(image_name, content, download_path)
[docs]
def download_with_progress_bar(self, url: str):
content = bytearray()
with self.client.stream("GET", url) as response:
response.raise_for_status()
total_size = int(response.headers.get("Content-Length", 0))
chunk_size = 1024 * 1024 * 5 # 5 MB
with tqdm(total=total_size, unit="B", unit_scale=True, desc="Downloading") as progress_bar:
for chunk in response.iter_bytes(chunk_size):
content.extend(chunk)
progress_bar.update(len(chunk))
return content
[docs]
def resolve_download_info(self, doi: str) -> str:
"""
Extract the download URL from the deposition metadata.
:return: The download URL of the image.
"""
doi_id = doi.split(".")[-1]
record_url = self.download_url + "/" + doi_id
r = self.client.get(record_url)
if r.status_code >= 400:
raise ZenodoError(f"Error trying to access the image record: {r.json()['message']}")
file_name = r.json()["files"][0]["key"]
return record_url + "/files/" + file_name + "/content"
[docs]
def save_image(self, image_name: str, image: bytes, save_path):
save_path = Path().cwd() / image_name if not save_path else Path().cwd() / Path(save_path) / image_name
save_path.parent.mkdir(parents=True, exist_ok=True)
save_path.write_bytes(image)
return save_path
[docs]
def delete_deposition(self, deposition_id: str):
r = self.client.delete(self.deposit_url + "/" + deposition_id)
if r.status_code >= 400:
raise ZenodoError(f"Error trying to delete a deposition resource from Zenodo: {r.json()['message']}")
[docs]
def check_existing_records(self, metadata: dict):
"""
Check if a record with the same name already exists in the depositions.
If it exists, prompt whether the image should be uploaded as a new version.
:return: The deposition ID in case of a positive prompt response, or an empty string.
"""
r = self.client.get(self.deposit_url)
if r.status_code >= 400:
raise ZenodoError(f"Error trying to access the image record: {r.json()['message']}")
for record in r.json():
if metadata["metadata"]["title"] == record["title"]:
while True:
response = (
input(
(
"The record already exists. Do you want to upload this as a new version "
"of the existing record?\n"
"(y)es: upload as a new version of an existing record\n"
"(n)o: create a new record\n"
"(a)bort: cancel the upload\n"
)
)
.strip()
.lower()
)
if response == "y":
return str(record["id"])
elif response == "n":
break
elif response == "a":
exit(0)
return ""
[docs]
def prepare_new_version_upload(self, metadata: dict, deposition_id: str) -> None:
"""
Set up the metadata for the new version and delete the old versions' files from that version.
"""
r = self.client.put(
self.deposit_url + "/" + deposition_id,
data=json.dumps(metadata), # type: ignore
)
if r.status_code >= 400:
raise ZenodoError(f"Error trying to upload the new version metadata to Zenodo: {r.json()['message']}")
for file in r.json()["files"]:
r = self.client.delete(
self.deposit_url + "/" + deposition_id + "/files/" + file["id"],
)
if r.status_code >= 400:
raise ZenodoError(
f"Error trying to delete old version files from the new version: {r.json()['message']}"
)
[docs]
class ZenodoMgr(object):
"""
Zenodo API manager.
Args:
engine (object): The fm-weck engine object.
access_token (str, optional): The access token for authentication with the API. Defaults to None.
image (str, optional): The image to be published. Required if no doi is given. Defaults to None.
doi (str, optional): The DOI of the image to be pulled. Required if no image is given. Defaults to None.
Raises:
ValueError: If a chosen image is invalid or does not exist.
"""
def __init__(
self,
engine,
access_token: str | None,
image: str | None,
doi: str | None,
is_sandbox: bool = False,
config: Config | None = None,
):
access_token = access_token or ZenodoMgr.load_zenodo_token(config)
self.session = ZenodoSession(access_token, is_sandbox)
self.image = image
self.engine = engine
if image:
self.image_info = ImageMgr().prepare_image_for_zenodo(engine, image)
self.doi = doi
[docs]
def publish(self):
"""
Publish an image on Zenodo.
Raises:
ZenodoError: If Zenodo API returns an error.
"""
if self.image is None:
raise ValueError("No image specified for publishing.")
print("Publishing image to Zenodo...")
doi = self.session.publish(self.image_info)
print("Image successfully published.")
ImageMgr().tag_image(self.engine, self.image, doi)
[docs]
def push(self):
"""
Upload an image to Zenodo and leave it in draft mode (unpublished).
Raises:
ZenodoError: If Zenodo API returns an error.
"""
if self.image is None:
raise ValueError("No image specified for pushing.")
logging.info(f"Uploading image to Zenodo... {self.image_info}")
self.session.upload(self.image_info)
print("Image successfully uploaded.")
[docs]
def pull(self, download_path: str | None, force_pull: bool = False):
"""
Download an image from Zenodo using its DOI.
Raises:
ValueError: If no DOI is provided.
ZenodoError: If Zenodo API returns an error.
"""
if self.doi:
doi = ImageMgr().check_if_doi_exists_locally(self.engine, self.doi)
else:
raise ValueError("No DOI provided.")
if not force_pull and not doi:
print("Image already exists locally. Pulling aborted.")
return
print("Pulling image from Zenodo...")
tar_path = self.session.download(self.doi, download_path)
ImageMgr().load_image(self.engine, tar_path, self.doi)
print("Image successfully pulled.")
tar_path.unlink()
[docs]
@staticmethod
def save_zenodo_token(config, token):
token_path = find_dotenv(filename=".zenodo_token")
token_path = Path.cwd() / config.get("defaults").get("token_location") if not token_path else Path(token_path)
print(f"This action will overwrite the token at: {token_path}\nDo you wish to continue? [y/N]")
if input().lower() == "y":
load_dotenv(token_path)
set_key(token_path, "ZENODO_TOKEN", token)
print(f"Token saved to: {token_path}")
else:
print("Action aborted.")
[docs]
@staticmethod
def load_zenodo_token(config):
"""
Load the Zenodo token from the Zenodo config file.
Raises:
ZenodoError: If no Zenodo token is found.
"""
token_path = find_dotenv(filename=".zenodo_token")
token_path = Path.cwd() / config.get("defaults").get("token_location") if not token_path else Path(token_path)
load_dotenv(Path(token_path))
token = os.getenv("ZENODO_TOKEN")
if token:
print(f"Using token from: {token_path}\n")
return token