Source code for quantem.widget.show3dvolume

"""
Show3DVolume: Orthogonal slice viewer for 3D volumetric data.
Displays XY, XZ, YZ planes with interactive sliders.
All slicing happens in JavaScript for instant response.
"""
import json
import pathlib
from typing import Optional, Union
import anywidget
import numpy as np
import traitlets

from quantem.widget.array_utils import to_numpy
from quantem.widget.io import IOResult
from quantem.widget.json_state import build_json_header, resolve_widget_version, save_state_file, unwrap_state_payload
from quantem.widget.tool_parity import (
    bind_tool_runtime_api,
    build_tool_groups,
    normalize_tool_groups,
)


[docs] class Show3DVolume(anywidget.AnyWidget): """ 3D volume viewer with three orthogonal slice planes. Parameters ---------- data : array_like 3D array of shape (nz, ny, nx). title : str, optional Title displayed above the viewer. cmap : str, default "inferno" Colormap name. pixel_size : float, optional Pixel size in angstroms for scale bar. show_stats : bool, default True Show per-slice statistics. log_scale : bool, default False Use log scale for intensity mapping. auto_contrast : bool, default False Use percentile-based contrast. disabled_tools : list of str, optional Tool groups to lock while still showing controls. Supported: ``"display"``, ``"histogram"``, ``"playback"``, ``"fft"``, ``"navigation"``, ``"stats"``, ``"export"``, ``"view"``, ``"volume"``, ``"all"`` disable_* : bool, optional Convenience flags mirroring ``disabled_tools``. hidden_tools : list of str, optional Tool groups to hide from the UI. Uses the same keys as ``disabled_tools``. hide_* : bool, optional Convenience flags mirroring ``disable_*`` for ``hidden_tools``. Examples -------- >>> import numpy as np >>> from quantem.widget import Show3DVolume >>> volume = np.random.rand(64, 64, 64).astype(np.float32) >>> Show3DVolume(volume, title="My Volume", cmap="viridis") """ _esm = pathlib.Path(__file__).parent / "static" / "show3dvolume.js" _css = pathlib.Path(__file__).parent / "static" / "show3dvolume.css" # Volume dimensions nx = traitlets.Int(1).tag(sync=True) ny = traitlets.Int(1).tag(sync=True) nz = traitlets.Int(1).tag(sync=True) # Slice positions slice_x = traitlets.Int(0).tag(sync=True) slice_y = traitlets.Int(0).tag(sync=True) slice_z = traitlets.Int(0).tag(sync=True) # Raw volume data (sent once) volume_bytes = traitlets.Bytes(b"").tag(sync=True) # Dual-volume comparison mode volume_bytes_b = traitlets.Bytes(b"").tag(sync=True) title_b = traitlets.Unicode("").tag(sync=True) dual_mode = traitlets.Bool(False).tag(sync=True) # Stats for volume B (3 values: xy, xz, yz) stats_mean_b = traitlets.List(traitlets.Float()).tag(sync=True) stats_min_b = traitlets.List(traitlets.Float()).tag(sync=True) stats_max_b = traitlets.List(traitlets.Float()).tag(sync=True) stats_std_b = traitlets.List(traitlets.Float()).tag(sync=True) # Display title = traitlets.Unicode("").tag(sync=True) cmap = traitlets.Unicode("inferno").tag(sync=True) log_scale = traitlets.Bool(False).tag(sync=True) auto_contrast = traitlets.Bool(False).tag(sync=True) # Scale bar pixel_size = traitlets.Float(0.0).tag(sync=True) scale_bar_visible = traitlets.Bool(True).tag(sync=True) # UI show_controls = traitlets.Bool(True).tag(sync=True) show_stats = traitlets.Bool(True).tag(sync=True) show_crosshair = traitlets.Bool(True).tag(sync=True) show_fft = traitlets.Bool(False).tag(sync=True) disabled_tools = traitlets.List(traitlets.Unicode()).tag(sync=True) hidden_tools = traitlets.List(traitlets.Unicode()).tag(sync=True) # Axis labels (dim 0, 1, 2 → default "Z", "Y", "X") dim_labels = traitlets.List(traitlets.Unicode(), default_value=["Z", "Y", "X"]).tag(sync=True) # Stats (3 values: xy, xz, yz) stats_mean = traitlets.List(traitlets.Float()).tag(sync=True) stats_min = traitlets.List(traitlets.Float()).tag(sync=True) stats_max = traitlets.List(traitlets.Float()).tag(sync=True) stats_std = traitlets.List(traitlets.Float()).tag(sync=True) # Playback playing = traitlets.Bool(False).tag(sync=True) reverse = traitlets.Bool(False).tag(sync=True) boomerang = traitlets.Bool(False).tag(sync=True) fps = traitlets.Float(5.0).tag(sync=True) loop = traitlets.Bool(True).tag(sync=True) play_axis = traitlets.Int(0).tag(sync=True) # 0=Z, 1=Y, 2=X, 3=All # Export _export_axis = traitlets.Int(0).tag(sync=True) # 0=Z, 1=Y, 2=X _gif_export_requested = traitlets.Bool(False).tag(sync=True) _gif_data = traitlets.Bytes(b"").tag(sync=True) _gif_metadata_json = traitlets.Unicode("").tag(sync=True) _zip_export_requested = traitlets.Bool(False).tag(sync=True) _zip_data = traitlets.Bytes(b"").tag(sync=True) @classmethod def _normalize_tool_groups(cls, tool_groups): return normalize_tool_groups("Show3DVolume", tool_groups) @classmethod def _build_disabled_tools( cls, disabled_tools=None, disable_display: bool = False, disable_histogram: bool = False, disable_playback: bool = False, disable_fft: bool = False, disable_navigation: bool = False, disable_stats: bool = False, disable_export: bool = False, disable_view: bool = False, disable_volume: bool = False, disable_all: bool = False, ): return build_tool_groups( "Show3DVolume", tool_groups=disabled_tools, all_flag=disable_all, flag_map={ "display": disable_display, "histogram": disable_histogram, "playback": disable_playback, "fft": disable_fft, "navigation": disable_navigation, "stats": disable_stats, "export": disable_export, "view": disable_view, "volume": disable_volume, }, ) @classmethod def _build_hidden_tools( cls, hidden_tools=None, hide_display: bool = False, hide_histogram: bool = False, hide_playback: bool = False, hide_fft: bool = False, hide_navigation: bool = False, hide_stats: bool = False, hide_view: bool = False, hide_export: bool = False, hide_volume: bool = False, hide_all: bool = False, ): return build_tool_groups( "Show3DVolume", tool_groups=hidden_tools, all_flag=hide_all, flag_map={ "display": hide_display, "histogram": hide_histogram, "playback": hide_playback, "fft": hide_fft, "navigation": hide_navigation, "stats": hide_stats, "export": hide_export, "view": hide_view, "volume": hide_volume, }, ) @traitlets.validate("disabled_tools") def _validate_disabled_tools(self, proposal): return self._normalize_tool_groups(proposal["value"]) @traitlets.validate("hidden_tools") def _validate_hidden_tools(self, proposal): return self._normalize_tool_groups(proposal["value"]) def __init__( self, data: Union[np.ndarray, "torch.Tensor"], data_b: Union[np.ndarray, "torch.Tensor", None] = None, title: str = "", title_b: str = "", cmap: str = "inferno", pixel_size: float = 0.0, scale_bar_visible: bool = True, show_controls: bool = True, show_stats: bool = True, show_crosshair: bool = True, show_fft: bool = False, log_scale: bool = False, auto_contrast: bool = False, disabled_tools: list[str] | None = None, disable_display: bool = False, disable_histogram: bool = False, disable_playback: bool = False, disable_fft: bool = False, disable_navigation: bool = False, disable_stats: bool = False, disable_export: bool = False, disable_view: bool = False, disable_volume: bool = False, disable_all: bool = False, hidden_tools: list[str] | None = None, hide_display: bool = False, hide_histogram: bool = False, hide_playback: bool = False, hide_fft: bool = False, hide_navigation: bool = False, hide_stats: bool = False, hide_view: bool = False, hide_export: bool = False, hide_volume: bool = False, hide_all: bool = False, fps: float = 5.0, dim_labels: Optional[list] = None, state=None, **kwargs, ): super().__init__(**kwargs) self.widget_version = resolve_widget_version() self.fps = fps if dim_labels is not None: self.dim_labels = dim_labels # Check if data is an IOResult and extract metadata if isinstance(data, IOResult): if not title and data.title: title = data.title if pixel_size == 0.0 and data.pixel_size is not None: pixel_size = data.pixel_size data = data.data # Check if data is a Dataset3d and extract metadata if hasattr(data, "array") and hasattr(data, "name") and hasattr(data, "sampling"): if not title and data.name: title = data.name if pixel_size == 0.0 and hasattr(data, "units"): units = list(data.units) sampling_val = float(data.sampling[-1]) if units[-1] in ("nm",): pixel_size = sampling_val * 10 # nm → Å elif units[-1] in ("Å", "angstrom", "A"): pixel_size = sampling_val data = data.array data = to_numpy(data) if data.ndim != 3: raise ValueError(f"Show3DVolume requires 3D data, got {data.ndim}D") self._data = data.astype(np.float32) self.nz, self.ny, self.nx = self._data.shape # Default to middle slices self.slice_z = self.nz // 2 self.slice_y = self.ny // 2 self.slice_x = self.nx // 2 self.title = title self.cmap = cmap self.pixel_size = pixel_size self.scale_bar_visible = scale_bar_visible self.show_controls = show_controls self.show_stats = show_stats self.show_crosshair = show_crosshair self.show_fft = show_fft self.log_scale = log_scale self.auto_contrast = auto_contrast self.disabled_tools = self._build_disabled_tools( disabled_tools=disabled_tools, disable_display=disable_display, disable_histogram=disable_histogram, disable_playback=disable_playback, disable_fft=disable_fft, disable_navigation=disable_navigation, disable_stats=disable_stats, disable_export=disable_export, disable_view=disable_view, disable_volume=disable_volume, disable_all=disable_all, ) self.hidden_tools = self._build_hidden_tools( hidden_tools=hidden_tools, hide_display=hide_display, hide_histogram=hide_histogram, hide_playback=hide_playback, hide_fft=hide_fft, hide_navigation=hide_navigation, hide_stats=hide_stats, hide_view=hide_view, hide_export=hide_export, hide_volume=hide_volume, hide_all=hide_all, ) # Volume B (dual comparison mode) self._data_b: np.ndarray | None = None if data_b is not None: if isinstance(data_b, IOResult): if not title_b and data_b.title: title_b = data_b.title data_b = data_b.data if hasattr(data_b, "array") and hasattr(data_b, "name") and hasattr(data_b, "sampling"): if not title_b and data_b.name: title_b = data_b.name data_b = data_b.array data_b = to_numpy(data_b) if data_b.ndim != 3: raise ValueError(f"data_b must be 3D, got {data_b.ndim}D") if data_b.shape != self._data.shape: raise ValueError( f"data_b shape {data_b.shape} must match data shape {self._data.shape}" ) self._data_b = data_b.astype(np.float32) self.dual_mode = True self.title_b = title_b self.volume_bytes_b = self._data_b.tobytes() self._compute_stats() self.volume_bytes = self._data.tobytes() self.observe(self._on_slice_change, names=["slice_x", "slice_y", "slice_z"]) self.observe(self._on_gif_export, names=["_gif_export_requested"]) self.observe(self._on_zip_export, names=["_zip_export_requested"]) if state is not None: if isinstance(state, (str, pathlib.Path)): state = unwrap_state_payload( json.loads(pathlib.Path(state).read_text()), require_envelope=True, ) else: state = unwrap_state_payload(state) self.load_state_dict(state)
[docs] def set_image(self, data, data_b=None): """Replace the volume data. Preserves all display settings. Parameters ---------- data : array_like New 3D volume for volume A. data_b : array_like, optional New 3D volume for volume B. Must match data shape. If not provided and dual mode is active, volume B is dropped when the new data shape differs from the old B shape. """ if hasattr(data, "array") and hasattr(data, "name") and hasattr(data, "sampling"): data = data.array data = to_numpy(data) if data.ndim != 3: raise ValueError(f"Show3DVolume requires 3D data, got {data.ndim}D") self._data = data.astype(np.float32) self.nz, self.ny, self.nx = self._data.shape self.slice_z = min(self.slice_z, self.nz - 1) self.slice_y = min(self.slice_y, self.ny - 1) self.slice_x = min(self.slice_x, self.nx - 1) if data_b is not None: if hasattr(data_b, "array") and hasattr(data_b, "name") and hasattr(data_b, "sampling"): data_b = data_b.array data_b = to_numpy(data_b) if data_b.ndim != 3: raise ValueError(f"data_b must be 3D, got {data_b.ndim}D") if data_b.shape != self._data.shape: raise ValueError( f"data_b shape {data_b.shape} must match data shape {self._data.shape}" ) self._data_b = data_b.astype(np.float32) self.dual_mode = True self.volume_bytes_b = self._data_b.tobytes() elif self._data_b is not None and self._data_b.shape != self._data.shape: self._data_b = None self.dual_mode = False self.volume_bytes_b = b"" self._compute_stats() self.volume_bytes = self._data.tobytes()
def __repr__(self) -> str: base = f"Show3DVolume({self.nz}×{self.ny}×{self.nx}, slices=({self.slice_z},{self.slice_y},{self.slice_x}), cmap={self.cmap}" if self.dual_mode: base += ", dual=True" return base + ")"
[docs] def state_dict(self): d = { "title": self.title, "cmap": self.cmap, "log_scale": self.log_scale, "auto_contrast": self.auto_contrast, "show_stats": self.show_stats, "show_controls": self.show_controls, "show_crosshair": self.show_crosshair, "show_fft": self.show_fft, "disabled_tools": self.disabled_tools, "hidden_tools": self.hidden_tools, "pixel_size": self.pixel_size, "scale_bar_visible": self.scale_bar_visible, "slice_x": self.slice_x, "slice_y": self.slice_y, "slice_z": self.slice_z, "fps": self.fps, "loop": self.loop, "reverse": self.reverse, "boomerang": self.boomerang, "play_axis": self.play_axis, "dim_labels": self.dim_labels, "dual_mode": self.dual_mode, "title_b": self.title_b, } return d
[docs] def save(self, path: str): save_state_file(path, "Show3DVolume", self.state_dict())
[docs] def load_state_dict(self, state): for key, val in state.items(): if key == "pixel_size_angstrom": key = "pixel_size" if hasattr(self, key): setattr(self, key, val)
[docs] def summary(self): lines = [self.title or "Show3DVolume", "═" * 32] lines.append(f"Volume: {self.nz}×{self.ny}×{self.nx}") if self.pixel_size > 0: ps = self.pixel_size if ps >= 10: lines[-1] += f" ({ps / 10:.2f} nm/px)" else: lines[-1] += f" ({ps:.2f} Å/px)" labels = self.dim_labels lines.append(f"Slices: {labels[0]}={self.slice_z} {labels[1]}={self.slice_y} {labels[2]}={self.slice_x}") if hasattr(self, "_data") and self._data is not None: arr = self._data lines.append(f"Data: min={float(arr.min()):.4g} max={float(arr.max()):.4g} mean={float(arr.mean()):.4g}") if self.dual_mode and self._data_b is not None: lines.append(f"Volume B: {self.title_b or 'Volume B'}") arr_b = self._data_b lines.append(f"Data B: min={float(arr_b.min()):.4g} max={float(arr_b.max()):.4g} mean={float(arr_b.mean()):.4g}") cmap = self.cmap scale = "log" if self.log_scale else "linear" contrast = "auto contrast" if self.auto_contrast else "manual contrast" display = f"{cmap} | {contrast} | {scale}" if self.show_fft: display += " | FFT" lines.append(f"Display: {display}") if self.disabled_tools: lines.append(f"Locked: {', '.join(self.disabled_tools)}") if self.hidden_tools: lines.append(f"Hidden: {', '.join(self.hidden_tools)}") print("\n".join(lines))
def _compute_stats(self): """Compute statistics for the 3 current slices.""" slices = [ self._data[self.slice_z, :, :], self._data[:, self.slice_y, :], self._data[:, :, self.slice_x], ] with self.hold_sync(): self.stats_mean = [float(np.mean(s)) for s in slices] self.stats_min = [float(np.min(s)) for s in slices] self.stats_max = [float(np.max(s)) for s in slices] self.stats_std = [float(np.std(s)) for s in slices] if self._data_b is not None: slices_b = [ self._data_b[self.slice_z, :, :], self._data_b[:, self.slice_y, :], self._data_b[:, :, self.slice_x], ] self.stats_mean_b = [float(np.mean(s)) for s in slices_b] self.stats_min_b = [float(np.min(s)) for s in slices_b] self.stats_max_b = [float(np.max(s)) for s in slices_b] self.stats_std_b = [float(np.std(s)) for s in slices_b] def _on_slice_change(self, change): self._compute_stats()
[docs] def play(self): self.playing = True
[docs] def pause(self): self.playing = False
[docs] def stop(self): self.playing = False self.slice_z = self.nz // 2 self.slice_y = self.ny // 2 self.slice_x = self.nx // 2
def _on_gif_export(self, change=None): if not self._gif_export_requested: return self._gif_export_requested = False self._generate_gif() def _on_zip_export(self, change=None): if not self._zip_export_requested: return self._zip_export_requested = False self._generate_zip() def _get_export_slices(self): axis = self._export_axis if axis == 0: return [self._data[z, :, :] for z in range(self.nz)] elif axis == 1: return [self._data[:, y, :] for y in range(self.ny)] else: return [self._data[:, :, x] for x in range(self.nx)] def _normalize_slice(self, slc: np.ndarray) -> np.ndarray: if self.log_scale: slc = np.log1p(np.maximum(slc, 0)) if self.auto_contrast: vmin = float(np.percentile(slc, 2)) vmax = float(np.percentile(slc, 98)) else: vmin = float(slc.min()) vmax = float(slc.max()) if vmax > vmin: return np.clip((slc - vmin) / (vmax - vmin) * 255, 0, 255).astype(np.uint8) return np.zeros(slc.shape, dtype=np.uint8) def _generate_gif(self): import io from matplotlib import colormaps from PIL import Image slices = self._get_export_slices() cmap_fn = colormaps.get_cmap(self.cmap) pil_frames = [] for slc in slices: normalized = self._normalize_slice(slc) rgba = cmap_fn(normalized / 255.0) rgb = (rgba[:, :, :3] * 255).astype(np.uint8) pil_frames.append(Image.fromarray(rgb)) if not pil_frames: with self.hold_sync(): self._gif_data = b"" self._gif_metadata_json = "" return buf = io.BytesIO() duration_ms = int(1000 / max(0.1, self.fps)) pil_frames[0].save(buf, format="GIF", save_all=True, append_images=pil_frames[1:], duration=duration_ms, loop=0) metadata = { **build_json_header("Show3DVolume"), "format": "gif", "export_kind": "animated_slices", "export_axis": int(self._export_axis), "n_slices": int(len(pil_frames)), "duration_ms": int(duration_ms), "display": { "cmap": self.cmap, "log_scale": bool(self.log_scale), "auto_contrast": bool(self.auto_contrast), }, } with self.hold_sync(): self._gif_metadata_json = json.dumps(metadata, indent=2) self._gif_data = buf.getvalue() def _generate_zip(self): import io import zipfile from matplotlib import colormaps from PIL import Image slices = self._get_export_slices() cmap_fn = colormaps.get_cmap(self.cmap) buf = io.BytesIO() with zipfile.ZipFile(buf, "w", zipfile.ZIP_DEFLATED) as zf: metadata = { **build_json_header("Show3DVolume"), "format": "zip", "export_kind": "png_slices", "n_slices": int(len(slices)), "display": {"cmap": self.cmap, "log_scale": bool(self.log_scale)}, } zf.writestr("metadata.json", json.dumps(metadata, indent=2)) for i, slc in enumerate(slices): normalized = self._normalize_slice(slc) rgba = cmap_fn(normalized / 255.0) rgb = (rgba[:, :, :3] * 255).astype(np.uint8) img = Image.fromarray(rgb) img_buf = io.BytesIO() img.save(img_buf, format="PNG") zf.writestr(f"slice_{i:04d}.png", img_buf.getvalue()) self._zip_data = buf.getvalue()
[docs] def save_image(self, path: str | pathlib.Path, *, plane: str | None = None, slice_idx: int | None = None, format: str | None = None, dpi: int = 150) -> pathlib.Path: """Save a volume slice as PNG, PDF, or TIFF. Parameters ---------- path : str or pathlib.Path Output file path. plane : str, optional One of 'xy', 'xz', 'yz'. Defaults to 'xy'. slice_idx : int, optional Slice index along the chosen axis. Defaults to current position. format : str, optional 'png', 'pdf', or 'tiff'. If omitted, inferred from file extension. dpi : int, default 150 Output DPI metadata. Returns ------- pathlib.Path The written file path. """ from matplotlib import colormaps from PIL import Image path = pathlib.Path(path) fmt = (format or path.suffix.lstrip(".").lower() or "png").lower() if fmt not in ("png", "pdf", "tiff", "tif"): raise ValueError(f"Unsupported format: {fmt!r}. Use 'png', 'pdf', or 'tiff'.") plane = (plane or "xy").lower() if plane == "xy": idx = slice_idx if slice_idx is not None else self.slice_z max_idx = self.nz elif plane == "xz": idx = slice_idx if slice_idx is not None else self.slice_y max_idx = self.ny elif plane == "yz": idx = slice_idx if slice_idx is not None else self.slice_x max_idx = self.nx else: raise ValueError(f"Unknown plane: {plane!r}. Use 'xy', 'xz', or 'yz'.") if idx < 0 or idx >= max_idx: raise IndexError(f"Slice index {idx} out of range [0, {max_idx}) for plane '{plane}'") if plane == "xy": slc = self._data[idx] elif plane == "xz": slc = self._data[:, idx, :] else: slc = self._data[:, :, idx] normalized = self._normalize_slice(slc) cmap_fn = colormaps.get_cmap(self.cmap) rgba = (cmap_fn(normalized / 255.0) * 255).astype(np.uint8) img = Image.fromarray(rgba) path.parent.mkdir(parents=True, exist_ok=True) img.save(str(path), dpi=(dpi, dpi)) return path
bind_tool_runtime_api(Show3DVolume, "Show3DVolume")