Source code for quantem.widget.show4d

"""
show4d: General-purpose 4D data explorer widget.

Interactive dual-panel viewer for any 4D dataset (nav_rows, nav_cols, sig_rows, sig_cols).
Left panel shows a navigation image (real-space), right panel shows the signal
at the selected position. Supports ROI masking on the navigation panel to
average signals from a region, path animation, and GPU acceleration.
"""

import json
import pathlib
from typing import Self

import numpy as np
import anywidget
import traitlets

from quantem.widget.array_utils import to_numpy
from quantem.widget.io import IOResult
from quantem.widget.json_state import build_json_header, resolve_widget_version, save_state_file, unwrap_state_payload
from quantem.widget.tool_parity import (
    bind_tool_runtime_api,
    build_tool_groups,
    normalize_tool_groups,
)

try:
    import torch
    _HAS_TORCH = True
except ImportError:
    _HAS_TORCH = False

try:
    from quantem.core.config import validate_device
    _HAS_VALIDATE_DEVICE = True
except ImportError:
    _HAS_VALIDATE_DEVICE = False

[docs] class Show4D(anywidget.AnyWidget): """ General-purpose 4D data explorer. Displays a navigation image (left) and signal at selected position (right). Click/drag on the navigation image to explore. Draw ROI masks to average signals from a region. Supports path animation for automated scanning. Parameters ---------- data : array_like 4D array of shape (nav_rows, nav_cols, sig_rows, sig_cols). Accepts NumPy, PyTorch, CuPy, or any np.asarray()-compatible object. nav_image : array_like, optional 2D array (nav_rows, nav_cols) to use as navigation image. If not provided, defaults to mean over signal dimensions. title : str, optional Title displayed in the widget header. cmap : str, default "inferno" Colormap for the signal panel. log_scale : bool, default False Apply log scale to signal panel. auto_contrast : bool, default True Auto-contrast with percentile clipping on signal panel. show_stats : bool, default True Show statistics bar below canvases. show_fft : bool, default False Show FFT panel for signal. nav_pixel_size : float, optional Pixel size for navigation space scale bar. sig_pixel_size : float, optional Pixel size for signal space scale bar. nav_pixel_unit : str, default "px" Unit for navigation pixel size. sig_pixel_unit : str, default "px" Unit for signal pixel size. percentile_low : float, default 0.5 Low percentile for auto-contrast clipping. percentile_high : float, default 99.5 High percentile for auto-contrast clipping. disabled_tools : list of str, optional Tool groups to lock while still showing controls. Supported: ``"display"``, ``"roi"``, ``"histogram"``, ``"profile"``, ``"navigation"``, ``"playback"``, ``"stats"``, ``"export"``, ``"view"``, ``"fft"``, ``"all"``. disable_* : bool, optional Convenience flags mirroring ``disabled_tools`` for each tool group, plus ``disable_all``. hidden_tools : list of str, optional Tool groups to hide from the UI. Uses the same keys as ``disabled_tools``. hide_* : bool, optional Convenience flags mirroring ``disable_*`` for ``hidden_tools``. """ _esm = pathlib.Path(__file__).parent / "static" / "show4d.js" _css = pathlib.Path(__file__).parent / "static" / "show4d.css" # Navigation position (row, col) pos_row = traitlets.Int(0).tag(sync=True) pos_col = traitlets.Int(0).tag(sync=True) # Shape nav_rows = traitlets.Int(1).tag(sync=True) nav_cols = traitlets.Int(1).tag(sync=True) sig_rows = traitlets.Int(1).tag(sync=True) sig_cols = traitlets.Int(1).tag(sync=True) # Data transfer (raw float32 bytes) frame_bytes = traitlets.Bytes(b"").tag(sync=True) nav_image_bytes = traitlets.Bytes(b"").tag(sync=True) # Data ranges for JS normalization nav_data_min = traitlets.Float(0.0).tag(sync=True) nav_data_max = traitlets.Float(1.0).tag(sync=True) sig_data_min = traitlets.Float(0.0).tag(sync=True) sig_data_max = traitlets.Float(1.0).tag(sync=True) # ROI on navigation panel roi_mode = traitlets.Unicode("off").tag(sync=True) # "off", "circle", "square", "rect", "annular" roi_reduce = traitlets.Unicode("mean").tag(sync=True) # "mean", "max", "min", "sum" roi_center_row = traitlets.Float(0.0).tag(sync=True) roi_center_col = traitlets.Float(0.0).tag(sync=True) roi_center = traitlets.List(traitlets.Float(), default_value=[0.0, 0.0]).tag(sync=True) roi_radius = traitlets.Float(5.0).tag(sync=True) roi_radius_inner = traitlets.Float(0.0).tag(sync=True) roi_width = traitlets.Float(10.0).tag(sync=True) roi_height = traitlets.Float(10.0).tag(sync=True) # Statistics ([mean, min, max, std]) nav_stats = traitlets.List(traitlets.Float(), default_value=[0.0, 0.0, 0.0, 0.0]).tag(sync=True) sig_stats = traitlets.List(traitlets.Float(), default_value=[0.0, 0.0, 0.0, 0.0]).tag(sync=True) # Display traits (synced to Python) cmap = traitlets.Unicode("inferno").tag(sync=True) log_scale = traitlets.Bool(False).tag(sync=True) auto_contrast = traitlets.Bool(True).tag(sync=True) show_stats = traitlets.Bool(True).tag(sync=True) show_controls = traitlets.Bool(True).tag(sync=True) show_fft = traitlets.Bool(False).tag(sync=True) fft_window = traitlets.Bool(True).tag(sync=True) disabled_tools = traitlets.List(traitlets.Unicode()).tag(sync=True) hidden_tools = traitlets.List(traitlets.Unicode()).tag(sync=True) percentile_low = traitlets.Float(0.5).tag(sync=True) percentile_high = traitlets.Float(99.5).tag(sync=True) # Scale bars nav_pixel_size = traitlets.Float(0.0).tag(sync=True) sig_pixel_size = traitlets.Float(0.0).tag(sync=True) nav_pixel_unit = traitlets.Unicode("px").tag(sync=True) sig_pixel_unit = traitlets.Unicode("px").tag(sync=True) # Title title = traitlets.Unicode("").tag(sync=True) # Snap-to-peak snap_enabled = traitlets.Bool(False).tag(sync=True) snap_radius = traitlets.Int(5).tag(sync=True) # Path animation path_playing = traitlets.Bool(False).tag(sync=True) path_index = traitlets.Int(0).tag(sync=True) path_length = traitlets.Int(0).tag(sync=True) path_interval_ms = traitlets.Int(100).tag(sync=True) path_loop = traitlets.Bool(True).tag(sync=True) # Export (GIF) _gif_export_requested = traitlets.Bool(False).tag(sync=True) _gif_data = traitlets.Bytes(b"").tag(sync=True) _gif_metadata_json = traitlets.Unicode("").tag(sync=True) # Line Profile profile_line = traitlets.List(traitlets.Dict()).tag(sync=True) profile_width = traitlets.Int(1).tag(sync=True) @classmethod def _normalize_tool_groups(cls, tool_groups): return normalize_tool_groups("Show4D", tool_groups) @classmethod def _build_disabled_tools( cls, disabled_tools=None, disable_display: bool = False, disable_roi: bool = False, disable_histogram: bool = False, disable_profile: bool = False, disable_navigation: bool = False, disable_playback: bool = False, disable_stats: bool = False, disable_export: bool = False, disable_view: bool = False, disable_fft: bool = False, disable_all: bool = False, ): return build_tool_groups( "Show4D", tool_groups=disabled_tools, all_flag=disable_all, flag_map={ "display": disable_display, "roi": disable_roi, "histogram": disable_histogram, "profile": disable_profile, "navigation": disable_navigation, "playback": disable_playback, "stats": disable_stats, "export": disable_export, "view": disable_view, "fft": disable_fft, }, ) @classmethod def _build_hidden_tools( cls, hidden_tools=None, hide_display: bool = False, hide_roi: bool = False, hide_histogram: bool = False, hide_profile: bool = False, hide_navigation: bool = False, hide_playback: bool = False, hide_stats: bool = False, hide_export: bool = False, hide_view: bool = False, hide_fft: bool = False, hide_all: bool = False, ): return build_tool_groups( "Show4D", tool_groups=hidden_tools, all_flag=hide_all, flag_map={ "display": hide_display, "roi": hide_roi, "histogram": hide_histogram, "profile": hide_profile, "navigation": hide_navigation, "playback": hide_playback, "stats": hide_stats, "export": hide_export, "view": hide_view, "fft": hide_fft, }, ) @traitlets.validate("disabled_tools") def _validate_disabled_tools(self, proposal): return self._normalize_tool_groups(proposal["value"]) @traitlets.validate("hidden_tools") def _validate_hidden_tools(self, proposal): return self._normalize_tool_groups(proposal["value"]) def __init__( self, data, nav_image=None, title="", cmap="inferno", log_scale=False, auto_contrast=True, show_stats=True, show_controls=True, show_fft=False, fft_window=True, nav_pixel_size=None, sig_pixel_size=None, nav_pixel_unit="px", sig_pixel_unit="px", percentile_low=0.5, percentile_high=99.5, snap_enabled=False, snap_radius=5, disabled_tools=None, disable_display=False, disable_roi=False, disable_histogram=False, disable_profile=False, disable_navigation=False, disable_playback=False, disable_stats=False, disable_export=False, disable_view=False, disable_fft=False, disable_all=False, hidden_tools=None, hide_display=False, hide_roi=False, hide_histogram=False, hide_profile=False, hide_navigation=False, hide_playback=False, hide_stats=False, hide_export=False, hide_view=False, hide_fft=False, hide_all=False, state=None, **kwargs, ): super().__init__(**kwargs) self.widget_version = resolve_widget_version() # Check if data is an IOResult and extract metadata if isinstance(data, IOResult): if not title and data.title: title = data.title data = data.data # Dataset duck typing if hasattr(data, "array") and hasattr(data, "sampling"): units = getattr(data, "units", ["pixels"] * 4) if nav_pixel_size is None and units[0] in ("Å", "angstrom", "A"): nav_pixel_size = float(data.sampling[0]) nav_pixel_unit = "Å" elif nav_pixel_size is None and units[0] == "nm": nav_pixel_size = float(data.sampling[0]) * 10 nav_pixel_unit = "Å" if sig_pixel_size is None and len(units) >= 4: if units[2] in ("Å", "angstrom", "A"): sig_pixel_size = float(data.sampling[2]) sig_pixel_unit = "Å" elif units[2] == "nm": sig_pixel_size = float(data.sampling[2]) * 10 sig_pixel_unit = "Å" elif units[2] == "mrad": sig_pixel_size = float(data.sampling[2]) sig_pixel_unit = "mrad" if not title and hasattr(data, "name") and data.name: title = data.name data = data.array # Convert to NumPy float32 data_np = to_numpy(data).astype(np.float32) if data_np.ndim != 4: raise ValueError( f"Expected 4D array (nav_rows, nav_cols, sig_rows, sig_cols), got {data_np.ndim}D" ) self.nav_rows = data_np.shape[0] self.nav_cols = data_np.shape[1] self.sig_rows = data_np.shape[2] self.sig_cols = data_np.shape[3] self.title = title # Display traits self.cmap = str(cmap) self.log_scale = log_scale self.auto_contrast = auto_contrast self.show_stats = show_stats self.show_controls = show_controls self.show_fft = show_fft self.fft_window = fft_window self.percentile_low = percentile_low self.percentile_high = percentile_high self.snap_enabled = snap_enabled self.snap_radius = snap_radius self.disabled_tools = self._build_disabled_tools( disabled_tools=disabled_tools, disable_display=disable_display, disable_roi=disable_roi, disable_histogram=disable_histogram, disable_profile=disable_profile, disable_navigation=disable_navigation, disable_playback=disable_playback, disable_stats=disable_stats, disable_export=disable_export, disable_view=disable_view, disable_fft=disable_fft, disable_all=disable_all, ) self.hidden_tools = self._build_hidden_tools( hidden_tools=hidden_tools, hide_display=hide_display, hide_roi=hide_roi, hide_histogram=hide_histogram, hide_profile=hide_profile, hide_navigation=hide_navigation, hide_playback=hide_playback, hide_stats=hide_stats, hide_export=hide_export, hide_view=hide_view, hide_fft=hide_fft, hide_all=hide_all, ) # Scale bar self.nav_pixel_size = nav_pixel_size if nav_pixel_size is not None else 0.0 self.sig_pixel_size = sig_pixel_size if sig_pixel_size is not None else 0.0 self.nav_pixel_unit = nav_pixel_unit self.sig_pixel_unit = sig_pixel_unit # Store data — GPU (PyTorch) if available, else NumPy if _HAS_TORCH: if _HAS_VALIDATE_DEVICE: device_str, _ = validate_device(None) self._device = torch.device(device_str) else: self._device = torch.device( "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu" ) self._data = torch.from_numpy(data_np).to(self._device) # Cache coordinate tensors for fast mask creation self._nav_row_coords = torch.arange( self.nav_rows, device=self._device, dtype=torch.float32 )[:, None] self._nav_col_coords = torch.arange( self.nav_cols, device=self._device, dtype=torch.float32 )[None, :] else: self._data = data_np self._device = None # Compute navigation image if nav_image is not None: nav_img = to_numpy(nav_image).astype(np.float32) elif _HAS_TORCH and isinstance(self._data, torch.Tensor): nav_img = self._data.mean(dim=(2, 3)).cpu().numpy() else: nav_img = data_np.mean(axis=(2, 3)) self._nav_image = nav_img # Compute global signal range if _HAS_TORCH and isinstance(self._data, torch.Tensor): self.sig_data_min = float(self._data.min()) self.sig_data_max = float(self._data.max()) else: n_total = self.nav_rows * self.nav_cols if n_total > 100: rng = np.random.default_rng(42) indices = rng.choice(n_total, 100, replace=False) sampled = data_np.reshape(n_total, self.sig_rows, self.sig_cols)[indices] self.sig_data_min = float(sampled.min()) self.sig_data_max = float(sampled.max()) else: self.sig_data_min = float(data_np.min()) self.sig_data_max = float(data_np.max()) # Nav image range and bytes (sent once) self.nav_data_min = float(nav_img.min()) self.nav_data_max = float(nav_img.max()) self.nav_image_bytes = nav_img.tobytes() self.nav_stats = [ float(nav_img.mean()), float(nav_img.min()), float(nav_img.max()), float(nav_img.std()), ] # Initial position at center self.pos_row = self.nav_rows // 2 self.pos_col = self.nav_cols // 2 # ROI defaults default_roi_size = max(3, min(self.nav_rows, self.nav_cols) * 0.15) self.roi_center_row = float(self.nav_rows / 2) self.roi_center_col = float(self.nav_cols / 2) self.roi_center = [float(self.nav_rows / 2), float(self.nav_cols / 2)] self.roi_radius = float(default_roi_size) self.roi_radius_inner = float(default_roi_size * 0.5) self.roi_width = float(default_roi_size * 2) self.roi_height = float(default_roi_size) # Path animation state self._path_points: list[tuple[int, int]] = [] # Observers self.observe(self._update_frame, names=["pos_row", "pos_col"]) self.observe(self._on_roi_change, names=[ "roi_mode", "roi_reduce", "roi_center_row", "roi_center_col", "roi_radius", "roi_radius_inner", "roi_width", "roi_height", ]) self.observe(self._on_roi_center_change, names=["roi_center"]) self.observe(self._on_path_index_change, names=["path_index"]) self.observe(self._on_gif_export, names=["_gif_export_requested"]) # Initial frame self._update_frame() if state is not None: if isinstance(state, (str, pathlib.Path)): state = unwrap_state_payload( json.loads(pathlib.Path(state).read_text()), require_envelope=True, ) else: state = unwrap_state_payload(state) self.load_state_dict(state)
[docs] def set_image(self, data, nav_image=None): """Replace the 4D data. Preserves all display settings.""" if hasattr(data, "array") and hasattr(data, "sampling"): data = data.array data_np = to_numpy(data).astype(np.float32) if data_np.ndim != 4: raise ValueError( f"Expected 4D array (nav_rows, nav_cols, sig_rows, sig_cols), got {data_np.ndim}D" ) self.nav_rows = data_np.shape[0] self.nav_cols = data_np.shape[1] self.sig_rows = data_np.shape[2] self.sig_cols = data_np.shape[3] if _HAS_TORCH: self._data = torch.from_numpy(data_np).to(self._device) self._nav_row_coords = torch.arange( self.nav_rows, device=self._device, dtype=torch.float32 )[:, None] self._nav_col_coords = torch.arange( self.nav_cols, device=self._device, dtype=torch.float32 )[None, :] else: self._data = data_np if nav_image is not None: nav_img = to_numpy(nav_image).astype(np.float32) elif _HAS_TORCH and isinstance(self._data, torch.Tensor): nav_img = self._data.mean(dim=(2, 3)).cpu().numpy() else: nav_img = data_np.mean(axis=(2, 3)) self._nav_image = nav_img if _HAS_TORCH and isinstance(self._data, torch.Tensor): self.sig_data_min = float(self._data.min()) self.sig_data_max = float(self._data.max()) else: self.sig_data_min = float(data_np.min()) self.sig_data_max = float(data_np.max()) self.nav_data_min = float(nav_img.min()) self.nav_data_max = float(nav_img.max()) self.nav_image_bytes = nav_img.tobytes() self.nav_stats = [ float(nav_img.mean()), float(nav_img.min()), float(nav_img.max()), float(nav_img.std()), ] self.pos_row = min(self.pos_row, self.nav_rows - 1) self.pos_col = min(self.pos_col, self.nav_cols - 1) self._update_frame()
def __repr__(self) -> str: return ( f"Show4D(shape=({self.nav_rows}, {self.nav_cols}, {self.sig_rows}, {self.sig_cols}), " f"pos=({self.pos_row}, {self.pos_col}))" )
[docs] def state_dict(self): return { "title": self.title, "cmap": self.cmap, "log_scale": self.log_scale, "auto_contrast": self.auto_contrast, "show_stats": self.show_stats, "show_controls": self.show_controls, "show_fft": self.show_fft, "fft_window": self.fft_window, "disabled_tools": self.disabled_tools, "hidden_tools": self.hidden_tools, "percentile_low": self.percentile_low, "percentile_high": self.percentile_high, "nav_pixel_size": self.nav_pixel_size, "sig_pixel_size": self.sig_pixel_size, "nav_pixel_unit": self.nav_pixel_unit, "sig_pixel_unit": self.sig_pixel_unit, "roi_mode": self.roi_mode, "roi_reduce": self.roi_reduce, "roi_center_row": self.roi_center_row, "roi_center_col": self.roi_center_col, "roi_radius": self.roi_radius, "roi_radius_inner": self.roi_radius_inner, "roi_width": self.roi_width, "roi_height": self.roi_height, "snap_enabled": self.snap_enabled, "snap_radius": self.snap_radius, "path_interval_ms": self.path_interval_ms, "path_loop": self.path_loop, "profile_line": self.profile_line, "profile_width": self.profile_width, }
[docs] def save(self, path: str): save_state_file(path, "Show4D", self.state_dict())
[docs] def load_state_dict(self, state): for key, val in state.items(): if hasattr(self, key): setattr(self, key, val)
[docs] def summary(self): lines = [self.title or "Show4D", "═" * 32] lines.append(f"Nav: {self.nav_rows}×{self.nav_cols}") if self.nav_pixel_size > 0: lines[-1] += f" ({self.nav_pixel_size:.2f} {self.nav_pixel_unit}/px)" lines.append(f"Signal: {self.sig_rows}×{self.sig_cols}") if self.sig_pixel_size > 0: lines[-1] += f" ({self.sig_pixel_size:.2f} {self.sig_pixel_unit}/px)" lines.append(f"Position: ({self.pos_row}, {self.pos_col})") cmap = self.cmap scale = "log" if self.log_scale else "linear" contrast = "auto contrast" if self.auto_contrast else "manual contrast" display = f"{cmap} | {contrast} | {scale}" if self.show_fft: display += " | FFT" if not self.fft_window: display += " (no window)" lines.append(f"Display: {display}") if self.roi_mode != "off": lines.append(f"ROI: {self.roi_mode} ({self.roi_reduce}) at ({self.roi_center_row:.1f}, {self.roi_center_col:.1f}) r={self.roi_radius:.1f}") if self.snap_enabled: lines.append(f"Snap: ON (radius={self.snap_radius} px)") if self.profile_line and len(self.profile_line) == 2: p0, p1 = self.profile_line[0], self.profile_line[1] lines.append(f"Profile: ({p0['row']:.0f}, {p0['col']:.0f}) -> ({p1['row']:.0f}, {p1['col']:.0f}) width={self.profile_width}") if self.disabled_tools: lines.append(f"Locked: {', '.join(self.disabled_tools)}") if self.hidden_tools: lines.append(f"Hidden: {', '.join(self.hidden_tools)}") print("\n".join(lines))
@property def position(self) -> tuple[int, int]: return (self.pos_row, self.pos_col) @position.setter def position(self, value: tuple[int, int]) -> None: self.pos_row, self.pos_col = value @property def nav_shape(self) -> tuple[int, int]: return (self.nav_rows, self.nav_cols) @property def sig_shape(self) -> tuple[int, int]: return (self.sig_rows, self.sig_cols) # ── Line Profile ────────────────────────────────────────────────────────
[docs] def set_profile(self, start: tuple, end: tuple) -> Self: row0, col0 = start row1, col1 = end self.profile_line = [ {"row": float(row0), "col": float(col0)}, {"row": float(row1), "col": float(col1)}, ] return self
[docs] def clear_profile(self) -> Self: self.profile_line = [] return self
@property def profile(self) -> list[tuple[float, float]]: if len(self.profile_line) == 2: p0, p1 = self.profile_line[0], self.profile_line[1] return [(p0["row"], p0["col"]), (p1["row"], p1["col"])] return [] @property def profile_values(self): if len(self.profile_line) != 2: return None p0, p1 = self.profile_line[0], self.profile_line[1] if _HAS_TORCH and isinstance(self._data, torch.Tensor): frame = self._data[self.pos_row, self.pos_col].cpu().numpy() else: frame = np.asarray(self._data[self.pos_row, self.pos_col], dtype=np.float32) return self._sample_line(frame, p0["row"], p0["col"], p1["row"], p1["col"]) @property def profile_distance(self) -> float: if len(self.profile_line) != 2: return 0.0 p0, p1 = self.profile_line[0], self.profile_line[1] dist_px = np.sqrt((p1["row"] - p0["row"]) ** 2 + (p1["col"] - p0["col"]) ** 2) if self.sig_pixel_size > 0: return float(dist_px * self.sig_pixel_size) return float(dist_px) def _sample_line(self, frame, row0, col0, row1, col1): h, w = frame.shape[:2] dc = col1 - col0 dr = row1 - row0 length = np.sqrt(dc * dc + dr * dr) n = max(2, int(np.ceil(length))) out = np.empty(n, dtype=np.float32) for i in range(n): t = i / (n - 1) c = col0 + t * dc r = row0 + t * dr ci, ri = int(np.floor(c)), int(np.floor(r)) cf, rf = c - ci, r - ri c0c = max(0, min(w - 1, ci)) c1c = max(0, min(w - 1, ci + 1)) r0c = max(0, min(h - 1, ri)) r1c = max(0, min(h - 1, ri + 1)) out[i] = ( frame[r0c, c0c] * (1 - cf) * (1 - rf) + frame[r0c, c1c] * cf * (1 - rf) + frame[r1c, c0c] * (1 - cf) * rf + frame[r1c, c1c] * cf * rf ) return out # ── Path Animation ──────────────────────────────────────────────────────
[docs] def set_path( self, points: list[tuple[int, int]], interval_ms: int = 100, loop: bool = True, autoplay: bool = True, ) -> Self: self._path_points = list(points) self.path_length = len(self._path_points) self.path_index = 0 self.path_interval_ms = interval_ms self.path_loop = loop if autoplay and self.path_length > 0: self.path_playing = True return self
[docs] def play(self) -> Self: if self.path_length > 0: self.path_playing = True return self
[docs] def pause(self) -> Self: self.path_playing = False return self
[docs] def stop(self) -> Self: self.path_playing = False self.path_index = 0 return self
[docs] def goto(self, index: int) -> Self: if 0 <= index < self.path_length: self.path_index = index return self
[docs] def raster( self, step: int = 1, bidirectional: bool = False, interval_ms: int = 100, loop: bool = True, ) -> Self: points = [] for r in range(0, self.nav_rows, step): cols = list(range(0, self.nav_cols, step)) if bidirectional and (r // step % 2 == 1): cols = cols[::-1] for c in cols: points.append((r, c)) return self.set_path(points=points, interval_ms=interval_ms, loop=loop)
# ── Export (GIF) ────────────────────────────────────────────────────────── def _normalize_frame(self, frame: np.ndarray) -> np.ndarray: if self.log_scale: frame = np.log1p(np.maximum(frame, 0)) fmin, fmax = float(frame.min()), float(frame.max()) if self.auto_contrast: fmin = float(np.percentile(frame, self.percentile_low)) fmax = float(np.percentile(frame, self.percentile_high)) if fmax > fmin: return np.clip((frame - fmin) / (fmax - fmin) * 255, 0, 255).astype(np.uint8) return np.zeros(frame.shape, dtype=np.uint8) def _on_gif_export(self, change=None): if not self._gif_export_requested: return self._gif_export_requested = False self._generate_gif() def _generate_gif(self): import io from matplotlib import colormaps from PIL import Image if not self._path_points: with self.hold_sync(): self._gif_data = b"" self._gif_metadata_json = "" return cmap_fn = colormaps.get_cmap(self.cmap) duration_ms = max(10, self.path_interval_ms) pil_frames = [] for row, col in self._path_points: row = max(0, min(self.nav_rows - 1, row)) col = max(0, min(self.nav_cols - 1, col)) if _HAS_TORCH and isinstance(self._data, torch.Tensor): frame = self._data[row, col].cpu().numpy() else: frame = np.asarray(self._data[row, col], dtype=np.float32) normalized = self._normalize_frame(frame) rgba = cmap_fn(normalized / 255.0) rgb = (rgba[:, :, :3] * 255).astype(np.uint8) pil_frames.append(Image.fromarray(rgb)) if not pil_frames: return buf = io.BytesIO() pil_frames[0].save( buf, format="GIF", save_all=True, append_images=pil_frames[1:], duration=duration_ms, loop=0, ) metadata = { **build_json_header("Show4D"), "view": "signal", "format": "gif", "export_kind": "path_animation", "n_frames": int(len(pil_frames)), "duration_ms": int(duration_ms), "path_loop": bool(self.path_loop), "path_points": [{"row": int(row), "col": int(col)} for row, col in self._path_points], "scan_shape": {"rows": int(self.nav_rows), "cols": int(self.nav_cols)}, "detector_shape": {"rows": int(self.sig_rows), "cols": int(self.sig_cols)}, "display": { "cmap": self.cmap, "log_scale": bool(self.log_scale), "auto_contrast": bool(self.auto_contrast), "percentile_low": float(self.percentile_low), "percentile_high": float(self.percentile_high), }, } with self.hold_sync(): self._gif_metadata_json = json.dumps(metadata, indent=2) self._gif_data = buf.getvalue() # ── Save Image ────────────────────────────────────────────────────────────
[docs] def save_image( self, path: str | pathlib.Path, *, view: str | None = None, position: tuple[int, int] | None = None, format: str | None = None, dpi: int = 150, ) -> pathlib.Path: """Save current signal or navigation image as PNG, PDF, or TIFF. Parameters ---------- path : str or pathlib.Path Output file path. view : str, optional 'signal' or 'nav'. Defaults to 'signal'. position : tuple[int, int], optional Navigation position as (row, col) for the signal view. Defaults to current position. format : str, optional 'png', 'pdf', or 'tiff'. If omitted, inferred from file extension. dpi : int, default 150 Output DPI metadata. Returns ------- pathlib.Path The written file path. """ from matplotlib import colormaps from PIL import Image path = pathlib.Path(path) fmt = (format or path.suffix.lstrip(".").lower() or "png").lower() if fmt not in ("png", "pdf", "tiff", "tif"): raise ValueError(f"Unsupported format: {fmt!r}. Use 'png', 'pdf', or 'tiff'.") view_key = (view or "signal").lower() if view_key not in ("signal", "nav"): raise ValueError(f"Unknown view: {view_key!r}. Use 'signal' or 'nav'.") if view_key == "signal": row, col = position if position is not None else (self.pos_row, self.pos_col) if row < 0 or row >= self.nav_rows or col < 0 or col >= self.nav_cols: raise IndexError( f"Position ({row}, {col}) out of range " f"[0, {self.nav_rows}) x [0, {self.nav_cols})" ) if _HAS_TORCH and isinstance(self._data, torch.Tensor): frame = self._data[row, col].cpu().numpy() else: frame = np.asarray(self._data[row, col], dtype=np.float32) normalized = self._normalize_frame(frame) else: frame = self._nav_image nav_min, nav_max = float(frame.min()), float(frame.max()) if nav_max > nav_min: normalized = np.clip( (frame - nav_min) / (nav_max - nav_min) * 255, 0, 255 ).astype(np.uint8) else: normalized = np.zeros(frame.shape, dtype=np.uint8) cmap_fn = colormaps.get_cmap(self.cmap) rgba = (cmap_fn(normalized / 255.0) * 255).astype(np.uint8) img = Image.fromarray(rgba) path.parent.mkdir(parents=True, exist_ok=True) img.save(str(path), dpi=(dpi, dpi)) return path
# ── Internal Observers ─────────────────────────────────────────────────── def _update_frame(self, change=None): if _HAS_TORCH and isinstance(self._data, torch.Tensor): frame = self._data[self.pos_row, self.pos_col].cpu().numpy() else: frame = self._data[self.pos_row, self.pos_col] frame = np.asarray(frame, dtype=np.float32) with self.hold_sync(): self.sig_stats = [ float(frame.mean()), float(frame.min()), float(frame.max()), float(frame.std()), ] self.frame_bytes = frame.tobytes() def _on_roi_change(self, change=None): if self.roi_mode == "off": self._update_frame() return self._compute_roi_signal() def _on_roi_center_change(self, change=None): if self.roi_mode == "off": return if change and "new" in change: row, col = change["new"] self.unobserve(self._on_roi_change, names=["roi_center_row", "roi_center_col"]) self.roi_center_row = row self.roi_center_col = col self.observe(self._on_roi_change, names=["roi_center_row", "roi_center_col"]) self._compute_roi_signal() def _on_path_index_change(self, change): idx = change["new"] if 0 <= idx < len(self._path_points): row, col = self._path_points[idx] self.pos_row = max(0, min(self.nav_rows - 1, row)) self.pos_col = max(0, min(self.nav_cols - 1, col)) def _compute_roi_signal(self): cr, cc = self.roi_center_row, self.roi_center_col if _HAS_TORCH and isinstance(self._data, torch.Tensor): # GPU path row_coords = self._nav_row_coords col_coords = self._nav_col_coords if self.roi_mode == "circle": mask = (row_coords - cr) ** 2 + (col_coords - cc) ** 2 <= self.roi_radius ** 2 elif self.roi_mode == "square": mask = (torch.abs(row_coords - cr) <= self.roi_radius) & (torch.abs(col_coords - cc) <= self.roi_radius) elif self.roi_mode == "rect": mask = (torch.abs(row_coords - cr) <= self.roi_height / 2) & (torch.abs(col_coords - cc) <= self.roi_width / 2) elif self.roi_mode == "annular": dist_sq = (row_coords - cr) ** 2 + (col_coords - cc) ** 2 mask = (dist_sq >= self.roi_radius_inner ** 2) & (dist_sq <= self.roi_radius ** 2) else: self._update_frame() return n_positions = int(mask.sum()) if n_positions == 0: self._update_frame() return masked_data = self._data[mask] reduce_ops = { "mean": lambda x: x.mean(dim=0), "max": lambda x: x.max(dim=0).values, "min": lambda x: x.min(dim=0).values, "sum": lambda x: x.sum(dim=0), } op = reduce_ops.get(self.roi_reduce, reduce_ops["mean"]) result = op(masked_data).cpu().numpy().astype(np.float32) else: # NumPy fallback row_coords = np.arange(self.nav_rows)[:, None] col_coords = np.arange(self.nav_cols)[None, :] if self.roi_mode == "circle": mask = (row_coords - cr) ** 2 + (col_coords - cc) ** 2 <= self.roi_radius ** 2 elif self.roi_mode == "square": mask = (np.abs(row_coords - cr) <= self.roi_radius) & (np.abs(col_coords - cc) <= self.roi_radius) elif self.roi_mode == "rect": mask = (np.abs(row_coords - cr) <= self.roi_height / 2) & (np.abs(col_coords - cc) <= self.roi_width / 2) elif self.roi_mode == "annular": dist_sq = (row_coords - cr) ** 2 + (col_coords - cc) ** 2 mask = (dist_sq >= self.roi_radius_inner ** 2) & (dist_sq <= self.roi_radius ** 2) else: self._update_frame() return n_positions = int(mask.sum()) if n_positions == 0: self._update_frame() return reduce_ops = {"mean": np.mean, "max": np.max, "min": np.min, "sum": np.sum} op = reduce_ops.get(self.roi_reduce, np.mean) result = op(self._data[mask], axis=0).astype(np.float32) with self.hold_sync(): self.sig_stats = [ float(result.mean()), float(result.min()), float(result.max()), float(result.std()), ] self.frame_bytes = result.tobytes()
bind_tool_runtime_api(Show4D, "Show4D")