Source code for quantem.data.registry

"""Dataset registry and loader backed by Hugging Face Hub.

Metadata lives in JSON sidecar files on HF Hub, not hardcoded here.
The registry is discovered at runtime by listing the repo contents.
"""

import json
import os
import tempfile

import numpy as np
from huggingface_hub import HfApi, hf_hub_download

from quantem.data.schema import SCHEMA_VERSION, validate, make_template

REPO_ID = "bobleesj/quantem-data"


def _api() -> HfApi:
    return HfApi()


[docs] def available(technique: str | None = None) -> list[str]: """List available dataset names. Parameters ---------- technique : str, optional Filter by technique (e.g. ``"4dstem"``, ``"hrtem"``). Returns ------- list of str Sorted dataset names (excluding placeholders). """ api = _api() files = api.list_repo_files(repo_id=REPO_ID, repo_type="dataset") names = [] for f in files: if f.endswith(".json") and f != "README.md": name = os.path.basename(f).removesuffix(".json") if name.startswith("placeholder_"): continue if technique and not f.startswith(f"{technique}/"): continue names.append(name) return sorted(names)
[docs] def info(name: str) -> dict: """Return metadata for a dataset (downloads the JSON sidecar). Parameters ---------- name : str Dataset name (see ``available()``). """ meta = _download_metadata(name) return meta
[docs] def load(name: str, metadata: bool = False): """Download (with caching) and return a dataset as a NumPy array. Files are cached in ``~/.cache/huggingface/`` and only downloaded once. Parameters ---------- name : str Dataset name (see ``available()``). metadata : bool If True, return ``(array, metadata_dict)`` instead of just the array. Returns ------- np.ndarray or (np.ndarray, dict) """ meta = _download_metadata(name) npy_path = meta["_npy_path"] local = hf_hub_download( repo_id=REPO_ID, filename=npy_path, repo_type="dataset", ) arr = np.load(local) if metadata: return arr, meta return arr
[docs] def load_raw(name: str) -> str: """Download an original instrument file and return the local path. Raw files (e.g. ``.h5``, ``.dm4``, ``.mrc``) are stored in the ``raw/`` folder. This returns the cached local path for use with h5py, hyperspy, etc. Parameters ---------- name : str Raw file name (without extension), e.g. ``"arina_lamella_master"``. Returns ------- str Local file path (cached in ``~/.cache/huggingface/``). """ api = _api() files = api.list_repo_files(repo_id=REPO_ID, repo_type="dataset") # Find matching file in raw/ matches = [f for f in files if f.startswith("raw/") and os.path.basename(f).startswith(name)] if not matches: raw_files = [f for f in files if f.startswith("raw/") and not f.endswith(".json")] raise KeyError( f"Raw file {name!r} not found. Available raw files: " f"{[os.path.basename(f) for f in raw_files]}" ) # Prefer exact match (name + extension) path = matches[0] local = hf_hub_download( repo_id=REPO_ID, filename=path, repo_type="dataset", ) return local
[docs] def preview_upload( data, name: str, technique: str, metadata: dict | str | None = None, description: str = "", contributor: str = "", license: str = "CC-BY-4.0", ) -> list[str]: """Validate and preview an upload without actually uploading. Checks naming convention, metadata schema, technique, and shape. Prints a summary of what would be uploaded. Returns a list of errors (empty if everything is valid). Parameters ---------- data : array_like or str NumPy array, or path to a ``.npy`` file. name : str Dataset name. technique : str Technique folder. metadata : dict or str, optional Full metadata dict, or path to a JSON file. description : str One-line description (used if metadata is None). contributor : str Who is uploading (used if metadata is None). license : str License string (default ``"CC-BY-4.0"``). Returns ------- list of str Error messages. Empty list means the upload is valid. """ import re errors: list[str] = [] # Resolve data if isinstance(data, str): try: arr = np.load(data) except Exception as e: errors.append(f"Cannot load data file: {e}") print("Upload preview: FAILED") for e in errors: print(f" - {e}") return errors else: try: arr = np.asarray(data, dtype=np.float32) except Exception as e: errors.append(f"Cannot convert data to array: {e}") print("Upload preview: FAILED") for e in errors: print(f" - {e}") return errors # Naming convention checks if not re.match(r"^[a-z][a-z0-9_]*$", name): errors.append( f"Name {name!r} violates naming convention: " "must be lowercase, underscores only, start with a letter" ) if name != name.lower(): errors.append(f"Name must be lowercase: {name!r}") if "-" in name: errors.append(f"Use underscores, not hyphens: {name!r}") if re.search(r"\d{4}", name): errors.append(f"Don't include year in name (put it in metadata): {name!r}") if re.search(r"\d+x\d+", name): errors.append(f"Don't include resolution in name (put it in metadata): {name!r}") # Resolve metadata if metadata is None: meta = make_template( name=name, technique=technique, shape=arr.shape, dtype=str(arr.dtype), description=description, contributor=contributor, license=license, ) elif isinstance(metadata, str): try: with open(metadata) as f: meta = json.load(f) except Exception as e: errors.append(f"Cannot read metadata file: {e}") meta = {} else: meta = dict(metadata) meta.setdefault("schema_version", SCHEMA_VERSION) meta.setdefault("name", name) meta.setdefault("technique", technique) # Schema validation schema_errors = validate(meta) errors.extend(schema_errors) # Shape match meta_shape = meta.get("data", {}).get("shape") if meta_shape and tuple(meta_shape) != tuple(arr.shape): errors.append( f"Metadata shape {meta_shape} does not match array shape {list(arr.shape)}" ) # Check for name collision try: existing = available() if name in existing: errors.append( f"Dataset {name!r} already exists. Choose a different name " "or use update_metadata() to modify existing metadata." ) except Exception: pass # offline — skip collision check # Print summary size_mb = arr.nbytes / (1024 * 1024) print("Upload preview") print(f" Name: {name}") print(f" Technique: {technique}") print(f" Shape: {list(arr.shape)}") print(f" Dtype: {arr.dtype}") print(f" Size: {size_mb:.1f} MB") print(f" Description: {meta.get('description', '') or '(empty)'}") print(f" Contributor: {meta.get('attribution', {}).get('contributor', '') or '(empty)'}") print(f" License: {meta.get('attribution', {}).get('license', '')}") print(f" Destination: {REPO_ID}/{technique}/{name}.npy") if errors: print(f"\n ERRORS ({len(errors)}):") for err in errors: print(f" - {err}") else: print("\n Ready to upload. Run upload() to create a PR.") return errors
[docs] def upload( data, name: str, technique: str, metadata: dict | str | None = None, description: str = "", contributor: str = "", license: str = "CC-BY-4.0", token: str | None = None, create_pr: bool = True, ): """Upload a dataset with metadata to HF Hub. By default creates a Pull Request for review. Set ``create_pr=False`` to commit directly (requires write access). Parameters ---------- data : array_like or str NumPy array, or path to a ``.npy`` file. name : str Dataset name (becomes the filename, e.g. ``"arina_lamella_32x32"``). technique : str Category folder (``"4dstem"``, ``"hrtem"``, ``"eels"``, etc.). metadata : dict or str, optional Full metadata dict, or path to a JSON file. If None, a template is created from the other parameters. description : str One-line description (used if metadata is None). contributor : str Who is uploading (used if metadata is None). license : str License string (default ``"CC-BY-4.0"``). token : str, optional HF token. If None, uses cached login. create_pr : bool If True (default), create a Pull Request instead of committing directly. The PR can be reviewed and merged on HF Hub. """ from huggingface_hub import CommitOperationAdd # Resolve data if isinstance(data, str): arr = np.load(data) else: arr = np.asarray(data, dtype=np.float32) # Resolve metadata if metadata is None: meta = make_template( name=name, technique=technique, shape=arr.shape, dtype=str(arr.dtype), description=description, contributor=contributor, license=license, ) elif isinstance(metadata, str): with open(metadata) as f: meta = json.load(f) else: meta = dict(metadata) # Ensure required fields meta.setdefault("schema_version", SCHEMA_VERSION) meta.setdefault("name", name) meta.setdefault("technique", technique) # Validate errors = validate(meta) if errors: raise ValueError( "Metadata validation failed:\n" + "\n".join(f" - {e}" for e in errors) ) # Check shape matches meta_shape = meta.get("data", {}).get("shape") if meta_shape and tuple(meta_shape) != tuple(arr.shape): raise ValueError( f"Metadata shape {meta_shape} does not match array shape {list(arr.shape)}" ) # Upload both files in a single atomic commit api = _api() npy_remote = f"{technique}/{name}.npy" json_remote = f"{technique}/{name}.json" with tempfile.TemporaryDirectory() as tmpdir: npy_local = os.path.join(tmpdir, f"{name}.npy") json_local = os.path.join(tmpdir, f"{name}.json") np.save(npy_local, arr) with open(json_local, "w") as f: json.dump(meta, f, indent=2) operations = [ CommitOperationAdd(path_in_repo=npy_remote, path_or_fileobj=npy_local), CommitOperationAdd(path_in_repo=json_remote, path_or_fileobj=json_local), ] commit_msg = f"Add {name} ({technique}, {list(arr.shape)}, {arr.nbytes / (1024**2):.1f} MB)" result = api.create_commit( repo_id=REPO_ID, repo_type="dataset", operations=operations, commit_message=commit_msg, token=token, create_pr=create_pr, ) if create_pr: print(f"Created PR to add {name} ({arr.nbytes / (1024**2):.1f} MB)") print(f"Review: {result.pr_url}") else: print(f"Uploaded {name} ({arr.nbytes / (1024**2):.1f} MB) to {REPO_ID}")
[docs] def update_metadata(name: str, updates: dict, token: str | None = None, create_pr: bool = True): """Update metadata fields for an existing dataset. Downloads the current JSON, merges your changes, re-uploads. By default creates a Pull Request for review. Parameters ---------- name : str Dataset name. updates : dict Fields to update. Nested dicts are merged (not replaced). token : str, optional HF token. If None, uses cached login. create_pr : bool If True (default), create a Pull Request instead of committing directly. The PR can be reviewed and merged on HF Hub. """ from huggingface_hub import CommitOperationAdd meta = _download_metadata(name) technique = meta["technique"] json_remote = f"{technique}/{name}.json" # Deep merge _deep_merge(meta, updates) # Remove internal fields meta.pop("_npy_path", None) # Validate errors = validate(meta) if errors: raise ValueError( "Updated metadata is invalid:\n" + "\n".join(f" - {e}" for e in errors) ) api = _api() with tempfile.TemporaryDirectory() as tmpdir: json_local = os.path.join(tmpdir, f"{name}.json") with open(json_local, "w") as f: json.dump(meta, f, indent=2) operations = [ CommitOperationAdd(path_in_repo=json_remote, path_or_fileobj=json_local), ] result = api.create_commit( repo_id=REPO_ID, repo_type="dataset", operations=operations, commit_message=f"Update metadata for {name}", token=token, create_pr=create_pr, ) if create_pr: print(f"Created PR to update metadata for {name}") print(f"Review: {result.pr_url}") else: print(f"Updated metadata for {name}")
[docs] def list_files(technique: str | None = None) -> list[dict]: """List all files on HF Hub with details. Parameters ---------- technique : str, optional Filter by technique folder (e.g. ``"4dstem"``). If None, lists all. Returns ------- list of dict Each dict has ``path``, ``size_mb``, ``type`` (``"data"``/``"metadata"``). """ api = _api() repo_info = api.repo_info(repo_id=REPO_ID, repo_type="dataset", files_metadata=True) results = [] for item in repo_info.siblings: path = item.rfilename if path in (".gitattributes", "README.md"): continue if technique and not path.startswith(f"{technique}/"): continue ext = os.path.splitext(path)[1] file_type = "metadata" if ext == ".json" else "data" size_mb = item.size / (1024 * 1024) if item.size else 0 results.append({ "path": path, "size_mb": round(size_mb, 2), "type": file_type, }) return sorted(results, key=lambda x: x["path"])
def _deep_merge(base: dict, updates: dict): """Merge updates into base, recursing into nested dicts.""" for key, val in updates.items(): if key in base and isinstance(base[key], dict) and isinstance(val, dict): _deep_merge(base[key], val) else: base[key] = val def _download_metadata(name: str) -> dict: """Download and parse the JSON sidecar for a dataset.""" # Try each technique folder api = _api() files = api.list_repo_files(repo_id=REPO_ID, repo_type="dataset") json_file = None npy_file = None for f in files: basename = os.path.basename(f) if basename == f"{name}.json": json_file = f elif basename == f"{name}.npy": npy_file = f if json_file is None: # Fall back: maybe only .npy exists (legacy, no sidecar) if npy_file is not None: return { "name": name, "description": "", "_npy_path": npy_file, } names = available() raise KeyError( f"Unknown dataset {name!r}. Available: {names}" ) local = hf_hub_download( repo_id=REPO_ID, filename=json_file, repo_type="dataset", ) with open(local) as f: meta = json.load(f) # Attach the npy path for load() if npy_file: meta["_npy_path"] = npy_file else: # Infer from json path meta["_npy_path"] = json_file.replace(".json", ".npy") return meta