"""
show3d: Interactive 3D stack viewer widget with advanced features.
For viewing a stack of 2D images (e.g., defocus sweep, time series, z-stack, movies).
Includes playback controls, statistics, ROI selection, FFT, and more.
"""
import json
import pathlib
from enum import Enum
from typing import Self
import anywidget
import numpy as np
import traitlets
from quantem.widget.array_utils import to_numpy
from quantem.widget.io import IO, IOResult
from quantem.widget.json_state import build_json_header, resolve_widget_version, save_state_file, unwrap_state_payload
from quantem.widget.tool_parity import (
bind_tool_runtime_api,
build_tool_groups,
normalize_tool_groups,
)
try:
import torch
_HAS_TORCH = True
except ImportError:
torch = None # type: ignore[assignment]
_HAS_TORCH = False
class Colormap(str, Enum):
"""Available colormaps for image display."""
INFERNO = "inferno"
VIRIDIS = "viridis"
PLASMA = "plasma"
MAGMA = "magma"
HOT = "hot"
GRAY = "gray"
def __str__(self) -> str:
return self.value
[docs]
class Show3D(anywidget.AnyWidget):
"""
Interactive 3D stack viewer with advanced features for electron microscopy.
View a stack of 2D images along a specific dimension (e.g., defocus sweep,
time series, depth stack, in-situ movies). Includes playback controls,
statistics panel, ROI selection, FFT view, and more.
Parameters
----------
data : array_like
3D array of shape (N, height, width) where N is the stack dimension.
labels : list of str, optional
Labels for each slice (e.g., ["C10=-500nm", "C10=-400nm", ...]).
If None, uses slice indices.
title : str, optional
Title to display above the image.
cmap : str or Colormap, default Colormap.MAGMA
Colormap name. Use Colormap enum (Colormap.MAGMA, Colormap.VIRIDIS, etc.)
or string ("magma", "viridis", "gray", "inferno", "plasma").
vmin : float, optional
Minimum value for colormap. If None, uses data min.
vmax : float, optional
Maximum value for colormap. If None, uses data max.
pixel_size : float, optional
Pixel size in Å for scale bar display.
log_scale : bool, default False
Use log scale for intensity mapping.
auto_contrast : bool, default False
Use percentile-based contrast (ignores vmin/vmax).
percentile_low : float, default 1.0
Lower percentile for auto-contrast.
percentile_high : float, default 99.0
Upper percentile for auto-contrast.
fps : float, default 5.0
Frames per second for playback.
timestamps : list of float, optional
Timestamps for each frame (e.g., seconds or dose values).
timestamp_unit : str, default "s"
Unit for timestamps (e.g., "s", "ms", "e/A2").
disabled_tools : list of str, optional
Tool groups to lock while still showing controls. Supported:
``"display"``, ``"histogram"``, ``"stats"``, ``"playback"``,
``"view"``, ``"export"``, ``"roi"``, ``"profile"``, ``"all"``.
``"navigation"`` is accepted as an alias of ``"playback"``.
disable_* : bool, optional
Convenience flags mirroring ``disabled_tools``. Includes
``disable_navigation`` as an alias of ``disable_playback``.
hidden_tools : list of str, optional
Tool groups to hide from the UI. Uses the same keys as
``disabled_tools``.
hide_* : bool, optional
Convenience flags mirroring ``disable_*`` for ``hidden_tools``.
Examples
--------
>>> import numpy as np
>>> from quantem.widget import Show3D
>>>
>>> # View defocus sweep
>>> labels = [f"C10={c10:.0f}nm" for c10 in np.linspace(-500, -200, 12)]
>>> Show3D(stack, labels=labels, title="Defocus Sweep")
>>>
>>> # View in-situ movie with timestamps
>>> times = np.arange(100) * 0.1 # 100 frames at 10 fps
>>> Show3D(movie, timestamps=times, timestamp_unit="s", fps=30)
>>>
>>> # With scale bar
>>> Show3D(data, pixel_size=0.5, title="HRTEM")
"""
_esm = pathlib.Path(__file__).parent / "static" / "show3d.js"
_css = pathlib.Path(__file__).parent / "static" / "show3d.css"
# =========================================================================
# Core State
# =========================================================================
slice_idx = traitlets.Int(0).tag(sync=True)
n_slices = traitlets.Int(1).tag(sync=True)
height = traitlets.Int(1).tag(sync=True)
width = traitlets.Int(1).tag(sync=True)
frame_bytes = traitlets.Bytes(b"").tag(sync=True)
labels = traitlets.List(traitlets.Unicode()).tag(sync=True)
title = traitlets.Unicode("").tag(sync=True)
cmap = traitlets.Unicode("magma").tag(sync=True)
dim_label = traitlets.Unicode("Frame").tag(sync=True)
# =========================================================================
# Playback Controls
# =========================================================================
playing = traitlets.Bool(False).tag(sync=True)
reverse = traitlets.Bool(False).tag(sync=True) # Play in reverse direction
boomerang = traitlets.Bool(False).tag(sync=True) # Ping-pong playback
fps = traitlets.Float(5.0).tag(sync=True) # Default 5 FPS for easier control
loop = traitlets.Bool(True).tag(sync=True)
loop_start = traitlets.Int(0).tag(sync=True) # Start frame for loop range
loop_end = traitlets.Int(-1).tag(sync=True) # End frame for loop (-1 = last)
bookmarked_frames = traitlets.List(traitlets.Int()).tag(sync=True)
playback_path = traitlets.List(traitlets.Int()).tag(sync=True)
# =========================================================================
# Statistics Panel
# =========================================================================
show_controls = traitlets.Bool(True).tag(sync=True)
show_stats = traitlets.Bool(True).tag(sync=True)
stats_mean = traitlets.Float(0.0).tag(sync=True)
stats_min = traitlets.Float(0.0).tag(sync=True)
stats_max = traitlets.Float(0.0).tag(sync=True)
stats_std = traitlets.Float(0.0).tag(sync=True)
# =========================================================================
# Display Options
# =========================================================================
log_scale = traitlets.Bool(False).tag(sync=True)
auto_contrast = traitlets.Bool(False).tag(sync=True)
percentile_low = traitlets.Float(1.0).tag(sync=True)
percentile_high = traitlets.Float(99.0).tag(sync=True)
data_min = traitlets.Float(0.0).tag(sync=True)
data_max = traitlets.Float(0.0).tag(sync=True)
# =========================================================================
# Scale Bar
# =========================================================================
pixel_size = traitlets.Float(0.0).tag(sync=True) # Å/pixel, 0 = no scale bar
scale_bar_visible = traitlets.Bool(True).tag(sync=True)
# =========================================================================
# Timestamps / Dose
# =========================================================================
timestamps = traitlets.List(traitlets.Float()).tag(sync=True)
timestamp_unit = traitlets.Unicode("s").tag(sync=True)
current_timestamp = traitlets.Float(0.0).tag(sync=True)
# =========================================================================
# ROI Selection
# =========================================================================
roi_active = traitlets.Bool(False).tag(sync=True)
roi_list = traitlets.List([]).tag(sync=True)
roi_selected_idx = traitlets.Int(-1).tag(sync=True)
roi_stats = traitlets.Dict({}).tag(sync=True)
roi_plot_data = traitlets.Bytes(b"").tag(sync=True)
# =========================================================================
# Sizing
# =========================================================================
canvas_size = traitlets.Int(0).tag(sync=True) # If 0, use frontend defaults
# =========================================================================
# Diff Mode
# =========================================================================
diff_mode = traitlets.Unicode("off").tag(sync=True)
# =========================================================================
# Analysis Panels (FFT + Histogram shown together)
# =========================================================================
show_fft = traitlets.Bool(False).tag(sync=True)
fft_window = traitlets.Bool(True).tag(sync=True)
show_playback = traitlets.Bool(False).tag(sync=True)
disabled_tools = traitlets.List(traitlets.Unicode()).tag(sync=True)
hidden_tools = traitlets.List(traitlets.Unicode()).tag(sync=True)
# =========================================================================
# Line Profile
# =========================================================================
profile_line = traitlets.List(traitlets.Dict()).tag(sync=True)
profile_width = traitlets.Int(1).tag(sync=True)
# =========================================================================
# Export (GIF / ZIP of PNGs)
# =========================================================================
_gif_export_requested = traitlets.Bool(False).tag(sync=True)
_gif_data = traitlets.Bytes(b"").tag(sync=True)
_gif_metadata_json = traitlets.Unicode("").tag(sync=True)
_zip_export_requested = traitlets.Bool(False).tag(sync=True)
_zip_data = traitlets.Bytes(b"").tag(sync=True)
_bundle_export_requested = traitlets.Bool(False).tag(sync=True)
_bundle_data = traitlets.Bytes(b"").tag(sync=True)
# =========================================================================
# Playback Buffer (sliding prefetch)
# =========================================================================
_buffer_bytes = traitlets.Bytes(b"").tag(sync=True)
_buffer_start = traitlets.Int(0).tag(sync=True)
_buffer_count = traitlets.Int(0).tag(sync=True)
_prefetch_request = traitlets.Int(-1).tag(sync=True)
@classmethod
def _normalize_tool_groups(cls, tool_groups):
return normalize_tool_groups("Show3D", tool_groups)
@classmethod
def _build_disabled_tools(
cls,
disabled_tools=None,
disable_display: bool = False,
disable_histogram: bool = False,
disable_stats: bool = False,
disable_playback: bool = False,
disable_navigation: bool = False,
disable_view: bool = False,
disable_export: bool = False,
disable_roi: bool = False,
disable_profile: bool = False,
disable_all: bool = False,
):
return build_tool_groups(
"Show3D",
tool_groups=disabled_tools,
all_flag=disable_all,
flag_map={
"display": disable_display,
"histogram": disable_histogram,
"stats": disable_stats,
"playback": disable_playback or disable_navigation,
"view": disable_view,
"export": disable_export,
"roi": disable_roi,
"profile": disable_profile,
},
)
@classmethod
def _build_hidden_tools(
cls,
hidden_tools=None,
hide_display: bool = False,
hide_histogram: bool = False,
hide_stats: bool = False,
hide_playback: bool = False,
hide_navigation: bool = False,
hide_view: bool = False,
hide_export: bool = False,
hide_roi: bool = False,
hide_profile: bool = False,
hide_all: bool = False,
):
return build_tool_groups(
"Show3D",
tool_groups=hidden_tools,
all_flag=hide_all,
flag_map={
"display": hide_display,
"histogram": hide_histogram,
"stats": hide_stats,
"playback": hide_playback or hide_navigation,
"view": hide_view,
"export": hide_export,
"roi": hide_roi,
"profile": hide_profile,
},
)
_VALID_DIFF_MODES = {"off", "previous", "first"}
@traitlets.validate("diff_mode")
def _validate_diff_mode(self, proposal):
val = proposal["value"]
if val not in self._VALID_DIFF_MODES:
raise traitlets.TraitError(
f"Invalid diff_mode '{val}'. Must be one of: {sorted(self._VALID_DIFF_MODES)}"
)
return val
@traitlets.validate("disabled_tools")
def _validate_disabled_tools(self, proposal):
return self._normalize_tool_groups(proposal["value"])
@traitlets.validate("hidden_tools")
def _validate_hidden_tools(self, proposal):
return self._normalize_tool_groups(proposal["value"])
def __init__(
self,
data,
labels: list[str] | None = None,
title: str = "",
cmap: str | Colormap = Colormap.MAGMA,
vmin: float | None = None,
vmax: float | None = None,
pixel_size: float = 0.0,
log_scale: bool = False,
auto_contrast: bool = False,
percentile_low: float = 1.0,
percentile_high: float = 99.0,
fps: float = 5.0,
timestamps: list[float] | None = None,
timestamp_unit: str = "s",
show_fft: bool = False,
fft_window: bool = True,
show_playback: bool = False,
show_stats: bool = True,
show_controls: bool = True,
canvas_size: int = 0,
disabled_tools: list[str] | None = None,
disable_display: bool = False,
disable_histogram: bool = False,
disable_stats: bool = False,
disable_playback: bool = False,
disable_navigation: bool = False,
disable_view: bool = False,
disable_export: bool = False,
disable_roi: bool = False,
disable_profile: bool = False,
disable_all: bool = False,
hidden_tools: list[str] | None = None,
hide_display: bool = False,
hide_histogram: bool = False,
hide_stats: bool = False,
hide_playback: bool = False,
hide_navigation: bool = False,
hide_view: bool = False,
hide_export: bool = False,
hide_roi: bool = False,
hide_profile: bool = False,
hide_all: bool = False,
diff_mode: str = "off",
buffer_size: int = 64,
dim_label: str = "Frame",
use_torch: bool = False,
device: str | None = None,
state=None,
**kwargs,
):
super().__init__(**kwargs)
self.widget_version = resolve_widget_version()
# Optional torch GPU acceleration
self._use_torch = False
self._device = None
self._data_torch = None
if use_torch:
if not _HAS_TORCH:
raise ImportError(
"use_torch=True requires PyTorch. Install it with: pip install torch"
)
self._use_torch = True
self._device = torch.device(
device or (
"mps" if torch.backends.mps.is_available()
else "cuda" if torch.cuda.is_available()
else "cpu"
)
)
# Check if data is an IOResult and extract metadata
if isinstance(data, IOResult):
if not title and data.title:
title = data.title
if pixel_size == 0.0 and data.pixel_size is not None:
pixel_size = data.pixel_size
if labels is None and data.labels:
labels = data.labels
data = data.data
# Wrap 2D to single-frame stack for Show3D
if hasattr(data, "ndim") and data.ndim == 2:
data = data[None, ...]
# Check if data is a Dataset3d and extract metadata
_extracted_title = None
_extracted_pixel_size = None
if hasattr(data, "array") and hasattr(data, "name") and hasattr(data, "sampling"):
_extracted_title = data.name if data.name else None
# sampling is (z_sampling, y_sampling, x_sampling) - use y/x for pixel size
if hasattr(data, "sampling") and len(data.sampling) >= 3:
sampling_val = float(data.sampling[1])
# pixel_size is in Å — convert if units are nm
if hasattr(data, "units"):
units = list(data.units)
if units[1] in ("nm", "nanometer"):
sampling_val = sampling_val * 10 # nm → Å
_extracted_pixel_size = sampling_val
data = data.array
# Convert input to NumPy (handles NumPy, CuPy, PyTorch)
data = to_numpy(data)
# Ensure 3D
if data.ndim != 3:
raise ValueError(f"Expected 3D array, got {data.ndim}D")
# Store data as float32 numpy array
self._data = data.astype(np.float32)
# Create GPU copy if torch acceleration enabled
if self._use_torch:
self._data_torch = torch.from_numpy(self._data).to(self._device)
# Dimensions
self.n_slices = int(self._data.shape[0])
self.height = int(self._data.shape[1])
self.width = int(self._data.shape[2])
# Color range (global across all frames)
self._vmin_user = vmin
self._vmax_user = vmax
if self._use_torch:
self._vmin = vmin if vmin is not None else float(self._data_torch.min().item())
self._vmax = vmax if vmax is not None else float(self._data_torch.max().item())
self.data_min = float(self._data_torch.min().item())
self.data_max = float(self._data_torch.max().item())
else:
self._vmin = vmin if vmin is not None else float(self._data.min())
self._vmax = vmax if vmax is not None else float(self._data.max())
self.data_min = float(self._data.min())
self.data_max = float(self._data.max())
# Labels
if labels is not None:
self.labels = list(labels)
else:
self.labels = [str(i) for i in range(self.n_slices)]
# Title and colormap - use extracted title if not explicitly provided
self.title = title if title else (_extracted_title or "")
self.cmap = str(cmap) # Convert Colormap enum to string
# Use extracted pixel_size if not explicitly provided
if pixel_size == 0.0 and _extracted_pixel_size is not None:
pixel_size = _extracted_pixel_size
# Display options
self.pixel_size = pixel_size
self.log_scale = log_scale
self.auto_contrast = auto_contrast
self.percentile_low = percentile_low
self.percentile_high = percentile_high
self.fps = fps
# Timestamps
if timestamps is not None:
self.timestamps = [float(t) for t in timestamps]
else:
self.timestamps = []
self.timestamp_unit = timestamp_unit
self.dim_label = dim_label
self.diff_mode = diff_mode
self.show_fft = show_fft
self.fft_window = fft_window
self.show_playback = show_playback
self.show_stats = show_stats
self.show_controls = show_controls
self.canvas_size = canvas_size
self.disabled_tools = self._build_disabled_tools(
disabled_tools=disabled_tools,
disable_display=disable_display,
disable_histogram=disable_histogram,
disable_stats=disable_stats,
disable_playback=disable_playback,
disable_navigation=disable_navigation,
disable_view=disable_view,
disable_export=disable_export,
disable_roi=disable_roi,
disable_profile=disable_profile,
disable_all=disable_all,
)
self.hidden_tools = self._build_hidden_tools(
hidden_tools=hidden_tools,
hide_display=hide_display,
hide_histogram=hide_histogram,
hide_stats=hide_stats,
hide_playback=hide_playback,
hide_navigation=hide_navigation,
hide_view=hide_view,
hide_export=hide_export,
hide_roi=hide_roi,
hide_profile=hide_profile,
hide_all=hide_all,
)
frame_bytes = self.height * self.width * 4 # float32
max_buffer_bytes = 64 * 1024 * 1024 # 64 MB cap per transfer
min_buffer_frames = 8 # guarantee at least 8 frames for large images
max_frames = max(min_buffer_frames, max_buffer_bytes // frame_bytes)
self._buffer_size = min(buffer_size, self.n_slices, max_frames)
# Initial position at middle
self.slice_idx = int(self.n_slices // 2)
# Observers
self.observe(self._on_slice_change, names=["slice_idx"])
self.observe(
self._on_roi_change,
names=["roi_active", "roi_list", "roi_selected_idx"],
)
self.observe(self._on_gif_export, names=["_gif_export_requested"])
self.observe(self._on_zip_export, names=["_zip_export_requested"])
self.observe(self._on_bundle_export, names=["_bundle_export_requested"])
self.observe(self._on_playing_change, names=["playing"])
self.observe(self._on_prefetch, names=["_prefetch_request"])
self.observe(self._on_diff_mode_change, names=["diff_mode"])
# Initial update
self._update_all()
if state is not None:
if isinstance(state, (str, pathlib.Path)):
state = unwrap_state_payload(
json.loads(pathlib.Path(state).read_text()),
require_envelope=True,
)
else:
state = unwrap_state_payload(state)
self.load_state_dict(state)
[docs]
def set_image(self, data, labels=None):
"""Replace the stack data. Preserves all display settings."""
if hasattr(data, "array") and hasattr(data, "name") and hasattr(data, "sampling"):
data = data.array
data = to_numpy(data)
if data.ndim != 3:
raise ValueError(f"Expected 3D array, got {data.ndim}D")
self._data = data.astype(np.float32)
if self._use_torch:
self._data_torch = torch.from_numpy(self._data).to(self._device)
self.n_slices = int(data.shape[0])
self.height = int(data.shape[1])
self.width = int(data.shape[2])
if self._use_torch:
self.data_min = float(self._data_torch.min().item())
self.data_max = float(self._data_torch.max().item())
else:
self.data_min = float(self._data.min())
self.data_max = float(self._data.max())
self._vmin = self._vmin_user if self._vmin_user is not None else self.data_min
self._vmax = self._vmax_user if self._vmax_user is not None else self.data_max
if labels is not None:
self.labels = list(labels)
else:
self.labels = [str(i) for i in range(self.n_slices)]
self.slice_idx = min(self.slice_idx, self.n_slices - 1)
self._buffer_size = min(self._buffer_size, self.n_slices)
self._update_all()
def __repr__(self) -> str:
parts = f"Show3D({self.n_slices}×{self.height}×{self.width}, frame={self.slice_idx}, cmap={self.cmap}"
if self.diff_mode != "off":
parts += f", diff={self.diff_mode}"
parts += ")"
return parts
[docs]
def state_dict(self):
return {
"title": self.title,
"cmap": self.cmap,
"log_scale": self.log_scale,
"auto_contrast": self.auto_contrast,
"percentile_low": self.percentile_low,
"percentile_high": self.percentile_high,
"show_stats": self.show_stats,
"show_controls": self.show_controls,
"show_fft": self.show_fft,
"fft_window": self.fft_window,
"show_playback": self.show_playback,
"disabled_tools": self.disabled_tools,
"hidden_tools": self.hidden_tools,
"pixel_size": self.pixel_size,
"scale_bar_visible": self.scale_bar_visible,
"canvas_size": self.canvas_size,
"fps": self.fps,
"loop": self.loop,
"reverse": self.reverse,
"boomerang": self.boomerang,
"loop_start": self.loop_start,
"loop_end": self.loop_end,
"bookmarked_frames": self.bookmarked_frames,
"playback_path": self.playback_path,
"roi_active": self.roi_active,
"roi_list": self.roi_list,
"roi_selected_idx": self.roi_selected_idx,
"profile_line": self.profile_line,
"profile_width": self.profile_width,
"diff_mode": self.diff_mode,
"dim_label": self.dim_label,
"timestamp_unit": self.timestamp_unit,
}
[docs]
def save(self, path: str):
save_state_file(path, "Show3D", self.state_dict())
[docs]
def load_state_dict(self, state):
for key, val in state.items():
if hasattr(self, key):
setattr(self, key, val)
[docs]
def summary(self):
lines = [self.title or "Show3D", "═" * 32]
lines.append(f"Stack: {self.n_slices}×{self.height}×{self.width}")
if self.pixel_size > 0:
ps = self.pixel_size
if ps >= 10:
lines[-1] += f" ({ps / 10:.2f} nm/px)"
else:
lines[-1] += f" ({ps:.2f} Å/px)"
lines.append(f"Frame: {self.slice_idx}/{self.n_slices - 1}")
if self.labels and self.slice_idx < len(self.labels):
lines[-1] += f" [{self.labels[self.slice_idx]}]"
if hasattr(self, "_data") and self._data is not None:
arr = self._data
lines.append(f"Data: min={float(arr.min()):.4g} max={float(arr.max()):.4g} mean={float(arr.mean()):.4g}")
cmap = self.cmap
scale = "log" if self.log_scale else "linear"
contrast = "auto contrast" if self.auto_contrast else "manual contrast"
display = f"{cmap} | {contrast} | {scale}"
if self.show_fft:
display += " | FFT"
if not self.fft_window:
display += " (no window)"
if self.diff_mode != "off":
display += f" | diff={self.diff_mode}"
lines.append(f"Display: {display}")
if self.disabled_tools:
lines.append(f"Locked: {', '.join(self.disabled_tools)}")
if self.hidden_tools:
lines.append(f"Hidden: {', '.join(self.hidden_tools)}")
lines.append(f"Playback: {self.fps} fps | loop={'on' if self.loop else 'off'} | reverse={'on' if self.reverse else 'off'} | boomerang={'on' if self.boomerang else 'off'}")
if self.loop_start > 0 or self.loop_end >= 0:
end = self.loop_end if self.loop_end >= 0 else self.n_slices - 1
lines.append(f"Range: {self.loop_start}–{end}")
if self.roi_active and self.roi_list:
lines.append(f"ROI: {len(self.roi_list)} region(s)")
if len(self.profile_line) >= 2:
p0, p1 = self.profile_line[0], self.profile_line[1]
lines.append(f"Profile: ({p0['row']:.0f}, {p0['col']:.0f}) → ({p1['row']:.0f}, {p1['col']:.0f}) width={self.profile_width}")
print("\n".join(lines))
def _get_color_range(self, frame: np.ndarray) -> tuple[float, float]:
"""Get vmin/vmax based on current settings."""
if self.auto_contrast:
vmin = float(np.percentile(frame, self.percentile_low))
vmax = float(np.percentile(frame, self.percentile_high))
else:
vmin = self._vmin
vmax = self._vmax
return vmin, vmax
def _normalize_frame(self, frame: np.ndarray) -> np.ndarray:
"""Normalize frame to uint8 with current display settings."""
# Apply log scale if enabled
if self.log_scale:
frame = np.log1p(np.maximum(frame, 0))
vmin, vmax = self._get_color_range(frame)
if vmax > vmin:
normalized = np.clip((frame - vmin) / (vmax - vmin) * 255, 0, 255)
return normalized.astype(np.uint8)
return np.zeros(frame.shape, dtype=np.uint8)
def _get_display_frame(self, idx=None):
if idx is None:
idx = self.slice_idx
frame = self._data[idx]
if self.diff_mode == "previous":
if idx == 0:
return np.zeros_like(frame)
return frame - self._data[idx - 1]
if self.diff_mode == "first":
return frame - self._data[0]
return frame
def _on_diff_mode_change(self, change=None):
if self.diff_mode == "off":
self.data_min = float(self._data.min())
self.data_max = float(self._data.max())
else:
# Recompute global range for diff frames
mins, maxs = [], []
for i in range(self.n_slices):
f = self._get_display_frame(i)
mins.append(float(f.min()))
maxs.append(float(f.max()))
self.data_min = min(mins)
self.data_max = max(maxs)
self._update_all()
def _update_all(self):
"""Update frame, stats, and all derived data. Uses hold_sync for batched transfer."""
frame = self._get_display_frame()
with self.hold_sync():
if self._use_torch:
t = self._data_torch[self.slice_idx]
self.stats_mean = float(t.mean().item())
self.stats_min = float(t.min().item())
self.stats_max = float(t.max().item())
self.stats_std = float(t.std().item())
else:
self.stats_mean = float(frame.mean())
self.stats_min = float(frame.min())
self.stats_max = float(frame.max())
self.stats_std = float(frame.std())
if self.timestamps and self.slice_idx < len(self.timestamps):
self.current_timestamp = self.timestamps[self.slice_idx]
if self.roi_active:
self._update_roi_stats(frame)
else:
self.roi_stats = {}
self.frame_bytes = frame.tobytes()
def _roi_mask(self, roi: dict):
r, c = np.ogrid[0 : self.height, 0 : self.width]
shape = roi.get("shape", "circle")
row = float(roi.get("row", 0))
col = float(roi.get("col", 0))
radius = max(1.0, float(roi.get("radius", 10)))
if shape == "circle":
return (c - col) ** 2 + (r - row) ** 2 <= radius**2
if shape == "square":
return (np.abs(c - col) <= radius) & (np.abs(r - row) <= radius)
if shape == "rectangle":
half_w = max(1.0, float(roi.get("width", 20)) / 2.0)
half_h = max(1.0, float(roi.get("height", 20)) / 2.0)
return (np.abs(c - col) <= half_w) & (np.abs(r - row) <= half_h)
if shape == "annular":
inner = max(0.0, float(roi.get("radius_inner", 5)))
dist2 = (c - col) ** 2 + (r - row) ** 2
return (dist2 >= inner**2) & (dist2 <= radius**2)
return (c - col) ** 2 + (r - row) ** 2 <= radius**2
def _update_roi_stats(self, frame: np.ndarray):
idx = self.roi_selected_idx
if idx < 0 or idx >= len(self.roi_list):
self.roi_stats = {}
return
roi = self.roi_list[idx]
mask = self._roi_mask(roi)
if self._use_torch:
mask_t = torch.from_numpy(mask).to(self._device)
t = self._data_torch[self.slice_idx]
region = t[mask_t]
if region.numel() > 0:
self.roi_stats = {
"mean": float(region.mean().item()),
"min": float(region.min().item()),
"max": float(region.max().item()),
"std": float(region.std().item()),
}
else:
self.roi_stats = {}
else:
region = frame[mask]
if region.size > 0:
self.roi_stats = {
"mean": float(region.mean()),
"min": float(region.min()),
"max": float(region.max()),
"std": float(region.std()),
}
else:
self.roi_stats = {}
def _send_buffer(self, start_idx: int):
end_idx = start_idx + self._buffer_size
if self.diff_mode == "off":
if end_idx <= self.n_slices:
chunk = self._data[start_idx:end_idx]
else:
chunk = np.concatenate(
[self._data[start_idx:], self._data[: end_idx - self.n_slices]]
)
else:
frames = []
for j in range(self._buffer_size):
idx = (start_idx + j) % self.n_slices
frames.append(self._get_display_frame(idx))
chunk = np.stack(frames)
with self.hold_sync():
self._buffer_start = int(start_idx)
self._buffer_count = int(chunk.shape[0])
self._buffer_bytes = chunk.tobytes()
def _on_playing_change(self, change=None):
if self.playing:
self._send_buffer(self.slice_idx)
else:
# Playback stopped — refresh stats for the current frame
self._update_all()
def _on_prefetch(self, change=None):
if self._prefetch_request >= 0 and self.playing:
self._send_buffer(self._prefetch_request % self.n_slices)
def _on_slice_change(self, change=None):
if self.playing:
return
self._update_all()
def _on_roi_change(self, change=None):
"""Handle ROI change."""
if self.roi_active:
self._update_roi_stats(self._get_display_frame())
self._compute_roi_plot()
else:
self.roi_stats = {}
self.roi_plot_data = b""
def _compute_roi_plot(self):
"""Compute selected ROI mean for all frames."""
idx = self.roi_selected_idx
if idx < 0 or idx >= len(self.roi_list):
self.roi_plot_data = b""
return
mask = self._roi_mask(self.roi_list[idx])
if mask.sum() == 0:
self.roi_plot_data = b""
return
if self._use_torch:
mask_t = torch.from_numpy(mask).to(self._device)
# Vectorized: (n_slices, n_masked_pixels) -> mean per frame
masked = self._data_torch[:, mask_t]
means = masked.mean(dim=1).cpu().numpy().astype(np.float32)
else:
means = np.array([float(self._data[i][mask].mean()) for i in range(self.n_slices)], dtype=np.float32)
self.roi_plot_data = means.tobytes()
# =========================================================================
# Public Methods
# =========================================================================
[docs]
def play(self) -> Self:
"""Start playback."""
self.playing = True
return self
[docs]
def pause(self) -> Self:
"""Pause playback."""
self.playing = False
return self
[docs]
def stop(self) -> Self:
"""Stop playback and reset to beginning."""
self.playing = False
self.slice_idx = 0
return self
[docs]
def goto(self, index: int) -> Self:
"""Jump to a specific frame index."""
self.slice_idx = int(index) % self.n_slices
return self
[docs]
def set_playback_path(self, path) -> Self:
"""Set custom playback order (list of frame indices)."""
self.playback_path = [int(i) % self.n_slices for i in path]
return self
[docs]
def clear_playback_path(self) -> Self:
"""Clear custom playback path (revert to sequential)."""
self.playback_path = []
return self
[docs]
def profile_all_frames(self, start: tuple | None = None, end: tuple | None = None) -> np.ndarray:
"""Extract the line profile from every frame, returning (n_slices, n_points).
Uses the current profile_line unless start/end are provided.
Always samples raw data (ignores diff_mode).
Parameters
----------
start : tuple of (row, col), optional
Start point. Overrides current profile_line.
end : tuple of (row, col), optional
End point. Overrides current profile_line.
Returns
-------
np.ndarray
Shape (n_slices, n_points) float32 array.
"""
if start is not None and end is not None:
row0, col0 = float(start[0]), float(start[1])
row1, col1 = float(end[0]), float(end[1])
elif len(self.profile_line) >= 2:
p0, p1 = self.profile_line[0], self.profile_line[1]
row0, col0 = p0["row"], p0["col"]
row1, col1 = p1["row"], p1["col"]
else:
raise ValueError(
"No profile line set. Call set_profile() first or pass start/end."
)
rows = []
for i in range(self.n_slices):
rows.append(self._sample_profile_on(self._data[i], row0, col0, row1, col1))
return np.stack(rows)
def _upsert_selected_roi(self, updates: dict):
rois = list(self.roi_list)
color_cycle = ["#4fc3f7", "#81c784", "#ffb74d", "#ce93d8", "#ef5350", "#ffd54f", "#90a4ae", "#a1887f"]
defaults = {
"shape": "circle",
"row": int(self.height // 2),
"col": int(self.width // 2),
"radius": 10,
"radius_inner": 5,
"width": 20,
"height": 20,
"line_width": 2,
"highlight": False,
"visible": True,
"locked": False,
}
if self.roi_selected_idx >= 0 and self.roi_selected_idx < len(rois):
current = {**defaults, **rois[self.roi_selected_idx]}
if not current.get("color"):
current["color"] = color_cycle[self.roi_selected_idx % len(color_cycle)]
rois[self.roi_selected_idx] = {**current, **updates}
else:
rois.append({**defaults, "color": color_cycle[len(rois) % len(color_cycle)], **updates})
self.roi_selected_idx = len(rois) - 1
self.roi_list = rois
self.roi_active = True
[docs]
def add_roi(self, row: int | None = None, col: int | None = None, shape: str = "square") -> Self:
with self.hold_sync():
self._upsert_selected_roi({
"shape": shape,
"row": int(self.height // 2 if row is None else row),
"col": int(self.width // 2 if col is None else col),
})
return self
[docs]
def clear_rois(self) -> Self:
with self.hold_sync():
self.roi_list = []
self.roi_selected_idx = -1
self.roi_active = False
return self
[docs]
def delete_selected_roi(self) -> Self:
"""Delete the currently selected ROI."""
idx = int(self.roi_selected_idx)
if idx < 0 or idx >= len(self.roi_list):
return self
with self.hold_sync():
rois = [roi for i, roi in enumerate(self.roi_list) if i != idx]
self.roi_list = rois
self.roi_selected_idx = min(idx, len(rois) - 1) if rois else -1
if not rois:
self.roi_active = False
return self
[docs]
def duplicate_selected_roi(self, row_offset: int = 3, col_offset: int = 3) -> Self:
"""Duplicate selected ROI with a small offset and auto-assigned color."""
idx = int(self.roi_selected_idx)
if idx < 0 or idx >= len(self.roi_list):
return self
color_cycle = ["#4fc3f7", "#81c784", "#ffb74d", "#ce93d8", "#ef5350", "#ffd54f", "#90a4ae", "#a1887f"]
src = dict(self.roi_list[idx])
with self.hold_sync():
rois = list(self.roi_list)
src["row"] = int(np.clip(float(src.get("row", self.height // 2)) + row_offset, 0, self.height - 1))
src["col"] = int(np.clip(float(src.get("col", self.width // 2)) + col_offset, 0, self.width - 1))
src["color"] = color_cycle[len(rois) % len(color_cycle)]
src["highlight"] = False
src["visible"] = True
src["locked"] = False
rois.append(src)
self.roi_list = rois
self.roi_selected_idx = len(rois) - 1
self.roi_active = True
return self
[docs]
def set_roi(self, row: int, col: int, radius: int = 10) -> Self:
"""Set selected ROI position and size (creates one if needed)."""
with self.hold_sync():
self._upsert_selected_roi({"shape": "circle", "row": int(row), "col": int(col), "radius": int(radius)})
return self
[docs]
def roi_circle(self, radius: int = 10) -> Self:
"""Set selected ROI shape to circle."""
with self.hold_sync():
self._upsert_selected_roi({"shape": "circle", "radius": int(radius)})
return self
[docs]
def roi_square(self, half_size: int = 10) -> Self:
"""Set selected ROI shape to square."""
with self.hold_sync():
self._upsert_selected_roi({"shape": "square", "radius": int(half_size)})
return self
[docs]
def roi_rectangle(self, width: int = 20, height: int = 10) -> Self:
"""Set selected ROI shape to rectangle."""
with self.hold_sync():
self._upsert_selected_roi({"shape": "rectangle", "width": int(width), "height": int(height)})
return self
[docs]
def roi_annular(self, inner: int = 5, outer: int = 10) -> Self:
"""Set selected ROI shape to annular (donut)."""
with self.hold_sync():
self._upsert_selected_roi({"shape": "annular", "radius_inner": int(inner), "radius": int(outer)})
return self
def _sample_line(self, img, row0, col0, row1, col1):
h, w = img.shape
dc, dr = col1 - col0, row1 - row0
length = (dc**2 + dr**2) ** 0.5
n = max(2, int(np.ceil(length)))
t = np.linspace(0, 1, n)
cs = col0 + t * dc
rs = row0 + t * dr
ci = np.floor(cs).astype(int)
ri = np.floor(rs).astype(int)
cf = cs - ci
rf = rs - ri
c0c = np.clip(ci, 0, w - 1)
c1c = np.clip(ci + 1, 0, w - 1)
r0c = np.clip(ri, 0, h - 1)
r1c = np.clip(ri + 1, 0, h - 1)
return (img[r0c, c0c] * (1 - cf) * (1 - rf) +
img[r0c, c1c] * cf * (1 - rf) +
img[r1c, c0c] * (1 - cf) * rf +
img[r1c, c1c] * cf * rf)
def _sample_profile_on(self, img, row0, col0, row1, col1):
pw = self.profile_width
if pw <= 1:
return self._sample_line(img, row0, col0, row1, col1).astype(np.float32)
dc, dr = col1 - col0, row1 - row0
length = (dc**2 + dr**2) ** 0.5
if length < 1e-8:
return self._sample_line(img, row0, col0, row1, col1).astype(np.float32)
perp_r, perp_c = -dc / length, dr / length
half = (pw - 1) / 2.0
offsets = np.linspace(-half, half, pw)
accumulated = None
for off in offsets:
vals = self._sample_line(img, row0 + off * perp_r, col0 + off * perp_c,
row1 + off * perp_r, col1 + off * perp_c)
if accumulated is None:
accumulated = vals.copy()
else:
accumulated += vals
return (accumulated / pw).astype(np.float32)
def _sample_profile(self, row0, col0, row1, col1):
return self._sample_profile_on(self._get_display_frame(), row0, col0, row1, col1)
[docs]
def set_profile(self, start: tuple, end: tuple) -> Self:
"""Set a line profile between two points (image pixel coordinates).
Parameters
----------
start : tuple of (row, col)
Start point in pixel coordinates.
end : tuple of (row, col)
End point in pixel coordinates.
"""
row0, col0 = start
row1, col1 = end
self.profile_line = [
{"row": float(row0), "col": float(col0)},
{"row": float(row1), "col": float(col1)},
]
return self
[docs]
def clear_profile(self) -> Self:
"""Clear the current line profile."""
self.profile_line = []
return self
@property
def profile(self):
"""Get profile line endpoints as [(row0, col0), (row1, col1)] or []."""
return [(p["row"], p["col"]) for p in self.profile_line]
@property
def profile_values(self):
"""Get intensity values along the profile line for the current frame."""
if len(self.profile_line) < 2:
return None
p0, p1 = self.profile_line
return self._sample_profile(p0["row"], p0["col"], p1["row"], p1["col"])
@property
def profile_distance(self):
"""Get total distance of the profile line in calibrated units (Å or px)."""
if len(self.profile_line) < 2:
return None
p0, p1 = self.profile_line
dc = p1["col"] - p0["col"]
dr = p1["row"] - p0["row"]
dist_px = (dc**2 + dr**2) ** 0.5
if self.pixel_size > 0:
return dist_px * self.pixel_size
return dist_px
def _on_gif_export(self, change=None):
if not self._gif_export_requested:
return
self._gif_export_requested = False
self._generate_gif()
def _normalize_frames_torch(self, start: int, end: int) -> np.ndarray:
"""Batch-normalize frames [start, end] on GPU. Returns (N, H, W) uint8 numpy."""
frames = self._data_torch[start : end + 1].clone()
if self.log_scale:
frames = torch.log1p(torch.clamp(frames, min=0))
if self.auto_contrast:
flat = frames.reshape(-1).float()
vmin = float(torch.quantile(flat, self.percentile_low / 100.0).item())
vmax = float(torch.quantile(flat, self.percentile_high / 100.0).item())
else:
vmin = self._vmin
vmax = self._vmax
if self.log_scale:
vmin = float(np.log1p(max(vmin, 0)))
vmax = float(np.log1p(max(vmax, 0)))
if vmax > vmin:
normalized = torch.clamp((frames - vmin) / (vmax - vmin) * 255.0, 0, 255).to(torch.uint8)
else:
normalized = torch.zeros_like(frames, dtype=torch.uint8)
return normalized.cpu().numpy()
def _generate_gif(self):
import io
from matplotlib import colormaps
from PIL import Image
start = max(0, self.loop_start)
end = self.loop_end if self.loop_end >= 0 else self.n_slices - 1
end = min(end, self.n_slices - 1)
cmap_fn = colormaps.get_cmap(self.cmap)
duration_ms = int(1000 / max(0.1, self.fps))
pil_frames = []
if self._use_torch:
normalized_all = self._normalize_frames_torch(start, end)
for i in range(normalized_all.shape[0]):
rgba = cmap_fn(normalized_all[i] / 255.0)
rgb = (rgba[:, :, :3] * 255).astype(np.uint8)
pil_frames.append(Image.fromarray(rgb))
else:
for i in range(start, end + 1):
frame = self._data[i]
normalized = self._normalize_frame(frame)
rgba = cmap_fn(normalized / 255.0)
rgb = (rgba[:, :, :3] * 255).astype(np.uint8)
pil_frames.append(Image.fromarray(rgb))
if not pil_frames:
with self.hold_sync():
self._gif_data = b""
self._gif_metadata_json = ""
return
buf = io.BytesIO()
pil_frames[0].save(
buf,
format="GIF",
save_all=True,
append_images=pil_frames[1:],
duration=duration_ms,
loop=0,
)
metadata = {
**build_json_header("Show3D"),
"format": "gif",
"export_kind": "animated_frames",
"frame_range": {"start": int(start), "end": int(end)},
"n_frames": int(len(pil_frames)),
"duration_ms": int(duration_ms),
"display": {
"cmap": self.cmap,
"log_scale": bool(self.log_scale),
"auto_contrast": bool(self.auto_contrast),
"percentile_low": float(self.percentile_low),
"percentile_high": float(self.percentile_high),
},
}
with self.hold_sync():
self._gif_metadata_json = json.dumps(metadata, indent=2)
self._gif_data = buf.getvalue()
def _on_zip_export(self, change=None):
if not self._zip_export_requested:
return
self._zip_export_requested = False
self._generate_zip()
def _generate_zip(self):
import io
import zipfile
from matplotlib import colormaps
from PIL import Image
start = max(0, self.loop_start)
end = self.loop_end if self.loop_end >= 0 else self.n_slices - 1
end = min(end, self.n_slices - 1)
cmap_fn = colormaps.get_cmap(self.cmap)
buf = io.BytesIO()
with zipfile.ZipFile(buf, "w", zipfile.ZIP_DEFLATED) as zf:
metadata = {
**build_json_header("Show3D"),
"format": "zip",
"export_kind": "png_frames",
"frame_range": {"start": int(start), "end": int(end)},
"n_frames": int(end - start + 1),
"display": {"cmap": self.cmap, "log_scale": bool(self.log_scale)},
}
zf.writestr("metadata.json", json.dumps(metadata, indent=2))
if self._use_torch:
normalized_all = self._normalize_frames_torch(start, end)
for j in range(normalized_all.shape[0]):
i = start + j
rgba = cmap_fn(normalized_all[j] / 255.0)
rgb = (rgba[:, :, :3] * 255).astype(np.uint8)
img = Image.fromarray(rgb)
img_buf = io.BytesIO()
img.save(img_buf, format="PNG")
label = self.labels[i] if self.labels else str(i).zfill(4)
zf.writestr(f"frame_{label}.png", img_buf.getvalue())
else:
for i in range(start, end + 1):
frame = self._data[i]
normalized = self._normalize_frame(frame)
rgba = cmap_fn(normalized / 255.0)
rgb = (rgba[:, :, :3] * 255).astype(np.uint8)
img = Image.fromarray(rgb)
img_buf = io.BytesIO()
img.save(img_buf, format="PNG")
label = self.labels[i] if self.labels else str(i).zfill(4)
zf.writestr(f"frame_{label}.png", img_buf.getvalue())
self._zip_data = buf.getvalue()
def _on_bundle_export(self, change=None):
if not self._bundle_export_requested:
return
self._bundle_export_requested = False
self._generate_bundle()
def _roi_timeseries_csv(self) -> str:
import csv
import io
rois = list(self.roi_list)
masks = [self._roi_mask(roi) for roi in rois]
out = io.StringIO()
writer = csv.writer(out)
header = ["frame_index", "label"]
if self.timestamps and len(self.timestamps) >= self.n_slices:
header.append(f"timestamp_{self.timestamp_unit or 'value'}")
header.extend([f"roi_{i + 1}_mean" for i in range(len(rois))])
writer.writerow(header)
if self._use_torch:
# Vectorized per-ROI means across all frames
masks_t = [torch.from_numpy(m).to(self._device) for m in masks]
roi_means = []
for mask_t in masks_t:
masked = self._data_torch[:, mask_t] # (n_slices, n_pixels)
if masked.shape[1] > 0:
roi_means.append(masked.mean(dim=1).cpu().numpy())
else:
roi_means.append(np.full(self.n_slices, np.nan))
for i in range(self.n_slices):
row = [i, self.labels[i] if i < len(self.labels) else str(i)]
if self.timestamps and len(self.timestamps) >= self.n_slices:
row.append(float(self.timestamps[i]))
for rm in roi_means:
val = rm[i]
row.append(float(val) if not np.isnan(val) else "")
writer.writerow(row)
else:
for i in range(self.n_slices):
row = [i, self.labels[i] if i < len(self.labels) else str(i)]
if self.timestamps and len(self.timestamps) >= self.n_slices:
row.append(float(self.timestamps[i]))
frame = self._data[i]
for mask in masks:
region = frame[mask]
row.append(float(region.mean()) if region.size > 0 else "")
writer.writerow(row)
return out.getvalue()
def _generate_bundle(self):
import io
import zipfile
from matplotlib import colormaps
from PIL import Image
idx = int(np.clip(self.slice_idx, 0, self.n_slices - 1))
cmap_fn = colormaps.get_cmap(self.cmap)
frame = self._data[idx]
normalized = self._normalize_frame(frame)
rgba = cmap_fn(normalized / 255.0)
rgb = (rgba[:, :, :3] * 255).astype(np.uint8)
img = Image.fromarray(rgb)
img_buf = io.BytesIO()
img.save(img_buf, format="PNG")
state_payload = {**build_json_header("Show3D"), "state": self.state_dict()}
csv_text = self._roi_timeseries_csv()
label = self.labels[idx] if idx < len(self.labels) else str(idx)
safe_label = "".join(ch if ch.isalnum() or ch in ("-", "_") else "_" for ch in str(label)).strip("_") or str(idx)
buf = io.BytesIO()
with zipfile.ZipFile(buf, "w", zipfile.ZIP_DEFLATED) as zf:
zf.writestr(f"frame_{safe_label}.png", img_buf.getvalue())
zf.writestr("roi_timeseries.csv", csv_text)
zf.writestr("state.json", json.dumps(state_payload, indent=2))
self._bundle_data = buf.getvalue()
[docs]
def save_image(self, path: str | pathlib.Path, *, frame_idx: int | None = None,
format: str | None = None, dpi: int = 150) -> pathlib.Path:
"""Save a single frame as PNG, PDF, or TIFF.
Parameters
----------
path : str or pathlib.Path
Output file path.
frame_idx : int, optional
Frame index to export. Defaults to current slice_idx.
format : str, optional
'png', 'pdf', or 'tiff'. If omitted, inferred from file extension.
dpi : int, default 150
Output DPI metadata.
Returns
-------
pathlib.Path
The written file path.
"""
from matplotlib import colormaps
from PIL import Image
path = pathlib.Path(path)
fmt = (format or path.suffix.lstrip(".").lower() or "png").lower()
if fmt not in ("png", "pdf", "tiff", "tif"):
raise ValueError(f"Unsupported format: {fmt!r}. Use 'png', 'pdf', or 'tiff'.")
idx = frame_idx if frame_idx is not None else self.slice_idx
if idx < 0 or idx >= self.n_slices:
raise IndexError(f"Frame index {idx} out of range [0, {self.n_slices})")
frame = self._data[idx]
normalized = self._normalize_frame(frame)
cmap_fn = colormaps.get_cmap(self.cmap)
rgba = (cmap_fn(normalized / 255.0) * 255).astype(np.uint8)
img = Image.fromarray(rgba)
path.parent.mkdir(parents=True, exist_ok=True)
img.save(str(path), dpi=(dpi, dpi))
return path
bind_tool_runtime_api(Show3D, "Show3D")