"""
Bin: Interactive calibration-aware 4D-STEM binning widget.
This widget is designed as a preprocessing + quality-control step before
`Show4DSTEM` analysis. It lets you interactively choose binning extent for
scan and detector axes, preview resulting shapes/calibration, and compare
BF/ADF virtual images before/after binning.
"""
from __future__ import annotations
import json
import math
import pathlib
import time
from typing import Self
import anywidget
import numpy as np
import traitlets
from quantem.widget.array_utils import to_numpy
from quantem.widget.io import IOResult
from quantem.widget.json_state import build_json_header, resolve_widget_version, save_state_file, unwrap_state_payload
from quantem.widget.bin_batch import _bin_axis_torch as _bin_axis_standalone, _binned_axis_shape
from quantem.widget.tool_parity import (
bind_tool_runtime_api,
build_tool_groups,
normalize_tool_groups,
)
try:
import torch
_HAS_TORCH = True
except ImportError:
_HAS_TORCH = False
try:
import h5py
_HAS_H5PY = True
except ImportError:
h5py = None # type: ignore[assignment]
_HAS_H5PY = False
try:
import hdf5plugin # noqa: F401 — registers bitshuffle/LZ4 HDF5 filters
except ImportError:
pass
try:
from quantem.core.config import validate_device
_HAS_VALIDATE_DEVICE = True
except Exception:
_HAS_VALIDATE_DEVICE = False
# Supported units for calibration extraction from quantem Dataset objects
_REAL_UNITS = {"Å", "angstrom", "A", "nm"}
_K_UNITS = {"mrad", "1/Å", "1/A"}
_BIN_ESM = pathlib.Path(__file__).parent / "static" / "bin.js"
_BIN_CSS = pathlib.Path(__file__).parent / "static" / "bin.css"
def _as_pair(value: float | tuple[float, float] | list[float] | None, default: float) -> tuple[float, float]:
"""Normalize scalar/pair value to a `(row, col)` pair."""
if value is None:
return (float(default), float(default))
if isinstance(value, (tuple, list)):
if len(value) != 2:
raise ValueError("Expected a scalar or a 2-tuple/list")
return (float(value[0]), float(value[1]))
return (float(value), float(value))
def _qc_stats_torch(image) -> list[float]:
"""Compute compact quality metrics for a 2D torch tensor.
Returns [mean, min, max, std, snr, contrast_1_99].
"""
if image.numel() == 0:
return [0.0] * 6
arr = image.float()
mean = float(arr.mean())
amin = float(arr.min())
amax = float(arr.max())
std = float(arr.std(unbiased=False))
snr = float(mean / (std + 1e-12))
q = torch.quantile(arr.flatten(), torch.tensor([0.01, 0.99], device=arr.device))
p1 = float(q[0])
p99 = float(q[1])
contrast = float((p99 - p1) / (abs(p99) + abs(p1) + 1e-12))
return [mean, amin, amax, std, snr, contrast]
def _discover_arina_chunks(master_path):
"""Parse a Dectris arina master HDF5 and return chunk metadata.
Parameters
----------
master_path : str or pathlib.Path
Path to the ``*_master.h5`` file.
Returns
-------
dict
Keys: ``chunks`` (list of (file_path, dataset_path, n_frames)),
``total_frames``, ``det_rows``, ``det_cols``,
``beam_center_x``, ``beam_center_y``, ``ntrigger``.
"""
master_path = pathlib.Path(master_path).resolve()
master_dir = master_path.parent
chunks = []
total_frames = 0
det_rows = None
det_cols = None
with h5py.File(master_path, "r") as f:
data_group = f["/entry/data"]
for key in sorted(data_group.keys()):
link = data_group.get(key, getlink=True)
if isinstance(link, h5py.ExternalLink):
chunk_file = str(master_dir / link.filename)
chunk_ds_path = link.path
else:
chunk_file = str(master_path)
chunk_ds_path = f"/entry/data/{key}"
with h5py.File(chunk_file, "r") as cf:
ds = cf[chunk_ds_path]
n_frames = int(ds.shape[0])
if det_rows is None and ds.ndim >= 3:
det_rows = int(ds.shape[1])
det_cols = int(ds.shape[2])
total_frames += n_frames
chunks.append((chunk_file, chunk_ds_path, n_frames))
beam_center_x = None
beam_center_y = None
ntrigger = None
det = f.get("/entry/instrument/detector")
if det is not None:
if "beam_center_x" in det:
beam_center_x = float(np.asarray(det["beam_center_x"]))
if "beam_center_y" in det:
beam_center_y = float(np.asarray(det["beam_center_y"]))
spec = det.get("detectorSpecific")
if spec is not None and "ntrigger" in spec:
ntrigger = int(np.asarray(spec["ntrigger"]))
return {
"chunks": chunks,
"total_frames": total_frames,
"det_rows": det_rows,
"det_cols": det_cols,
"beam_center_x": beam_center_x,
"beam_center_y": beam_center_y,
"ntrigger": ntrigger,
}
def _is_arina_master(h5f):
"""Return True if the open HDF5 file looks like a Dectris arina master."""
if "/entry/data" not in h5f:
return False
data_group = h5f["/entry/data"]
for key in data_group.keys():
link = data_group.get(key, getlink=True)
if isinstance(link, h5py.ExternalLink):
return True
return False
[docs]
class Bin(anywidget.AnyWidget):
"""
Interactive 4D-STEM binning widget with calibration tracking and BF/ADF QC.
Parameters
----------
data : Dataset4dstem or array_like
4D array `(scan_rows, scan_cols, det_rows, det_cols)` or flattened 3D
`(N, det_rows, det_cols)` with explicit `scan_shape`.
If a quantem dataset object is provided, calibration is auto-extracted.
scan_shape : tuple[int, int], optional
Required for flattened 3D input.
pixel_size : float or tuple[float, float], optional
Real-space sampling in Å/px for `(row, col)`.
k_pixel_size : float or tuple[float, float], optional
Detector sampling in mrad/px (or reciprocal-space units) for `(row, col)`.
center : tuple[float, float], optional
Detector center `(row, col)` used for BF/ADF preview masks.
bin_mode : {"mean", "sum"}, default "mean"
Reduction mode for block binning.
edge_mode : {"crop", "pad", "error"}, default "crop"
How non-divisible dimensions are handled.
bf_radius_ratio : float, default 0.125
BF disk radius as fraction of `min(det_rows, det_cols)`.
adf_inner_ratio : float, default 0.30
ADF annulus inner radius as fraction of detector size.
adf_outer_ratio : float, default 0.45
ADF annulus outer radius as fraction of detector size.
Notes
-----
- Real-space calibration is multiplied by scan bin factors.
- Detector-space calibration is multiplied by detector bin factors.
- BF/ADF previews are recomputed after every parameter change.
"""
_esm = _BIN_ESM if _BIN_ESM.exists() else "export function render() {}"
_css = _BIN_CSS if _BIN_CSS.exists() else ""
# Original shape
scan_rows = traitlets.Int(1).tag(sync=True)
scan_cols = traitlets.Int(1).tag(sync=True)
det_rows = traitlets.Int(1).tag(sync=True)
det_cols = traitlets.Int(1).tag(sync=True)
# Current bin factors
scan_bin_row = traitlets.Int(1).tag(sync=True)
scan_bin_col = traitlets.Int(1).tag(sync=True)
det_bin_row = traitlets.Int(1).tag(sync=True)
det_bin_col = traitlets.Int(1).tag(sync=True)
# UI hint maxima
max_scan_bin_row = traitlets.Int(1).tag(sync=True)
max_scan_bin_col = traitlets.Int(1).tag(sync=True)
max_det_bin_row = traitlets.Int(1).tag(sync=True)
max_det_bin_col = traitlets.Int(1).tag(sync=True)
# Binning behavior
bin_mode = traitlets.Unicode("mean").tag(sync=True) # "mean" | "sum"
edge_mode = traitlets.Unicode("crop").tag(sync=True) # "crop" | "pad" | "error"
device = traitlets.Unicode("cpu").tag(sync=True)
# Binned shape
binned_scan_rows = traitlets.Int(1).tag(sync=True)
binned_scan_cols = traitlets.Int(1).tag(sync=True)
binned_det_rows = traitlets.Int(1).tag(sync=True)
binned_det_cols = traitlets.Int(1).tag(sync=True)
# Calibration (original)
pixel_size_row = traitlets.Float(1.0).tag(sync=True)
pixel_size_col = traitlets.Float(1.0).tag(sync=True)
pixel_unit = traitlets.Unicode("px").tag(sync=True)
pixel_calibrated = traitlets.Bool(False).tag(sync=True)
k_pixel_size_row = traitlets.Float(1.0).tag(sync=True)
k_pixel_size_col = traitlets.Float(1.0).tag(sync=True)
k_unit = traitlets.Unicode("px").tag(sync=True)
k_calibrated = traitlets.Bool(False).tag(sync=True)
# Calibration (binned)
binned_pixel_size_row = traitlets.Float(1.0).tag(sync=True)
binned_pixel_size_col = traitlets.Float(1.0).tag(sync=True)
binned_k_pixel_size_row = traitlets.Float(1.0).tag(sync=True)
binned_k_pixel_size_col = traitlets.Float(1.0).tag(sync=True)
# Detector center for BF/ADF preview masks
center_row = traitlets.Float(0.0).tag(sync=True)
center_col = traitlets.Float(0.0).tag(sync=True)
# BF/ADF mask settings (fractions of detector size)
bf_radius_ratio = traitlets.Float(0.125).tag(sync=True)
adf_inner_ratio = traitlets.Float(0.30).tag(sync=True)
adf_outer_ratio = traitlets.Float(0.45).tag(sync=True)
# Scan position for DP exploration — compound [row, col], [-1, -1] = mean DP
_scan_position = traitlets.List(traitlets.Int(), default_value=[-1, -1]).tag(sync=True)
# Preview data as float32 bytes
original_bf_bytes = traitlets.Bytes(b"").tag(sync=True)
original_adf_bytes = traitlets.Bytes(b"").tag(sync=True)
binned_bf_bytes = traitlets.Bytes(b"").tag(sync=True)
binned_adf_bytes = traitlets.Bytes(b"").tag(sync=True)
# Mean diffraction pattern (detector-space preview)
original_mean_dp_bytes = traitlets.Bytes(b"").tag(sync=True)
binned_mean_dp_bytes = traitlets.Bytes(b"").tag(sync=True)
# Per-position diffraction pattern bytes (sent when _scan_position >= 0)
_position_dp_bytes = traitlets.Bytes(b"").tag(sync=True)
_binned_position_dp_bytes = traitlets.Bytes(b"").tag(sync=True)
# Binned detector center (for JS overlay positioning)
binned_center_row = traitlets.Float(0.0).tag(sync=True)
binned_center_col = traitlets.Float(0.0).tag(sync=True)
# Preview stats [mean, min, max, std, snr, contrast]
original_bf_stats = traitlets.List(traitlets.Float(), default_value=[0.0] * 6).tag(sync=True)
original_adf_stats = traitlets.List(traitlets.Float(), default_value=[0.0] * 6).tag(sync=True)
binned_bf_stats = traitlets.List(traitlets.Float(), default_value=[0.0] * 6).tag(sync=True)
binned_adf_stats = traitlets.List(traitlets.Float(), default_value=[0.0] * 6).tag(sync=True)
# Status
status_message = traitlets.Unicode("").tag(sync=True)
status_level = traitlets.Unicode("ok").tag(sync=True) # "ok" | "warn" | "error"
# Display
title = traitlets.Unicode("").tag(sync=True)
cmap = traitlets.Unicode("inferno").tag(sync=True)
log_scale = traitlets.Bool(False).tag(sync=True)
auto_contrast = traitlets.Bool(False).tag(sync=True)
show_fft = traitlets.Bool(False).tag(sync=True)
show_controls = traitlets.Bool(True).tag(sync=True)
disabled_tools = traitlets.List(traitlets.Unicode()).tag(sync=True)
hidden_tools = traitlets.List(traitlets.Unicode()).tag(sync=True)
# Export (trait-triggered .npy download)
_npy_export_requested = traitlets.Bool(False).tag(sync=True)
_npy_export_data = traitlets.Bytes(b"").tag(sync=True)
@classmethod
def _normalize_tool_groups(cls, tool_groups):
return normalize_tool_groups("Bin", tool_groups)
@classmethod
def _build_disabled_tools(
cls,
disabled_tools=None,
disable_display: bool = False,
disable_binning: bool = False,
disable_mask: bool = False,
disable_preview: bool = False,
disable_stats: bool = False,
disable_export: bool = False,
disable_all: bool = False,
):
return build_tool_groups(
"Bin",
tool_groups=disabled_tools,
all_flag=disable_all,
flag_map={
"display": disable_display,
"binning": disable_binning,
"mask": disable_mask,
"preview": disable_preview,
"stats": disable_stats,
"export": disable_export,
},
)
@classmethod
def _build_hidden_tools(
cls,
hidden_tools=None,
hide_display: bool = False,
hide_binning: bool = False,
hide_mask: bool = False,
hide_preview: bool = False,
hide_stats: bool = False,
hide_export: bool = False,
hide_all: bool = False,
):
return build_tool_groups(
"Bin",
tool_groups=hidden_tools,
all_flag=hide_all,
flag_map={
"display": hide_display,
"binning": hide_binning,
"mask": hide_mask,
"preview": hide_preview,
"stats": hide_stats,
"export": hide_export,
},
)
@traitlets.validate("disabled_tools")
def _validate_disabled_tools(self, proposal):
return self._normalize_tool_groups(proposal["value"])
@traitlets.validate("hidden_tools")
def _validate_hidden_tools(self, proposal):
return self._normalize_tool_groups(proposal["value"])
@traitlets.validate("scan_bin_row", "scan_bin_col", "det_bin_row", "det_bin_col")
def _validate_bin_factor(self, proposal):
value = int(proposal["value"])
if value < 1:
raise traitlets.TraitError("Binning factors must be >= 1")
return value
@traitlets.validate("bin_mode")
def _validate_bin_mode(self, proposal):
value = str(proposal["value"]).lower()
if value not in {"mean", "sum"}:
raise traitlets.TraitError("bin_mode must be 'mean' or 'sum'")
return value
@traitlets.validate("edge_mode")
def _validate_edge_mode(self, proposal):
value = str(proposal["value"]).lower()
if value not in {"crop", "pad", "error"}:
raise traitlets.TraitError("edge_mode must be 'crop', 'pad', or 'error'")
return value
@traitlets.validate("device")
def _validate_device_name(self, proposal):
value = str(proposal["value"]).strip().lower()
if not value:
raise traitlets.TraitError("device must be a non-empty string")
return value
@traitlets.validate("bf_radius_ratio", "adf_inner_ratio", "adf_outer_ratio")
def _validate_ratio(self, proposal):
value = float(proposal["value"])
if value < 0:
raise traitlets.TraitError("Ratios must be >= 0")
return value
def __init__(
self,
data,
scan_shape: tuple[int, int] | None = None,
pixel_size: float | tuple[float, float] | None = None,
k_pixel_size: float | tuple[float, float] | None = None,
center: tuple[float, float] | None = None,
bin_mode: str = "mean",
edge_mode: str = "crop",
bf_radius_ratio: float = 0.125,
adf_inner_ratio: float = 0.30,
adf_outer_ratio: float = 0.45,
title: str = "",
cmap: str = "inferno",
log_scale: bool = False,
show_controls: bool = True,
disabled_tools: list[str] | None = None,
hidden_tools: list[str] | None = None,
disable_display: bool = False,
disable_binning: bool = False,
disable_mask: bool = False,
disable_preview: bool = False,
disable_stats: bool = False,
disable_export: bool = False,
disable_all: bool = False,
hide_display: bool = False,
hide_binning: bool = False,
hide_mask: bool = False,
hide_preview: bool = False,
hide_stats: bool = False,
hide_export: bool = False,
hide_all: bool = False,
device: str | None = None,
state: dict | str | pathlib.Path | None = None,
**kwargs,
):
super().__init__(**kwargs)
self.widget_version = resolve_widget_version()
self.bin_mode = bin_mode
self.edge_mode = edge_mode
self.cmap = cmap
self.log_scale = log_scale
self.show_controls = show_controls
self.bf_radius_ratio = float(bf_radius_ratio)
self.adf_inner_ratio = float(adf_inner_ratio)
self.adf_outer_ratio = float(adf_outer_ratio)
# Check if data is an IOResult and extract metadata
if isinstance(data, IOResult):
if not title and data.title:
title = data.title
data = data.data
# Dataset-like duck typing: `array`, `sampling`, `units`
dataset_pixel: float | tuple[float, float] | None = None
dataset_k: float | tuple[float, float] | None = None
dataset_pixel_unit = "px"
dataset_k_unit = "px"
pixel_calibrated = False
k_calibrated = False
if hasattr(data, "sampling") and hasattr(data, "array"):
if not title and hasattr(data, "name") and data.name:
title = str(data.name)
units = list(getattr(data, "units", ["pixels"] * 4))
sampling = list(getattr(data, "sampling", [1.0] * 4))
if len(units) >= 2 and len(sampling) >= 2 and units[0] in _REAL_UNITS:
sy = float(sampling[0])
sx = float(sampling[1])
if units[0] == "nm":
sy *= 10.0
if units[1] == "nm":
sx *= 10.0
dataset_pixel = (sy, sx)
dataset_pixel_unit = "Å"
pixel_calibrated = True
if len(units) >= 4 and len(sampling) >= 4 and units[2] in _K_UNITS:
ky = float(sampling[2])
kx = float(sampling[3])
dataset_k = (ky, kx)
dataset_k_unit = units[2]
k_calibrated = True
data = data.array
self.title = title
# Manual kwargs override extracted calibration
p_row, p_col = _as_pair(pixel_size if pixel_size is not None else dataset_pixel, 1.0)
k_row, k_col = _as_pair(k_pixel_size if k_pixel_size is not None else dataset_k, 1.0)
if pixel_size is not None:
pixel_calibrated = True
dataset_pixel_unit = "Å"
if k_pixel_size is not None:
k_calibrated = True
dataset_k_unit = "mrad"
self.pixel_size_row = p_row
self.pixel_size_col = p_col
self.pixel_unit = dataset_pixel_unit if pixel_calibrated else "px"
self.pixel_calibrated = pixel_calibrated
self.k_pixel_size_row = k_row
self.k_pixel_size_col = k_col
self.k_unit = dataset_k_unit if k_calibrated else "px"
self.k_calibrated = k_calibrated
# Normalize input to float32 4D NumPy, then move to torch (compute is torch-only).
data_np = to_numpy(data, dtype=np.float32)
if data_np.ndim == 4:
scan_r, scan_c, det_r, det_c = data_np.shape
data4d = data_np
elif data_np.ndim == 3:
if scan_shape is None:
n = int(data_np.shape[0])
side = int(math.isqrt(n))
if side * side != n:
raise ValueError(
f"Cannot infer square scan_shape from flattened N={n}. Provide scan_shape=(rows, cols)."
)
scan_shape = (side, side)
if int(scan_shape[0]) * int(scan_shape[1]) != int(data_np.shape[0]):
raise ValueError(
f"scan_shape={scan_shape} does not match flattened length {data_np.shape[0]}"
)
scan_r, scan_c = int(scan_shape[0]), int(scan_shape[1])
det_r, det_c = int(data_np.shape[1]), int(data_np.shape[2])
data4d = data_np.reshape(scan_r, scan_c, det_r, det_c)
else:
raise ValueError(
f"Expected 4D array (scan_rows, scan_cols, det_rows, det_cols) or flattened 3D (N, det_rows, det_cols), got {data_np.ndim}D"
)
self.scan_rows = scan_r
self.scan_cols = scan_c
self.det_rows = det_r
self.det_cols = det_c
if not _HAS_TORCH:
raise ImportError("Bin requires torch. Install PyTorch to use this widget.")
device_str = self._resolve_torch_device(requested=device, numel=int(data_np.size))
if device_str is None:
requested = "auto" if device is None else str(device)
raise ValueError(f"Unable to initialize torch device '{requested}'")
self._device = torch.device(device_str)
data4d_writable = np.array(data4d, dtype=np.float32, copy=True)
self._data_torch = torch.from_numpy(data4d_writable).to(self._device)
self.device = device_str
# Slider maxima (UI hint only)
self.max_scan_bin_row = max(1, scan_r)
self.max_scan_bin_col = max(1, scan_c)
self.max_det_bin_row = max(1, det_r)
self.max_det_bin_col = max(1, det_c)
if center is None:
self.center_row = det_r / 2.0
self.center_col = det_c / 2.0
else:
self.center_row = float(center[0])
self.center_col = float(center[1])
self._binned_data_torch = self._data_torch
self._original_bf_torch = torch.zeros((self.scan_rows, self.scan_cols), dtype=torch.float32)
self._original_adf_torch = torch.zeros((self.scan_rows, self.scan_cols), dtype=torch.float32)
self._binned_bf_torch = torch.zeros((self.scan_rows, self.scan_cols), dtype=torch.float32)
self._binned_adf_torch = torch.zeros((self.scan_rows, self.scan_cols), dtype=torch.float32)
self.observe(
self._on_params_changed,
names=[
"scan_bin_row",
"scan_bin_col",
"det_bin_row",
"det_bin_col",
"bin_mode",
"edge_mode",
"center_row",
"center_col",
"bf_radius_ratio",
"adf_inner_ratio",
"adf_outer_ratio",
],
)
self.observe(self._on_position_changed, names=["_scan_position"])
self.observe(self._on_npy_export, names=["_npy_export_requested"])
self._recompute_previews()
self.disabled_tools = self._build_disabled_tools(
disabled_tools=disabled_tools,
disable_display=disable_display,
disable_binning=disable_binning,
disable_mask=disable_mask,
disable_preview=disable_preview,
disable_stats=disable_stats,
disable_export=disable_export,
disable_all=disable_all,
)
self.hidden_tools = self._build_hidden_tools(
hidden_tools=hidden_tools,
hide_display=hide_display,
hide_binning=hide_binning,
hide_mask=hide_mask,
hide_preview=hide_preview,
hide_stats=hide_stats,
hide_export=hide_export,
hide_all=hide_all,
)
if state is not None:
if isinstance(state, (str, pathlib.Path)):
state = unwrap_state_payload(
json.loads(pathlib.Path(state).read_text()),
require_envelope=True,
)
else:
state = unwrap_state_payload(state)
self.load_state_dict(state)
# ------------------------------------------------------------------
# Alternative constructors
# ------------------------------------------------------------------
@classmethod
def file_info(cls, path, det_bin_row=2, det_bin_col=2, edge_mode="crop"):
"""Print file summary without loading data.
Parameters
----------
path : str or pathlib.Path
Path to a master or saved HDF5 file.
det_bin_row : int
Detector row binning factor (for estimating binned size).
det_bin_col : int
Detector column binning factor (for estimating binned size).
edge_mode : {"crop", "pad", "error"}
Edge handling (for estimating binned size).
"""
if not _HAS_H5PY:
raise ImportError("h5py is required. Install: pip install h5py")
path = pathlib.Path(path).resolve()
disk_bytes = path.stat().st_size
disk_gb = disk_bytes / 1_000_000_000.0
with h5py.File(path, "r") as f:
is_arina = _is_arina_master(f)
if is_arina:
info = _discover_arina_chunks(path)
total = info["total_frames"]
det_r, det_c = info["det_rows"], info["det_cols"]
side = int(math.isqrt(total))
scan_r, scan_c = (side, side) if side * side == total else (total, 1)
raw_gb = total * det_r * det_c * 2 / 1e9 # uint16
binned_det_r, _ = _binned_axis_shape(det_r, det_bin_row, edge_mode, axis=0)
binned_det_c, _ = _binned_axis_shape(det_c, det_bin_col, edge_mode, axis=1)
mem_gb = scan_r * scan_c * binned_det_r * binned_det_c * 4 / 1e9
# sum chunk file sizes (master file itself is tiny)
chunk_paths = {c[0] for c in info["chunks"]}
disk_gb = sum(pathlib.Path(p).stat().st_size for p in chunk_paths) / 1e9
lines = [
f"File: {path.name}",
f"Format: Dectris arina master ({len(info['chunks'])} chunks)",
f"Disk: {disk_gb:.2f} GB (bitshuffle+LZ4, {len(chunk_paths)} files)",
f"Raw: {raw_gb:.1f} GB (uint16 uncompressed)",
f"Scan: {scan_r} x {scan_c} ({total:,} frames)",
f"Detector: {det_r} x {det_c}",
f"Binned: {scan_r} x {scan_c} x {binned_det_r} x {binned_det_c} (det {det_bin_row}x{det_bin_col})",
f"Memory: {mem_gb:.1f} GB (float32 after binning)",
]
else:
with h5py.File(path, "r") as f:
ds = f["data"]
shape = tuple(int(v) for v in ds.shape)
mem_gb = np.prod(shape) * 4 / 1e9
lines = [
f"File: {path.name}",
f"Format: Saved binned HDF5",
f"Disk: {disk_gb:.2f} GB (bitshuffle+LZ4)",
f"Shape: {shape}",
f"Memory: {mem_gb:.1f} GB (float32)",
]
print("\n".join(lines))
@classmethod
def from_file(
cls,
path,
scan_shape=None,
det_bin_row=2,
det_bin_col=2,
bin_mode="mean",
edge_mode="crop",
frames_per_batch=1000,
pixel_size=None,
k_pixel_size=None,
center=None,
title="",
device=None,
**kwargs,
) -> "Bin":
"""Load an HDF5 file into a Bin widget.
Auto-detects the file format:
- **Dectris arina master** (``*_master.h5`` with external-link chunks):
streams frames in small batches, bins detector axes on the fly, and
assembles only the binned result in memory.
- **Saved binned file** (written by :meth:`save_h5`): loads the 4D
dataset and calibration directly — no re-binning needed.
Parameters
----------
path : str or pathlib.Path
Path to the HDF5 file.
scan_shape : tuple[int, int], optional
Scan grid ``(rows, cols)``. Inferred as square if omitted.
Ignored when loading a saved binned file.
det_bin_row : int
Detector row binning factor (arina path only).
det_bin_col : int
Detector column binning factor (arina path only).
bin_mode : {"mean", "sum"}
Reduction mode for binning (arina path only).
edge_mode : {"crop", "pad", "error"}
How non-divisible detector edges are handled (arina path only).
frames_per_batch : int
Frames read per I/O batch (arina path only).
pixel_size : float or tuple[float, float], optional
Real-space sampling in Å/px. Overrides saved calibration.
k_pixel_size : float or tuple[float, float], optional
Detector sampling in mrad/px. Overrides saved calibration.
center : tuple[float, float], optional
Detector center ``(row, col)``. Overrides saved calibration.
title : str
Widget title.
device : str, optional
Torch device.
**kwargs
Forwarded to widget init (``cmap``, ``log_scale``, etc.).
Returns
-------
Bin
Widget ready for interactive exploration.
"""
if not _HAS_H5PY:
raise ImportError("h5py is required for Bin.from_file(). Install: pip install h5py")
if not _HAS_TORCH:
raise ImportError("torch is required for Bin.from_file()")
path = pathlib.Path(path).resolve()
# -- auto-detect format ------------------------------------------------
with h5py.File(path, "r") as probe:
is_arina = _is_arina_master(probe)
if is_arina:
return cls._from_arina(
path,
scan_shape=scan_shape,
det_bin_row=det_bin_row,
det_bin_col=det_bin_col,
bin_mode=bin_mode,
edge_mode=edge_mode,
frames_per_batch=frames_per_batch,
pixel_size=pixel_size,
k_pixel_size=k_pixel_size,
center=center,
title=title,
device=device,
**kwargs,
)
return cls._from_binned_h5(
path,
pixel_size=pixel_size,
k_pixel_size=k_pixel_size,
center=center,
title=title,
device=device,
**kwargs,
)
@classmethod
def _from_binned_h5(cls, path, pixel_size=None, k_pixel_size=None,
center=None, title="", device=None, **kwargs):
"""Load a saved binned HDF5 (written by :meth:`save_h5`).
Uses the same zero-copy construction as ``_from_arina`` to avoid
doubling memory for large datasets.
"""
t0 = time.time()
with h5py.File(path, "r") as f:
ds = f["data"]
data_np = np.asarray(ds, dtype=np.float32)
attrs = dict(ds.attrs)
elapsed = time.time() - t0
print(f" Loaded {path.name}: {data_np.shape} in {elapsed:.1f}s")
scan_r, scan_c = int(data_np.shape[0]), int(data_np.shape[1])
det_r, det_c = int(data_np.shape[2]), int(data_np.shape[3])
# calibration from file, overridable by explicit kwargs
if pixel_size is None and "pixel_size_row" in attrs:
pixel_size = (float(attrs["pixel_size_row"]), float(attrs["pixel_size_col"]))
if k_pixel_size is None and "k_pixel_size_row" in attrs:
k_pixel_size = (float(attrs["k_pixel_size_row"]), float(attrs["k_pixel_size_col"]))
if center is None and "center_row" in attrs:
center = (float(attrs["center_row"]), float(attrs["center_col"]))
if not title and "title" in attrs:
title = str(attrs["title"])
pixel_unit = str(attrs.get("pixel_unit", "px"))
k_unit = str(attrs.get("k_unit", "px"))
# resolve device
numel = int(data_np.size)
if device is not None:
device_str = str(device).strip().lower()
elif _HAS_VALIDATE_DEVICE:
device_str, _ = validate_device(None)
else:
device_str = (
"mps"
if torch.backends.mps.is_available()
else "cuda"
if torch.cuda.is_available()
else "cpu"
)
if device_str == "mps" and numel > 2**31 - 1:
device_str = "cpu"
torch.zeros(1, device=torch.device(device_str))
# numpy → torch without an extra copy (from_numpy shares memory on cpu)
data_torch = torch.from_numpy(data_np)
if device_str != "cpu":
data_torch = data_torch.to(torch.device(device_str))
# build widget directly — same pattern as _from_arina
cmap = kwargs.pop("cmap", "inferno")
log_scale = kwargs.pop("log_scale", False)
show_controls = kwargs.pop("show_controls", True)
bf_radius_ratio = float(kwargs.pop("bf_radius_ratio", 0.125))
adf_inner_ratio = float(kwargs.pop("adf_inner_ratio", 0.30))
adf_outer_ratio = float(kwargs.pop("adf_outer_ratio", 0.45))
inst = cls.__new__(cls)
anywidget.AnyWidget.__init__(inst, **kwargs)
inst.widget_version = resolve_widget_version()
inst.bin_mode = "mean"
inst.edge_mode = "crop"
inst.title = title
inst.cmap = cmap
inst.log_scale = log_scale
inst.show_controls = show_controls
inst.bf_radius_ratio = bf_radius_ratio
inst.adf_inner_ratio = adf_inner_ratio
inst.adf_outer_ratio = adf_outer_ratio
# calibration
p_row, p_col = _as_pair(pixel_size, 1.0)
k_row, k_col = _as_pair(k_pixel_size, 1.0)
inst.pixel_size_row = p_row
inst.pixel_size_col = p_col
inst.pixel_unit = pixel_unit if pixel_unit != "px" else ("Å" if pixel_size is not None else "px")
inst.pixel_calibrated = pixel_unit != "px" or pixel_size is not None
inst.k_pixel_size_row = k_row
inst.k_pixel_size_col = k_col
inst.k_unit = k_unit if k_unit != "px" else ("mrad" if k_pixel_size is not None else "px")
inst.k_calibrated = k_unit != "px" or k_pixel_size is not None
inst.scan_rows = scan_r
inst.scan_cols = scan_c
inst.det_rows = det_r
inst.det_cols = det_c
inst._device = torch.device(device_str)
inst._data_torch = data_torch
inst.device = device_str
inst.max_scan_bin_row = max(1, scan_r)
inst.max_scan_bin_col = max(1, scan_c)
inst.max_det_bin_row = max(1, det_r)
inst.max_det_bin_col = max(1, det_c)
if center is not None:
inst.center_row = float(center[0])
inst.center_col = float(center[1])
else:
inst.center_row = det_r / 2.0
inst.center_col = det_c / 2.0
inst._binned_data_torch = inst._data_torch
inst._original_bf_torch = torch.zeros((scan_r, scan_c), dtype=torch.float32)
inst._original_adf_torch = torch.zeros((scan_r, scan_c), dtype=torch.float32)
inst._binned_bf_torch = torch.zeros((scan_r, scan_c), dtype=torch.float32)
inst._binned_adf_torch = torch.zeros((scan_r, scan_c), dtype=torch.float32)
inst.observe(
inst._on_params_changed,
names=[
"scan_bin_row", "scan_bin_col", "det_bin_row", "det_bin_col",
"bin_mode", "edge_mode", "center_row", "center_col",
"bf_radius_ratio", "adf_inner_ratio", "adf_outer_ratio",
],
)
inst.observe(inst._on_position_changed, names=["_scan_position"])
inst.observe(inst._on_npy_export, names=["_npy_export_requested"])
inst.disabled_tools = cls._build_disabled_tools()
inst.hidden_tools = cls._build_hidden_tools()
inst._recompute_previews()
return inst
@classmethod
def _from_arina(cls, path, scan_shape=None, det_bin_row=2, det_bin_col=2,
bin_mode="mean", edge_mode="crop", frames_per_batch=1000,
pixel_size=None, k_pixel_size=None, center=None,
title="", device=None, **kwargs):
"""Stream arina master HDF5, bin detector axes on the fly."""
# -- discover chunks --------------------------------------------------
info = _discover_arina_chunks(path)
chunks = info["chunks"]
total_frames = info["total_frames"]
det_r = info["det_rows"]
det_c = info["det_cols"]
# -- scan shape --------------------------------------------------------
if scan_shape is None:
side = int(math.isqrt(total_frames))
if side * side != total_frames:
raise ValueError(
f"Cannot infer square scan_shape from {total_frames} frames. "
f"Provide scan_shape=(rows, cols)."
)
scan_shape = (side, side)
scan_r, scan_c = int(scan_shape[0]), int(scan_shape[1])
# -- binned detector shape ---------------------------------------------
binned_det_r, _ = _binned_axis_shape(det_r, det_bin_row, edge_mode, axis=0)
binned_det_c, _ = _binned_axis_shape(det_c, det_bin_col, edge_mode, axis=1)
# -- resolve torch device ----------------------------------------------
numel = scan_r * scan_c * binned_det_r * binned_det_c
if device is not None:
device_str = str(device).strip().lower()
elif _HAS_VALIDATE_DEVICE:
device_str, _ = validate_device(None)
else:
device_str = (
"mps"
if torch.backends.mps.is_available()
else "cuda"
if torch.cuda.is_available()
else "cpu"
)
if device_str == "mps" and numel > 2**31 - 1:
device_str = "cpu"
torch.zeros(1, device=torch.device(device_str))
# -- pre-allocate output -----------------------------------------------
output = torch.zeros(
(scan_r, scan_c, binned_det_r, binned_det_c),
dtype=torch.float32,
device=torch.device(device_str),
)
# -- stream chunks & bin detector axes ---------------------------------
frame_idx = 0
t0 = time.time()
for chunk_file, chunk_ds_path, chunk_n_frames in chunks:
with h5py.File(chunk_file, "r") as cf:
ds = cf[chunk_ds_path]
for batch_start in range(0, chunk_n_frames, frames_per_batch):
batch_end = min(batch_start + frames_per_batch, chunk_n_frames)
batch = torch.from_numpy(
np.asarray(ds[batch_start:batch_end], dtype=np.float32)
)
# bin detector rows (axis 1) then cols (axis 2)
binned = _bin_axis_standalone(batch, axis=1, factor=det_bin_row, mode=bin_mode, edge=edge_mode)
binned = _bin_axis_standalone(binned, axis=2, factor=det_bin_col, mode=bin_mode, edge=edge_mode)
# scatter into output using vectorized indexing
n_batch = binned.shape[0]
gi = torch.arange(frame_idx, frame_idx + n_batch, dtype=torch.long)
rows = gi // scan_c
cols = gi % scan_c
mask = rows < scan_r
if mask.any():
output[rows[mask], cols[mask]] = binned[mask].to(device=output.device)
frame_idx += n_batch
elapsed = time.time() - t0
pct = frame_idx / total_frames * 100
fps = frame_idx / max(elapsed, 1e-6)
remaining = (total_frames - frame_idx) / max(fps, 1e-6)
print(
f"\r Loading: {frame_idx:,}/{total_frames:,} frames "
f"({pct:.1f}%) {elapsed:.0f}s elapsed, ~{remaining:.0f}s remaining",
end="",
flush=True,
)
elapsed_total = time.time() - t0
print(f"\n Done: {total_frames:,} frames in {elapsed_total:.1f}s")
# -- build Bin instance without copying data ---------------------------
cmap = kwargs.pop("cmap", "inferno")
log_scale = kwargs.pop("log_scale", False)
show_controls = kwargs.pop("show_controls", True)
bf_radius_ratio = float(kwargs.pop("bf_radius_ratio", 0.125))
adf_inner_ratio = float(kwargs.pop("adf_inner_ratio", 0.30))
adf_outer_ratio = float(kwargs.pop("adf_outer_ratio", 0.45))
inst = cls.__new__(cls)
anywidget.AnyWidget.__init__(inst, **kwargs)
inst.widget_version = resolve_widget_version()
inst.bin_mode = bin_mode
inst.edge_mode = edge_mode
inst.title = title
inst.cmap = cmap
inst.log_scale = log_scale
inst.show_controls = show_controls
inst.bf_radius_ratio = bf_radius_ratio
inst.adf_inner_ratio = adf_inner_ratio
inst.adf_outer_ratio = adf_outer_ratio
# calibration
p_row, p_col = _as_pair(pixel_size, 1.0)
k_row, k_col = _as_pair(k_pixel_size, 1.0)
inst.pixel_size_row = p_row
inst.pixel_size_col = p_col
inst.pixel_unit = "Å" if pixel_size is not None else "px"
inst.pixel_calibrated = pixel_size is not None
inst.k_pixel_size_row = k_row
inst.k_pixel_size_col = k_col
inst.k_unit = "mrad" if k_pixel_size is not None else "px"
inst.k_calibrated = k_pixel_size is not None
# shape — the widget sees the already-binned detector dimensions
inst.scan_rows = scan_r
inst.scan_cols = scan_c
inst.det_rows = binned_det_r
inst.det_cols = binned_det_c
# torch data
inst._device = torch.device(device_str)
inst._data_torch = output
inst.device = device_str
# slider maxima
inst.max_scan_bin_row = max(1, scan_r)
inst.max_scan_bin_col = max(1, scan_c)
inst.max_det_bin_row = max(1, binned_det_r)
inst.max_det_bin_col = max(1, binned_det_c)
# detector center
if center is not None:
inst.center_row = float(center[0])
inst.center_col = float(center[1])
elif info["beam_center_y"] is not None and info["beam_center_x"] is not None:
inst.center_row = float(info["beam_center_y"]) / det_bin_row
inst.center_col = float(info["beam_center_x"]) / det_bin_col
else:
inst.center_row = binned_det_r / 2.0
inst.center_col = binned_det_c / 2.0
# internal preview tensors
inst._binned_data_torch = inst._data_torch
inst._original_bf_torch = torch.zeros((scan_r, scan_c), dtype=torch.float32)
inst._original_adf_torch = torch.zeros((scan_r, scan_c), dtype=torch.float32)
inst._binned_bf_torch = torch.zeros((scan_r, scan_c), dtype=torch.float32)
inst._binned_adf_torch = torch.zeros((scan_r, scan_c), dtype=torch.float32)
# observers
inst.observe(
inst._on_params_changed,
names=[
"scan_bin_row", "scan_bin_col", "det_bin_row", "det_bin_col",
"bin_mode", "edge_mode", "center_row", "center_col",
"bf_radius_ratio", "adf_inner_ratio", "adf_outer_ratio",
],
)
inst.observe(inst._on_position_changed, names=["_scan_position"])
inst.observe(inst._on_npy_export, names=["_npy_export_requested"])
inst.disabled_tools = cls._build_disabled_tools()
inst.hidden_tools = cls._build_hidden_tools()
inst._recompute_previews()
return inst
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
[docs]
def state_dict(self) -> dict:
return {
"title": self.title,
"scan_bin_row": self.scan_bin_row,
"scan_bin_col": self.scan_bin_col,
"det_bin_row": self.det_bin_row,
"det_bin_col": self.det_bin_col,
"bin_mode": self.bin_mode,
"edge_mode": self.edge_mode,
"center_row": self.center_row,
"center_col": self.center_col,
"bf_radius_ratio": self.bf_radius_ratio,
"adf_inner_ratio": self.adf_inner_ratio,
"adf_outer_ratio": self.adf_outer_ratio,
"cmap": self.cmap,
"log_scale": self.log_scale,
"auto_contrast": self.auto_contrast,
"show_fft": self.show_fft,
"show_controls": self.show_controls,
"disabled_tools": list(self.disabled_tools),
"hidden_tools": list(self.hidden_tools),
}
[docs]
def save(self, path: str | pathlib.Path) -> None:
save_state_file(path, "Bin", self.state_dict())
def save_h5(
self,
path: str | pathlib.Path,
source_file: str = "",
) -> pathlib.Path:
"""Save the current binned 4D data to HDF5 with bitshuffle + LZ4.
The file can be reloaded with ``Bin.from_file(path)`` — calibration,
center, and provenance metadata are stored as dataset attributes.
Parameters
----------
path : str or pathlib.Path
Output ``.h5`` file path.
source_file : str
Optional provenance string (e.g. original master file path).
Returns
-------
pathlib.Path
The written file path.
"""
if not _HAS_H5PY:
raise ImportError("h5py is required for save_h5(). Install: pip install h5py")
output_path = pathlib.Path(path)
output_path.parent.mkdir(parents=True, exist_ok=True)
arr = self._binned_data_torch.detach().cpu().numpy().astype(np.float32, copy=False)
compression_kwargs = {}
try:
compression_kwargs = dict(hdf5plugin.Bitshuffle(cname="lz4"))
except Exception:
pass
t0 = time.time()
with h5py.File(output_path, "w") as f:
ds = f.create_dataset(
"data",
data=arr,
chunks=True,
**compression_kwargs,
)
ds.attrs["pixel_size_row"] = float(self.binned_pixel_size_row)
ds.attrs["pixel_size_col"] = float(self.binned_pixel_size_col)
ds.attrs["pixel_unit"] = self.pixel_unit
ds.attrs["k_pixel_size_row"] = float(self.binned_k_pixel_size_row)
ds.attrs["k_pixel_size_col"] = float(self.binned_k_pixel_size_col)
ds.attrs["k_unit"] = self.k_unit
ds.attrs["center_row"] = float(self.binned_center_row)
ds.attrs["center_col"] = float(self.binned_center_col)
ds.attrs["title"] = self.title
ds.attrs["bin_factors"] = [
int(self.scan_bin_row),
int(self.scan_bin_col),
int(self.det_bin_row),
int(self.det_bin_col),
]
if source_file:
ds.attrs["source_file"] = str(source_file)
elapsed = time.time() - t0
size_gb = float(output_path.stat().st_size) / 1_000_000_000.0
print(f" Saved {output_path.name}: {arr.shape} in {elapsed:.1f}s ({size_gb:.2f} GB)")
return output_path
[docs]
def load_state_dict(self, state: dict) -> None:
for key, value in state.items():
if hasattr(self, key):
setattr(self, key, value)
@property
def result(self):
"""Current binned data as torch tensor."""
return self._binned_data_torch
[docs]
def get_binned_data(self, copy: bool = True, as_numpy: bool = False):
"""Return current binned data (torch by default)."""
if as_numpy:
arr = self._binned_data_torch.detach().cpu().numpy().astype(np.float32, copy=False)
return arr.copy() if copy else arr
return self._binned_data_torch.clone() if copy else self._binned_data_torch
[docs]
def set_data(
self,
data,
scan_shape: tuple[int, int] | None = None,
pixel_size: float | tuple[float, float] | None = None,
k_pixel_size: float | tuple[float, float] | None = None,
center: tuple[float, float] | None = None,
) -> Self:
"""Replace the 4D data while preserving display settings."""
dataset_pixel: float | tuple[float, float] | None = None
dataset_k: float | tuple[float, float] | None = None
dataset_pixel_unit = "px"
dataset_k_unit = "px"
pixel_calibrated = False
k_calibrated = False
if hasattr(data, "sampling") and hasattr(data, "array"):
units = list(getattr(data, "units", ["pixels"] * 4))
sampling = list(getattr(data, "sampling", [1.0] * 4))
if len(units) >= 2 and len(sampling) >= 2 and units[0] in _REAL_UNITS:
sy, sx = float(sampling[0]), float(sampling[1])
if units[0] == "nm":
sy *= 10.0
if units[1] == "nm":
sx *= 10.0
dataset_pixel = (sy, sx)
dataset_pixel_unit = "Å"
pixel_calibrated = True
if len(units) >= 4 and len(sampling) >= 4 and units[2] in _K_UNITS:
ky, kx = float(sampling[2]), float(sampling[3])
dataset_k = (ky, kx)
dataset_k_unit = units[2]
k_calibrated = True
data = data.array
p_row, p_col = _as_pair(pixel_size if pixel_size is not None else dataset_pixel, 1.0)
k_row, k_col = _as_pair(k_pixel_size if k_pixel_size is not None else dataset_k, 1.0)
if pixel_size is not None:
pixel_calibrated = True
dataset_pixel_unit = "Å"
if k_pixel_size is not None:
k_calibrated = True
dataset_k_unit = "mrad"
self.pixel_size_row = p_row
self.pixel_size_col = p_col
self.pixel_unit = dataset_pixel_unit if pixel_calibrated else "px"
self.pixel_calibrated = pixel_calibrated
self.k_pixel_size_row = k_row
self.k_pixel_size_col = k_col
self.k_unit = dataset_k_unit if k_calibrated else "px"
self.k_calibrated = k_calibrated
data_np = to_numpy(data, dtype=np.float32)
if data_np.ndim == 4:
scan_r, scan_c, det_r, det_c = data_np.shape
data4d = data_np
elif data_np.ndim == 3:
if scan_shape is None:
n = int(data_np.shape[0])
side = int(math.isqrt(n))
if side * side != n:
raise ValueError(
f"Cannot infer square scan_shape from flattened N={n}. Provide scan_shape=(rows, cols)."
)
scan_shape = (side, side)
if int(scan_shape[0]) * int(scan_shape[1]) != int(data_np.shape[0]):
raise ValueError(f"scan_shape={scan_shape} does not match flattened length {data_np.shape[0]}")
scan_r, scan_c = int(scan_shape[0]), int(scan_shape[1])
det_r, det_c = int(data_np.shape[1]), int(data_np.shape[2])
data4d = data_np.reshape(scan_r, scan_c, det_r, det_c)
else:
raise ValueError(f"Expected 4D or flattened 3D array, got {data_np.ndim}D")
self.scan_rows = scan_r
self.scan_cols = scan_c
self.det_rows = det_r
self.det_cols = det_c
self.max_scan_bin_row = max(1, scan_r)
self.max_scan_bin_col = max(1, scan_c)
self.max_det_bin_row = max(1, det_r)
self.max_det_bin_col = max(1, det_c)
data4d_writable = np.array(data4d, dtype=np.float32, copy=True)
self._data_torch = torch.from_numpy(data4d_writable).to(self._device)
if center is not None:
self.center_row = float(center[0])
self.center_col = float(center[1])
else:
self.center_row = det_r / 2.0
self.center_col = det_c / 2.0
self.scan_bin_row = 1
self.scan_bin_col = 1
self.det_bin_row = 1
self.det_bin_col = 1
self._binned_data_torch = self._data_torch
self._original_bf_torch = torch.zeros((scan_r, scan_c), dtype=torch.float32)
self._original_adf_torch = torch.zeros((scan_r, scan_c), dtype=torch.float32)
self._binned_bf_torch = torch.zeros((scan_r, scan_c), dtype=torch.float32)
self._binned_adf_torch = torch.zeros((scan_r, scan_c), dtype=torch.float32)
self.original_mean_dp_bytes = b""
self.binned_mean_dp_bytes = b""
self._recompute_previews()
return self
[docs]
def to_show4dstem(self, **kwargs):
"""Create a `Show4DSTEM` instance from the current binned data.
Notes
-----
`Show4DSTEM` currently accepts scalar calibrations; row/col calibrations
are collapsed using arithmetic mean for compatibility.
"""
from quantem.widget.show4dstem import Show4DSTEM
pixel_size = kwargs.pop(
"pixel_size",
float((self.binned_pixel_size_row + self.binned_pixel_size_col) / 2.0),
)
if self.k_calibrated:
k_pixel_size = kwargs.pop(
"k_pixel_size",
float((self.binned_k_pixel_size_row + self.binned_k_pixel_size_col) / 2.0),
)
else:
k_pixel_size = kwargs.pop("k_pixel_size", None)
center = kwargs.pop(
"center",
(
float(self.center_row / self.det_bin_row),
float(self.center_col / self.det_bin_col),
),
)
bf_radius = kwargs.pop(
"bf_radius",
float(self.bf_radius_ratio * min(self.binned_det_rows, self.binned_det_cols)),
)
return Show4DSTEM(
self.get_binned_data(copy=False),
pixel_size=pixel_size,
k_pixel_size=k_pixel_size,
center=center,
bf_radius=bf_radius,
**kwargs,
)
[docs]
def save_image(
self,
path: str | pathlib.Path,
view: str = "binned_bf",
cmap: str = "inferno",
scale_mode: str = "linear",
format: str | None = None,
include_metadata: bool = True,
metadata_path: str | pathlib.Path | None = None,
dpi: int = 300,
) -> pathlib.Path:
"""Save one Bin preview view (or 2x2 grid) as PNG/PDF with metadata."""
from PIL import Image
output_path = pathlib.Path(path)
output_path.parent.mkdir(parents=True, exist_ok=True)
fmt = self._resolve_export_format(output_path, format)
image, render_meta = self._render_view_image(view=view, cmap=cmap, scale_mode=scale_mode)
if fmt == "png":
image.save(output_path, format="PNG", dpi=(dpi, dpi))
else:
image.save(output_path, format="PDF", resolution=dpi)
if include_metadata:
meta_path = (
pathlib.Path(metadata_path)
if metadata_path is not None
else output_path.with_suffix(".json")
)
metadata = {
**build_json_header("Bin"),
"view": view,
"format": fmt,
"export_kind": "single_view_image",
"path": str(output_path),
"bin_factors": {
"scan_row": int(self.scan_bin_row),
"scan_col": int(self.scan_bin_col),
"det_row": int(self.det_bin_row),
"det_col": int(self.det_bin_col),
},
"bin_mode": self.bin_mode,
"edge_mode": self.edge_mode,
"shape": {
"input": [int(self.scan_rows), int(self.scan_cols), int(self.det_rows), int(self.det_cols)],
"output": [
int(self.binned_scan_rows),
int(self.binned_scan_cols),
int(self.binned_det_rows),
int(self.binned_det_cols),
],
},
"calibration": {
"pixel_size": [float(self.pixel_size_row), float(self.pixel_size_col)],
"pixel_size_binned": [
float(self.binned_pixel_size_row),
float(self.binned_pixel_size_col),
],
"k_pixel_size": [float(self.k_pixel_size_row), float(self.k_pixel_size_col)],
"k_pixel_size_binned": [
float(self.binned_k_pixel_size_row),
float(self.binned_k_pixel_size_col),
],
},
"render": render_meta,
}
meta_path.write_text(json.dumps(metadata, indent=2))
return output_path
[docs]
def save_zip(
self,
path: str | pathlib.Path,
cmap: str = "inferno",
scale_mode: str = "linear",
include_arrays: bool = False,
) -> pathlib.Path:
"""Export all Bin previews + metadata in a ZIP bundle."""
import io
import zipfile
zip_path = pathlib.Path(path)
zip_path.parent.mkdir(parents=True, exist_ok=True)
panels = ["original_bf", "original_adf", "binned_bf", "binned_adf", "grid"]
metadata = {
**build_json_header("Bin"),
"format": "zip",
"export_kind": "multi_panel_bundle",
"include_arrays": bool(include_arrays),
"bin_factors": {
"scan_row": int(self.scan_bin_row),
"scan_col": int(self.scan_bin_col),
"det_row": int(self.det_bin_row),
"det_col": int(self.det_bin_col),
},
"panels": {},
}
with zipfile.ZipFile(zip_path, mode="w", compression=zipfile.ZIP_DEFLATED) as zf:
for panel in panels:
image, render_meta = self._render_view_image(view=panel, cmap=cmap, scale_mode=scale_mode)
buf = io.BytesIO()
image.save(buf, format="PNG")
zf.writestr(f"{panel}.png", buf.getvalue())
metadata["panels"][panel] = render_meta
if include_arrays:
arrays = {
"original_bf.npy": self._original_bf_torch,
"original_adf.npy": self._original_adf_torch,
"binned_bf.npy": self._binned_bf_torch,
"binned_adf.npy": self._binned_adf_torch,
}
for name, tensor in arrays.items():
arr_buf = io.BytesIO()
np.save(arr_buf, tensor.detach().cpu().numpy().astype(np.float32, copy=False))
zf.writestr(name, arr_buf.getvalue())
zf.writestr("metadata.json", json.dumps(metadata, indent=2))
return zip_path
[docs]
def save_gif(
self,
path: str | pathlib.Path,
channel: str = "bf",
cmap: str = "inferno",
scale_mode: str = "linear",
duration_ms: int = 800,
loop: int = 0,
include_metadata: bool = True,
metadata_path: str | pathlib.Path | None = None,
) -> pathlib.Path:
"""Save a two-frame GIF comparing original vs binned BF/ADF previews."""
from PIL import Image
channel_key = str(channel).strip().lower()
if channel_key not in {"bf", "adf"}:
raise ValueError("channel must be 'bf' or 'adf'")
left = "original_bf" if channel_key == "bf" else "original_adf"
right = "binned_bf" if channel_key == "bf" else "binned_adf"
img_left, left_meta = self._render_view_image(view=left, cmap=cmap, scale_mode=scale_mode)
img_right, right_meta = self._render_view_image(view=right, cmap=cmap, scale_mode=scale_mode)
gif_path = pathlib.Path(path)
gif_path.parent.mkdir(parents=True, exist_ok=True)
img_left.save(
gif_path,
format="GIF",
save_all=True,
append_images=[img_right],
duration=max(10, int(duration_ms)),
loop=max(0, int(loop)),
)
if include_metadata:
meta_path = pathlib.Path(metadata_path) if metadata_path is not None else gif_path.with_suffix(".json")
metadata = {
**build_json_header("Bin"),
"format": "gif",
"export_kind": "before_after_animation",
"path": str(gif_path),
"channel": channel_key,
"duration_ms": int(max(10, int(duration_ms))),
"loop": int(max(0, int(loop))),
"display": {
"left": left_meta,
"right": right_meta,
},
"bin_factors": {
"scan_row": int(self.scan_bin_row),
"scan_col": int(self.scan_bin_col),
"det_row": int(self.det_bin_row),
"det_col": int(self.det_bin_col),
},
"bin_mode": self.bin_mode,
"edge_mode": self.edge_mode,
}
meta_path.write_text(json.dumps(metadata, indent=2))
return gif_path
[docs]
def summary(self) -> None:
"""Print compact binning + calibration summary."""
name = self.title if self.title else "Bin"
lines = [name, "═" * 32]
lines.append(f"Device: {self.device}")
lines.append(
f"Shape: ({self.scan_rows}, {self.scan_cols}, {self.det_rows}, {self.det_cols})"
f" -> ({self.binned_scan_rows}, {self.binned_scan_cols}, {self.binned_det_rows}, {self.binned_det_cols})"
)
lines.append(
"Factors: "
f"scan=({self.scan_bin_row}, {self.scan_bin_col}), "
f"det=({self.det_bin_row}, {self.det_bin_col}), "
f"mode={self.bin_mode}, edge={self.edge_mode}"
)
lines.append(
f"Real cal: ({self.pixel_size_row:.4g}, {self.pixel_size_col:.4g}) {self.pixel_unit}/px"
f" -> ({self.binned_pixel_size_row:.4g}, {self.binned_pixel_size_col:.4g})"
)
lines.append(
f"K cal: ({self.k_pixel_size_row:.4g}, {self.k_pixel_size_col:.4g}) {self.k_unit}/px"
f" -> ({self.binned_k_pixel_size_row:.4g}, {self.binned_k_pixel_size_col:.4g})"
)
if self.disabled_tools:
lines.append(f"Locked: {', '.join(self.disabled_tools)}")
if self.hidden_tools:
lines.append(f"Hidden: {', '.join(self.hidden_tools)}")
if self.status_message:
lines.append(f"Status: {self.status_level.upper()} - {self.status_message}")
print("\n".join(lines))
# ------------------------------------------------------------------
# Internal compute
# ------------------------------------------------------------------
def _on_params_changed(self, change=None):
self._recompute_previews()
def _on_position_changed(self, change=None):
pos = self._scan_position
if len(pos) != 2 or pos[0] < 0 or pos[1] < 0:
self._position_dp_bytes = b""
self._binned_position_dp_bytes = b""
return
r = min(max(0, pos[0]), self.scan_rows - 1)
c = min(max(0, pos[1]), self.scan_cols - 1)
dp = self._data_torch[r, c, :, :]
self._position_dp_bytes = dp.detach().cpu().contiguous().float().numpy().astype(np.float32, copy=False).tobytes()
br = min(r // max(1, self.scan_bin_row), self.binned_scan_rows - 1)
bc = min(c // max(1, self.scan_bin_col), self.binned_scan_cols - 1)
dp_binned = self._binned_data_torch[br, bc, :, :]
self._binned_position_dp_bytes = dp_binned.detach().cpu().contiguous().float().numpy().astype(np.float32, copy=False).tobytes()
def _on_npy_export(self, change=None):
import io
if not self._npy_export_requested:
return
self._npy_export_requested = False
arr = self._binned_data_torch.detach().cpu().numpy().astype(np.float32, copy=False)
buf = io.BytesIO()
np.save(buf, arr)
self._npy_export_data = buf.getvalue()
def _resolve_torch_device(self, requested: str | None, numel: int) -> str | None:
"""Pick a valid torch device from user request or quantem config."""
if not _HAS_TORCH:
return None
if requested is not None:
device_str = str(requested).strip().lower()
elif _HAS_VALIDATE_DEVICE:
device_str, _ = validate_device(None)
else:
device_str = (
"mps"
if torch.backends.mps.is_available()
else "cuda"
if torch.cuda.is_available()
else "cpu"
)
# MPS has tensor-size limitations similar to Show4DSTEM.
if device_str == "mps" and numel > 2**31 - 1:
device_str = "cpu"
try:
torch.zeros(1, device=torch.device(device_str))
except Exception:
return None
return device_str
def _resolve_export_format(
self,
path: pathlib.Path,
fmt: str | None,
) -> str:
if fmt is not None and str(fmt).strip():
resolved = str(fmt).strip().lower()
else:
resolved = path.suffix.lstrip(".").lower() or "png"
if resolved not in {"png", "pdf"}:
raise ValueError(f"Unsupported format '{resolved}'. Supported: png, pdf")
return resolved
def _get_panel_tensor(self, view: str):
key = str(view).strip().lower()
mapping = {
"original_bf": self._original_bf_torch,
"original_adf": self._original_adf_torch,
"binned_bf": self._binned_bf_torch,
"binned_adf": self._binned_adf_torch,
}
if key not in mapping:
raise ValueError(
"view must be one of: original_bf, original_adf, binned_bf, binned_adf, grid"
)
return mapping[key]
def _tensor_to_rgb(self, tensor, cmap: str, scale_mode: str) -> np.ndarray:
from matplotlib import colormaps
arr = tensor.detach().cpu().float()
if scale_mode == "log":
arr = torch.log1p(torch.clamp_min(arr, 0.0))
elif scale_mode == "power":
arr = torch.sqrt(torch.clamp_min(arr, 0.0))
elif scale_mode != "linear":
raise ValueError("scale_mode must be 'linear', 'log', or 'power'")
if arr.numel() == 0:
dmin = 0.0
dmax = 0.0
else:
dmin = float(torch.min(arr))
dmax = float(torch.max(arr))
if dmax <= dmin:
normalized = torch.zeros_like(arr, dtype=torch.float32)
else:
normalized = torch.clamp((arr - dmin) / (dmax - dmin), 0.0, 1.0)
rgba = colormaps.get_cmap(cmap)(normalized.numpy())
return (rgba[..., :3] * 255).astype(np.uint8)
def _render_view_image(self, view: str, cmap: str, scale_mode: str):
from PIL import Image
view_key = str(view).strip().lower()
if view_key == "grid":
panels = ["original_bf", "original_adf", "binned_bf", "binned_adf"]
rgbs = [self._tensor_to_rgb(self._get_panel_tensor(v), cmap=cmap, scale_mode=scale_mode) for v in panels]
h0 = max(int(rgbs[0].shape[0]), int(rgbs[1].shape[0]))
h1 = max(int(rgbs[2].shape[0]), int(rgbs[3].shape[0]))
w0 = max(int(rgbs[0].shape[1]), int(rgbs[2].shape[1]))
w1 = max(int(rgbs[1].shape[1]), int(rgbs[3].shape[1]))
grid = np.zeros((h0 + h1, w0 + w1, 3), dtype=np.uint8)
grid[: rgbs[0].shape[0], : rgbs[0].shape[1]] = rgbs[0]
grid[: rgbs[1].shape[0], w0 : w0 + rgbs[1].shape[1]] = rgbs[1]
grid[h0 : h0 + rgbs[2].shape[0], : rgbs[2].shape[1]] = rgbs[2]
grid[h0 : h0 + rgbs[3].shape[0], w0 : w0 + rgbs[3].shape[1]] = rgbs[3]
meta = {"view": "grid", "panels": panels, "colormap": cmap, "scale_mode": scale_mode}
return Image.fromarray(grid, mode="RGB"), meta
panel = self._get_panel_tensor(view_key)
rgb = self._tensor_to_rgb(panel, cmap=cmap, scale_mode=scale_mode)
meta = {
"view": view_key,
"shape": [int(panel.shape[0]), int(panel.shape[1])],
"colormap": cmap,
"scale_mode": scale_mode,
}
return Image.fromarray(rgb, mode="RGB"), meta
def _recompute_previews(self) -> None:
# Ensure annulus is valid
if self.adf_outer_ratio <= self.adf_inner_ratio:
self.status_level = "warn"
self.status_message = "ADF outer ratio must be greater than inner ratio."
self.adf_outer_ratio = float(self.adf_inner_ratio + 1e-3)
return
# Torch-only compute path
try:
binned_t = self._bin_4d_torch(
self._data_torch,
factors=(self.scan_bin_row, self.scan_bin_col, self.det_bin_row, self.det_bin_col),
mode=self.bin_mode,
edge=self.edge_mode,
)
orig_bf_t, orig_adf_t = self._virtual_images_torch(
self._data_torch,
center=(self.center_row, self.center_col),
bf_ratio=self.bf_radius_ratio,
adf_inner_ratio=self.adf_inner_ratio,
adf_outer_ratio=self.adf_outer_ratio,
)
except ValueError as exc:
self.status_level = "error"
self.status_message = str(exc)
return
except Exception as exc:
self.status_level = "error"
self.status_message = str(exc)
return
self._binned_data_torch = binned_t
self.binned_scan_rows = int(binned_t.shape[0])
self.binned_scan_cols = int(binned_t.shape[1])
self.binned_det_rows = int(binned_t.shape[2])
self.binned_det_cols = int(binned_t.shape[3])
self.binned_pixel_size_row = float(self.pixel_size_row * self.scan_bin_row)
self.binned_pixel_size_col = float(self.pixel_size_col * self.scan_bin_col)
self.binned_k_pixel_size_row = float(self.k_pixel_size_row * self.det_bin_row)
self.binned_k_pixel_size_col = float(self.k_pixel_size_col * self.det_bin_col)
# Detector center maps with detector bin factors (clamped to bounds)
self.binned_center_row = float(
max(0.0, min(self.binned_det_rows - 1, self.center_row / self.det_bin_row))
)
self.binned_center_col = float(
max(0.0, min(self.binned_det_cols - 1, self.center_col / self.det_bin_col))
)
binned_bf_t, binned_adf_t = self._virtual_images_torch(
binned_t,
center=(self.binned_center_row, self.binned_center_col),
bf_ratio=self.bf_radius_ratio,
adf_inner_ratio=self.adf_inner_ratio,
adf_outer_ratio=self.adf_outer_ratio,
)
self._original_bf_torch = orig_bf_t.detach().cpu().float()
self._original_adf_torch = orig_adf_t.detach().cpu().float()
self._binned_bf_torch = binned_bf_t.detach().cpu().float()
self._binned_adf_torch = binned_adf_t.detach().cpu().float()
self.original_bf_bytes = (
orig_bf_t.detach().cpu().contiguous().numpy().astype(np.float32, copy=False).tobytes()
)
self.original_adf_bytes = (
orig_adf_t.detach().cpu().contiguous().numpy().astype(np.float32, copy=False).tobytes()
)
self.binned_bf_bytes = (
binned_bf_t.detach().cpu().contiguous().numpy().astype(np.float32, copy=False).tobytes()
)
self.binned_adf_bytes = (
binned_adf_t.detach().cpu().contiguous().numpy().astype(np.float32, copy=False).tobytes()
)
# Mean diffraction patterns (detector-space previews)
mean_dp_orig = self._data_torch.mean(dim=(0, 1))
self.original_mean_dp_bytes = (
mean_dp_orig.detach().cpu().contiguous().float().numpy().astype(np.float32, copy=False).tobytes()
)
mean_dp_binned = binned_t.mean(dim=(0, 1))
self.binned_mean_dp_bytes = (
mean_dp_binned.detach().cpu().contiguous().float().numpy().astype(np.float32, copy=False).tobytes()
)
self.original_bf_stats = _qc_stats_torch(orig_bf_t)
self.original_adf_stats = _qc_stats_torch(orig_adf_t)
self.binned_bf_stats = _qc_stats_torch(binned_bf_t)
self.binned_adf_stats = _qc_stats_torch(binned_adf_t)
self.status_level = "ok"
self.status_message = (
f"Preview updated on torch/{self.device}: "
f"({self.scan_rows}×{self.scan_cols}×{self.det_rows}×{self.det_cols})"
f" -> ({self.binned_scan_rows}×{self.binned_scan_cols}×{self.binned_det_rows}×{self.binned_det_cols})"
)
# Refresh per-position DP if a position is selected
self._on_position_changed()
def _virtual_images_torch(
self,
data4d,
center: tuple[float, float],
bf_ratio: float,
adf_inner_ratio: float,
adf_outer_ratio: float,
):
"""Compute BF/ADF virtual images using torch tensors."""
det_rows, det_cols = int(data4d.shape[2]), int(data4d.shape[3])
center_row, center_col = float(center[0]), float(center[1])
det_size = float(min(det_rows, det_cols))
bf_radius = max(1e-6, bf_ratio * det_size)
adf_inner = max(0.0, adf_inner_ratio * det_size)
adf_outer = max(adf_inner + 1e-6, adf_outer_ratio * det_size)
rr = torch.arange(det_rows, device=data4d.device, dtype=torch.float32)[:, None]
cc = torch.arange(det_cols, device=data4d.device, dtype=torch.float32)[None, :]
dist2 = (rr - center_row) ** 2 + (cc - center_col) ** 2
bf_mask = (dist2 <= bf_radius**2).float()
adf_mask = ((dist2 >= adf_inner**2) & (dist2 <= adf_outer**2)).float()
bf = torch.tensordot(data4d, bf_mask, dims=([2, 3], [0, 1]))
adf = torch.tensordot(data4d, adf_mask, dims=([2, 3], [0, 1]))
return bf, adf
def _bin_axis_torch(self, data, axis: int, factor: int, mode: str, edge: str):
"""Torch equivalent of `_bin_axis`."""
if factor == 1:
return data
n = int(data.shape[axis])
if edge == "crop":
n_used = (n // factor) * factor
if n_used <= 0:
raise ValueError(
f"crop mode: factor {factor} is larger than axis size {n} for axis {axis}"
)
trimmed = data.narrow(axis, 0, n_used)
elif edge == "pad":
n_used = int(math.ceil(n / factor) * factor)
pad_amount = n_used - n
if pad_amount > 0:
pad_shape = list(data.shape)
pad_shape[axis] = pad_amount
pad_block = torch.zeros(
pad_shape,
dtype=data.dtype,
device=data.device,
)
trimmed = torch.cat([data, pad_block], dim=axis)
else:
trimmed = data
else: # edge == "error"
if n % factor != 0:
raise ValueError(
f"error mode: axis size {n} is not divisible by factor {factor} (axis {axis})"
)
n_used = n
trimmed = data
new_shape = (
tuple(trimmed.shape[:axis])
+ (n_used // factor, factor)
+ tuple(trimmed.shape[axis + 1 :])
)
reshaped = trimmed.reshape(new_shape)
reduce_axis = axis + 1
if mode == "sum":
return reshaped.sum(dim=reduce_axis)
return reshaped.mean(dim=reduce_axis)
def _bin_4d_torch(
self,
data4d,
factors: tuple[int, int, int, int],
mode: str,
edge: str,
):
out = data4d
for axis, factor in enumerate(factors):
out = self._bin_axis_torch(out, axis=axis, factor=int(factor), mode=mode, edge=edge)
return out.float()
# ------------------------------------------------------------------
# Representation
# ------------------------------------------------------------------
def __repr__(self) -> str:
title_info = f", title='{self.title}'" if self.title else ""
return (
"Bin("
f"shape=({self.scan_rows}, {self.scan_cols}, {self.det_rows}, {self.det_cols}), "
f"bin=({self.scan_bin_row}, {self.scan_bin_col}, {self.det_bin_row}, {self.det_bin_col}), "
f"binned_shape=({self.binned_scan_rows}, {self.binned_scan_cols}, {self.binned_det_rows}, {self.binned_det_cols}), "
f"mode={self.bin_mode}, edge={self.edge_mode}, device={self.device}"
f"{title_info})"
)
bind_tool_runtime_api(Bin, "Bin")