"""
show2d: Static 2D image viewer with optional FFT and histogram analysis.
For displaying a single image or a static gallery of multiple images.
Unlike Show3D (interactive), Show2D focuses on static visualization.
"""
import json
import pathlib
import io
import base64
import math
from enum import StrEnum
from typing import Optional, Union, List, Self
import anywidget
import matplotlib.pyplot as plt
import numpy as np
import traitlets
from quantem.widget.array_utils import to_numpy, _resize_image
from quantem.widget.io import IO, IOResult
from quantem.widget.json_state import resolve_widget_version, save_state_file, unwrap_state_payload
from quantem.widget.tool_parity import (
bind_tool_runtime_api,
build_tool_groups,
normalize_tool_groups,
)
class Colormap(StrEnum):
INFERNO = "inferno"
VIRIDIS = "viridis"
MAGMA = "magma"
PLASMA = "plasma"
GRAY = "gray"
[docs]
class Show2D(anywidget.AnyWidget):
"""
Static 2D image viewer with optional FFT and histogram analysis.
Display a single image or multiple images in a gallery layout.
For interactive stack viewing with playback, use Show3D instead.
Parameters
----------
data : array_like
2D array (height, width) for single image, or
3D array (N, height, width) for multiple images displayed as gallery.
labels : list of str, optional
Labels for each image in gallery mode.
title : str, optional
Title to display above the image(s).
cmap : str, default "inferno"
Colormap name ("magma", "viridis", "gray", "inferno", "plasma").
pixel_size : float, optional
Pixel size in angstroms for scale bar display.
show_fft : bool, default False
Show FFT and histogram panels.
show_stats : bool, default True
Show statistics (mean, min, max, std).
log_scale : bool, default False
Use log scale for intensity mapping.
auto_contrast : bool, default False
Use percentile-based contrast.
ncols : int, default 3
Number of columns in gallery mode.
disabled_tools : list of str, optional
Tool groups to lock while still showing controls. Supported:
``"display"``, ``"histogram"``, ``"stats"``, ``"navigation"``,
``"view"``, ``"export"``, ``"roi"``, ``"profile"``, ``"all"``.
disable_* : bool, optional
Convenience flags (``disable_display``, ``disable_histogram``,
``disable_stats``, ``disable_navigation``, ``disable_view``,
``disable_export``, ``disable_roi``, ``disable_profile``,
``disable_all``) equivalent to adding those keys to
``disabled_tools``.
hidden_tools : list of str, optional
Tool groups to hide from the UI. Uses the same keys as
``disabled_tools``.
hide_* : bool, optional
Convenience flags mirroring ``disable_*`` for ``hidden_tools``.
Examples
--------
>>> import numpy as np
>>> from quantem.widget import Show2D
>>>
>>> # Single image with FFT
>>> Show2D(image, title="HRTEM Image", show_fft=True, pixel_size=1.0)
>>>
>>> # Gallery of multiple images
>>> labels = ["Raw", "Filtered", "FFT"]
>>> Show2D([img1, img2, img3], labels=labels, ncols=3)
"""
_esm = pathlib.Path(__file__).parent / "static" / "show2d.js"
_css = pathlib.Path(__file__).parent / "static" / "show2d.css"
# =========================================================================
# Core State
# =========================================================================
widget_version = traitlets.Unicode("unknown").tag(sync=True)
n_images = traitlets.Int(1).tag(sync=True)
height = traitlets.Int(1).tag(sync=True)
width = traitlets.Int(1).tag(sync=True)
frame_bytes = traitlets.Bytes(b"").tag(sync=True)
labels = traitlets.List(traitlets.Unicode()).tag(sync=True)
title = traitlets.Unicode("").tag(sync=True)
cmap = traitlets.Unicode("inferno").tag(sync=True)
ncols = traitlets.Int(3).tag(sync=True)
# =========================================================================
# Display Options
# =========================================================================
log_scale = traitlets.Bool(False).tag(sync=True)
auto_contrast = traitlets.Bool(False).tag(sync=True)
# =========================================================================
# Scale Bar
# =========================================================================
pixel_size = traitlets.Float(0.0).tag(sync=True)
scale_bar_visible = traitlets.Bool(True).tag(sync=True)
canvas_size = traitlets.Int(0).tag(sync=True)
# =========================================================================
# UI Visibility
# =========================================================================
show_controls = traitlets.Bool(True).tag(sync=True)
show_stats = traitlets.Bool(True).tag(sync=True)
disabled_tools = traitlets.List(traitlets.Unicode()).tag(sync=True)
hidden_tools = traitlets.List(traitlets.Unicode()).tag(sync=True)
stats_mean = traitlets.List(traitlets.Float()).tag(sync=True)
stats_min = traitlets.List(traitlets.Float()).tag(sync=True)
stats_max = traitlets.List(traitlets.Float()).tag(sync=True)
stats_std = traitlets.List(traitlets.Float()).tag(sync=True)
# =========================================================================
# Analysis Panels (FFT + Histogram shown together)
# =========================================================================
show_fft = traitlets.Bool(False).tag(sync=True)
fft_window = traitlets.Bool(True).tag(sync=True)
# =========================================================================
# Selected Image (for single-image analysis display)
# =========================================================================
selected_idx = traitlets.Int(0).tag(sync=True)
# =========================================================================
# ROI Selection
# =========================================================================
roi_active = traitlets.Bool(False).tag(sync=True)
roi_list = traitlets.List([]).tag(sync=True)
roi_selected_idx = traitlets.Int(-1).tag(sync=True)
# =========================================================================
# Line Profile
# =========================================================================
profile_line = traitlets.List(traitlets.Dict()).tag(sync=True)
@classmethod
def _normalize_tool_groups(cls, tool_groups) -> List[str]:
return normalize_tool_groups("Show2D", tool_groups)
@classmethod
def _build_disabled_tools(
cls,
disabled_tools=None,
disable_display: bool = False,
disable_histogram: bool = False,
disable_stats: bool = False,
disable_navigation: bool = False,
disable_view: bool = False,
disable_export: bool = False,
disable_roi: bool = False,
disable_profile: bool = False,
disable_all: bool = False,
) -> List[str]:
return build_tool_groups(
"Show2D",
tool_groups=disabled_tools,
all_flag=disable_all,
flag_map={
"display": disable_display,
"histogram": disable_histogram,
"stats": disable_stats,
"navigation": disable_navigation,
"view": disable_view,
"export": disable_export,
"roi": disable_roi,
"profile": disable_profile,
},
)
@classmethod
def _build_hidden_tools(
cls,
hidden_tools=None,
hide_display: bool = False,
hide_histogram: bool = False,
hide_stats: bool = False,
hide_navigation: bool = False,
hide_view: bool = False,
hide_export: bool = False,
hide_roi: bool = False,
hide_profile: bool = False,
hide_all: bool = False,
) -> List[str]:
return build_tool_groups(
"Show2D",
tool_groups=hidden_tools,
all_flag=hide_all,
flag_map={
"display": hide_display,
"histogram": hide_histogram,
"stats": hide_stats,
"navigation": hide_navigation,
"view": hide_view,
"export": hide_export,
"roi": hide_roi,
"profile": hide_profile,
},
)
@traitlets.validate("disabled_tools")
def _validate_disabled_tools(self, proposal):
return self._normalize_tool_groups(proposal["value"])
@traitlets.validate("hidden_tools")
def _validate_hidden_tools(self, proposal):
return self._normalize_tool_groups(proposal["value"])
def __init__(
self,
data: Union[np.ndarray, List[np.ndarray]],
labels: Optional[List[str]] = None,
title: str = "",
cmap: Union[str, Colormap] = Colormap.INFERNO,
pixel_size: float = 0.0,
scale_bar_visible: bool = True,
show_fft: bool = False,
fft_window: bool = True,
show_controls: bool = True,
show_stats: bool = True,
log_scale: bool = False,
auto_contrast: bool = False,
disabled_tools: Optional[List[str]] = None,
disable_display: bool = False,
disable_histogram: bool = False,
disable_stats: bool = False,
disable_navigation: bool = False,
disable_view: bool = False,
disable_export: bool = False,
disable_roi: bool = False,
disable_profile: bool = False,
disable_all: bool = False,
hidden_tools: Optional[List[str]] = None,
hide_display: bool = False,
hide_histogram: bool = False,
hide_stats: bool = False,
hide_navigation: bool = False,
hide_view: bool = False,
hide_export: bool = False,
hide_roi: bool = False,
hide_profile: bool = False,
hide_all: bool = False,
ncols: int = 3,
canvas_size: int = 0,
state=None,
**kwargs,
):
super().__init__(**kwargs)
self.widget_version = resolve_widget_version()
# Check if data is an IOResult and extract metadata
if isinstance(data, IOResult):
if not title and data.title:
title = data.title
if pixel_size == 0.0 and data.pixel_size is not None:
pixel_size = data.pixel_size
if labels is None and data.labels:
labels = data.labels
data = data.data
# Check if data is a Dataset2d and extract metadata
if hasattr(data, "array") and hasattr(data, "name") and hasattr(data, "sampling"):
if not title and data.name:
title = data.name
if pixel_size == 0.0 and hasattr(data, "units"):
units = list(data.units)
sampling_val = float(data.sampling[-1])
if units[-1] in ("nm",):
pixel_size = sampling_val * 10 # nm → Å
elif units[-1] in ("Å", "angstrom", "A"):
pixel_size = sampling_val
data = data.array
# Convert input to NumPy (handles NumPy, CuPy, PyTorch)
if isinstance(data, list):
images = [to_numpy(d) for d in data]
# Check if all images have the same shape
shapes = [img.shape for img in images]
if len(set(shapes)) > 1:
# Different sizes - resize all to the largest
max_h = max(s[0] for s in shapes)
max_w = max(s[1] for s in shapes)
images = [_resize_image(img, max_h, max_w) for img in images]
data = np.stack(images)
else:
data = to_numpy(data)
# Ensure 3D shape (N, H, W)
if data.ndim == 2:
data = data[np.newaxis, ...]
self._data = data.astype(np.float32)
self.n_images = int(data.shape[0])
self.height = int(data.shape[1])
self.width = int(data.shape[2])
# Labels
if labels is None:
self.labels = [f"Image {i+1}" for i in range(self.n_images)]
else:
self.labels = list(labels)
# Options
self.title = title
self.cmap = cmap
self.pixel_size = pixel_size
self.scale_bar_visible = scale_bar_visible
self.canvas_size = canvas_size
self.show_fft = show_fft
self.fft_window = fft_window
self.show_controls = show_controls
self.show_stats = show_stats
self.log_scale = log_scale
self.auto_contrast = auto_contrast
self.disabled_tools = self._build_disabled_tools(
disabled_tools=disabled_tools,
disable_display=disable_display,
disable_histogram=disable_histogram,
disable_stats=disable_stats,
disable_navigation=disable_navigation,
disable_view=disable_view,
disable_export=disable_export,
disable_roi=disable_roi,
disable_profile=disable_profile,
disable_all=disable_all,
)
self.hidden_tools = self._build_hidden_tools(
hidden_tools=hidden_tools,
hide_display=hide_display,
hide_histogram=hide_histogram,
hide_stats=hide_stats,
hide_navigation=hide_navigation,
hide_view=hide_view,
hide_export=hide_export,
hide_roi=hide_roi,
hide_profile=hide_profile,
hide_all=hide_all,
)
self.ncols = ncols
# Compute initial stats
self._compute_all_stats()
# Send raw float32 data to JS (normalization happens in JS for speed)
self._update_all_frames()
self.selected_idx = 0
if state is not None:
if isinstance(state, (str, pathlib.Path)):
state = unwrap_state_payload(
json.loads(pathlib.Path(state).read_text()),
require_envelope=True,
)
else:
state = unwrap_state_payload(state)
self.load_state_dict(state)
[docs]
def set_image(self, data, labels=None):
"""Replace the displayed image(s). Preserves all display settings."""
if hasattr(data, "array") and hasattr(data, "name") and hasattr(data, "sampling"):
data = data.array
if isinstance(data, list):
images = [to_numpy(d) for d in data]
shapes = [img.shape for img in images]
if len(set(shapes)) > 1:
max_h = max(s[0] for s in shapes)
max_w = max(s[1] for s in shapes)
images = [_resize_image(img, max_h, max_w) for img in images]
data = np.stack(images)
else:
data = to_numpy(data)
if data.ndim == 2:
data = data[np.newaxis, ...]
self._data = data.astype(np.float32)
self.n_images = int(data.shape[0])
self.height = int(data.shape[1])
self.width = int(data.shape[2])
if labels is not None:
self.labels = list(labels)
else:
self.labels = [f"Image {i+1}" for i in range(self.n_images)]
self.selected_idx = 0
self._compute_all_stats()
self._update_all_frames()
def __repr__(self) -> str:
if self.n_images > 1:
shape = f"{self.n_images}×{self.height}×{self.width}"
return f"Show2D({shape}, idx={self.selected_idx}, cmap={self.cmap})"
return f"Show2D({self.height}×{self.width}, cmap={self.cmap})"
def _repr_mimebundle_(self, **kwargs):
"""Return widget view + static PNG fallback.
Live Jupyter renders the interactive widget. Static contexts
(nbsphinx, GitHub, nbviewer) fall back to the embedded PNG.
"""
bundle = super()._repr_mimebundle_(**kwargs)
data_dict = bundle[0] if isinstance(bundle, tuple) else bundle
n = self.n_images
ncols = min(self.ncols, n)
nrows = math.ceil(n / ncols)
cell = 4
fig, axes = plt.subplots(
nrows, ncols,
figsize=(cell * ncols, cell * nrows),
squeeze=False,
)
for i in range(nrows * ncols):
r, c = divmod(i, ncols)
ax = axes[r][c]
if i < n:
ax.imshow(self._data[i], cmap=self.cmap, origin="upper")
ax.set_title(self.labels[i], fontsize=10)
ax.axis("off")
if self.title:
fig.suptitle(self.title, fontsize=12)
fig.tight_layout()
buf = io.BytesIO()
fig.savefig(buf, format="png", dpi=120, bbox_inches="tight")
plt.close(fig)
data_dict["image/png"] = base64.b64encode(buf.getvalue()).decode("ascii")
if isinstance(bundle, tuple):
return (data_dict, bundle[1])
return data_dict
def _normalize_frame(self, frame: np.ndarray) -> np.ndarray:
if self.log_scale:
frame = np.log1p(np.maximum(frame, 0))
if self.auto_contrast:
vmin = float(np.percentile(frame, 2))
vmax = float(np.percentile(frame, 98))
else:
vmin = float(frame.min())
vmax = float(frame.max())
if vmax > vmin:
normalized = np.clip((frame - vmin) / (vmax - vmin) * 255, 0, 255)
return normalized.astype(np.uint8)
return np.zeros(frame.shape, dtype=np.uint8)
[docs]
def save_image(
self,
path: str | pathlib.Path,
*,
idx: int | None = None,
format: str | None = None,
dpi: int = 150,
) -> pathlib.Path:
"""Save current image as PNG or PDF.
Parameters
----------
path : str or pathlib.Path
Output file path.
idx : int, optional
Image index in gallery mode. Defaults to current selected_idx.
format : str, optional
'png' or 'pdf'. If omitted, inferred from file extension.
dpi : int, default 150
Output DPI metadata.
Returns
-------
pathlib.Path
The written file path.
"""
from matplotlib import colormaps
from PIL import Image
path = pathlib.Path(path)
fmt = (format or path.suffix.lstrip(".").lower() or "png").lower()
if fmt not in ("png", "pdf", "tiff", "tif"):
raise ValueError(f"Unsupported format: {fmt!r}. Use 'png', 'pdf', or 'tiff'.")
i = idx if idx is not None else self.selected_idx
if i < 0 or i >= self.n_images:
raise IndexError(f"Image index {i} out of range [0, {self.n_images})")
frame = self._data[i]
normalized = self._normalize_frame(frame)
cmap_fn = colormaps.get_cmap(self.cmap)
rgba = (cmap_fn(normalized / 255.0) * 255).astype(np.uint8)
img = Image.fromarray(rgba)
path.parent.mkdir(parents=True, exist_ok=True)
img.save(str(path), dpi=(dpi, dpi))
return path
[docs]
def state_dict(self):
return {
"title": self.title,
"cmap": self.cmap,
"log_scale": self.log_scale,
"auto_contrast": self.auto_contrast,
"show_stats": self.show_stats,
"show_fft": self.show_fft,
"fft_window": self.fft_window,
"show_controls": self.show_controls,
"disabled_tools": self.disabled_tools,
"hidden_tools": self.hidden_tools,
"pixel_size": self.pixel_size,
"scale_bar_visible": self.scale_bar_visible,
"canvas_size": self.canvas_size,
"ncols": self.ncols,
"selected_idx": self.selected_idx,
"roi_active": self.roi_active,
"roi_list": self.roi_list,
"roi_selected_idx": self.roi_selected_idx,
"profile_line": self.profile_line,
}
[docs]
def save(self, path: str):
save_state_file(path, "Show2D", self.state_dict())
[docs]
def load_state_dict(self, state):
for key, val in state.items():
if key == "pixel_size_angstrom":
key = "pixel_size"
if hasattr(self, key):
setattr(self, key, val)
[docs]
def summary(self):
lines = [self.title or "Show2D", "═" * 32]
if self.n_images > 1:
lines.append(f"Image: {self.n_images}×{self.height}×{self.width} ({self.ncols} cols)")
else:
lines.append(f"Image: {self.height}×{self.width}")
if self.pixel_size > 0:
ps = self.pixel_size
if ps >= 10:
lines[-1] += f" ({ps / 10:.2f} nm/px)"
else:
lines[-1] += f" ({ps:.2f} Å/px)"
if hasattr(self, "_data") and self._data is not None:
arr = self._data
lines.append(f"Data: min={float(arr.min()):.4g} max={float(arr.max()):.4g} mean={float(arr.mean()):.4g}")
cmap = self.cmap
scale = "log" if self.log_scale else "linear"
contrast = "auto contrast" if self.auto_contrast else "manual contrast"
display = f"{cmap} | {contrast} | {scale}"
if self.show_fft:
display += " | FFT"
if not self.fft_window:
display += " (no window)"
lines.append(f"Display: {display}")
if self.disabled_tools:
lines.append(f"Locked: {', '.join(self.disabled_tools)}")
if self.hidden_tools:
lines.append(f"Hidden: {', '.join(self.hidden_tools)}")
if self.roi_active and self.roi_list:
lines.append(f"ROI: {len(self.roi_list)} region(s)")
if self.profile_line:
p0, p1 = self.profile_line[0], self.profile_line[1]
lines.append(f"Profile: ({p0['row']:.0f}, {p0['col']:.0f}) → ({p1['row']:.0f}, {p1['col']:.0f})")
print("\n".join(lines))
def _compute_all_stats(self):
"""Compute statistics for all images."""
means, mins, maxs, stds = [], [], [], []
for i in range(self.n_images):
img = self._data[i]
means.append(float(np.mean(img)))
mins.append(float(np.min(img)))
maxs.append(float(np.max(img)))
stds.append(float(np.std(img)))
self.stats_mean = means
self.stats_min = mins
self.stats_max = maxs
self.stats_std = stds
def _update_all_frames(self):
"""Send raw float32 data to JS (normalization happens in JS for speed)."""
self.frame_bytes = self._data.tobytes()
def _sample_profile(self, row0, col0, row1, col1):
img = self._data[self.selected_idx]
h, w = img.shape
dc, dr = col1 - col0, row1 - row0
length = (dc**2 + dr**2) ** 0.5
n = max(2, int(np.ceil(length)))
t = np.linspace(0, 1, n)
cs = col0 + t * dc
rs = row0 + t * dr
ci = np.floor(cs).astype(int)
ri = np.floor(rs).astype(int)
cf = cs - ci
rf = rs - ri
c0c = np.clip(ci, 0, w - 1)
c1c = np.clip(ci + 1, 0, w - 1)
r0c = np.clip(ri, 0, h - 1)
r1c = np.clip(ri + 1, 0, h - 1)
return (img[r0c, c0c] * (1 - cf) * (1 - rf) +
img[r0c, c1c] * cf * (1 - rf) +
img[r1c, c0c] * (1 - cf) * rf +
img[r1c, c1c] * cf * rf).astype(np.float32)
[docs]
def set_profile(self, start: tuple, end: tuple):
"""Set a line profile between two points (image pixel coordinates).
Parameters
----------
start : tuple of (row, col)
Start point in pixel coordinates.
end : tuple of (row, col)
End point in pixel coordinates.
"""
row0, col0 = start
row1, col1 = end
self.profile_line = [
{"row": float(row0), "col": float(col0)},
{"row": float(row1), "col": float(col1)},
]
[docs]
def clear_profile(self):
"""Clear the current line profile."""
self.profile_line = []
def _upsert_selected_roi(self, updates: dict):
rois = list(self.roi_list)
color_cycle = ["#4fc3f7", "#81c784", "#ffb74d", "#ce93d8", "#ef5350", "#ffd54f", "#90a4ae", "#a1887f"]
defaults = {
"shape": "square",
"row": int(self.height // 2),
"col": int(self.width // 2),
"radius": 10,
"radius_inner": 5,
"width": 20,
"height": 20,
"line_width": 2,
"highlight": False,
"visible": True,
"locked": False,
}
if self.roi_selected_idx >= 0 and self.roi_selected_idx < len(rois):
current = {**defaults, **rois[self.roi_selected_idx]}
if not current.get("color"):
current["color"] = color_cycle[self.roi_selected_idx % len(color_cycle)]
rois[self.roi_selected_idx] = {**current, **updates}
else:
rois.append({**defaults, "color": color_cycle[len(rois) % len(color_cycle)], **updates})
self.roi_selected_idx = len(rois) - 1
self.roi_list = rois
self.roi_active = True
[docs]
def add_roi(self, row: int | None = None, col: int | None = None, shape: str = "square") -> Self:
with self.hold_sync():
self.roi_selected_idx = -1
self._upsert_selected_roi({
"shape": shape,
"row": int(self.height // 2 if row is None else row),
"col": int(self.width // 2 if col is None else col),
})
return self
[docs]
def clear_rois(self) -> Self:
with self.hold_sync():
self.roi_list = []
self.roi_selected_idx = -1
self.roi_active = False
return self
[docs]
def delete_selected_roi(self) -> Self:
idx = int(self.roi_selected_idx)
if idx < 0 or idx >= len(self.roi_list):
return self
with self.hold_sync():
rois = [roi for i, roi in enumerate(self.roi_list) if i != idx]
self.roi_list = rois
self.roi_selected_idx = min(idx, len(rois) - 1) if rois else -1
if not rois:
self.roi_active = False
return self
[docs]
def set_roi(self, row: int, col: int, radius: int = 10) -> Self:
with self.hold_sync():
self._upsert_selected_roi({"shape": "circle", "row": int(row), "col": int(col), "radius": int(radius)})
return self
[docs]
def roi_circle(self, radius: int = 10) -> Self:
with self.hold_sync():
self._upsert_selected_roi({"shape": "circle", "radius": int(radius)})
return self
[docs]
def roi_square(self, half_size: int = 10) -> Self:
with self.hold_sync():
self._upsert_selected_roi({"shape": "square", "radius": int(half_size)})
return self
[docs]
def roi_rectangle(self, width: int = 20, height: int = 10) -> Self:
with self.hold_sync():
self._upsert_selected_roi({"shape": "rectangle", "width": int(width), "height": int(height)})
return self
[docs]
def roi_annular(self, inner: int = 5, outer: int = 10) -> Self:
with self.hold_sync():
self._upsert_selected_roi({"shape": "annular", "radius_inner": int(inner), "radius": int(outer)})
return self
@property
def profile(self):
"""Get profile line endpoints as [(row0, col0), (row1, col1)] or [].
Returns
-------
list of tuple
Line endpoints in pixel coordinates, or empty list if no profile.
"""
return [(p["row"], p["col"]) for p in self.profile_line]
@property
def profile_values(self):
"""Get intensity values along the profile line as a numpy array.
Returns
-------
np.ndarray or None
Float32 array of sampled intensities, or None if no profile.
"""
if len(self.profile_line) < 2:
return None
p0, p1 = self.profile_line
return self._sample_profile(p0["row"], p0["col"], p1["row"], p1["col"])
@property
def profile_distance(self):
"""Get total distance of the profile line in calibrated units.
Returns
-------
float or None
Distance in angstroms (if pixel_size > 0) or pixels.
None if no profile line is set.
"""
if len(self.profile_line) < 2:
return None
p0, p1 = self.profile_line
dc = p1["col"] - p0["col"]
dr = p1["row"] - p0["row"]
dist_px = (dc**2 + dr**2) ** 0.5
if self.pixel_size > 0:
return dist_px * self.pixel_size
return dist_px
bind_tool_runtime_api(Show2D, "Show2D")