"""
mark2d: Interactive 2D image annotation widget.
Mark points (atom positions, features), draw ROIs, measure distances,
snap to intensity peaks. Supports gallery mode with multiple images.
"""
import json
import pathlib
from typing import List, Optional
import anywidget
import numpy as np
import traitlets
from quantem.widget.array_utils import to_numpy, _resize_image
from quantem.widget.io import IOResult
from quantem.widget.json_state import resolve_widget_version, save_state_file, unwrap_state_payload
from quantem.widget.tool_parity import (
bind_tool_runtime_api,
build_tool_groups,
normalize_tool_groups,
)
_MARKER_SHAPES = ["circle", "triangle", "square", "diamond", "star"]
_MARKER_COLORS = [
"#f44336", "#4caf50", "#2196f3", "#ff9800", "#9c27b0",
"#00bcd4", "#ffeb3b", "#e91e63", "#8bc34a", "#ff5722",
]
_COLOR_NAMES = {
"#f44336": "red", "#4caf50": "green", "#2196f3": "blue",
"#ff9800": "orange", "#9c27b0": "purple", "#00bcd4": "cyan",
"#ffeb3b": "yellow", "#e91e63": "pink", "#8bc34a": "lime",
"#ff5722": "deep orange",
# ROI colors
"#0f0": "green", "#ff0": "yellow", "#0af": "cyan",
"#f0f": "magenta", "#f80": "orange", "#f44": "red",
# Common CSS names
"white": "white", "black": "black", "red": "red",
"#ff0": "yellow", "#0ff": "cyan", "#f00": "red",
}
def _color_name(hex_code: str) -> str:
return _COLOR_NAMES.get(hex_code.lower(), hex_code)
[docs]
class Mark2D(anywidget.AnyWidget):
"""
Interactive point picker for 2D images.
Click on an image to select points (atom positions, features, lattice
vectors). Supports gallery mode for comparing multiple images, pre-loaded
points from detection algorithms, multiple ROI overlays with pixel
statistics, snap-to-peak for precise atom column picking, and calibrated
distance measurements between points.
Parameters
----------
data : array_like
Image data. Accepts:
- 2D array ``(H, W)`` — single image
- 3D array ``(N, H, W)`` — gallery of N images
- List of 2D arrays — gallery (resized to common dimensions)
- ``Dataset2d`` object — array and sampling auto-extracted
scale : float, default 1.0
Display scale factor. Values > 1 enlarge the canvas.
dot_size : int, default 12
Diameter of point markers in CSS pixels.
max_points : int, default 10
Maximum number of points per image. Oldest points are removed
when the limit is exceeded.
ncols : int, default 3
Number of columns in the gallery grid (ignored for single images).
labels : list of str, optional
Per-image labels shown below each gallery tile and in the header.
Defaults to ``"Image 1"``, ``"Image 2"``, etc.
marker_border : int, default 2
Border width of point markers in pixels (0–6). The border grows
inward from the marker edge, so the overall marker size stays
constant. Set to 0 for borderless markers.
marker_opacity : float, default 1.0
Opacity of point markers (0.1–1.0).
label_size : int, default 0
Font size in pixels for the numbered label above each marker.
``0`` means auto-scale relative to ``dot_size``.
label_color : str, default ""
CSS color for numbered labels (e.g. ``"white"``, ``"#ff0"``).
Empty string uses the automatic theme color.
pixel_size : float, default 0.0
Pixel size in angstroms. When set, the widget displays a calibrated
scale bar and shows point-to-point distances in physical units
(angstroms or nanometers). ``0`` means uncalibrated.
points : list or ndarray, optional
Pre-populate the widget with initial points. Useful for reviewing
or refining positions from an atom-finding algorithm. Accepts:
- List of ``(row, col)`` tuples: ``[(10, 20), (30, 40)]``
- List of dicts with optional shape/color:
``[{"row": 10, "col": 20, "shape": "star", "color": "#f00"}]``
- NumPy array of shape ``(N, 2)`` with columns ``[row, col]``
- For gallery: list of the above, one per image
When ``shape`` or ``color`` are omitted, they cycle through the
built-in palettes (5 shapes, 10 colors).
marker_shape : str, default "circle"
Active marker shape for new points. One of ``"circle"``,
``"triangle"``, ``"square"``, ``"diamond"``, ``"star"``.
Synced bidirectionally — changes in the UI are reflected in Python.
marker_color : str, default "#f44336"
Active marker color for new points (CSS color string).
Synced bidirectionally — changes in the UI are reflected in Python.
snap_enabled : bool, default False
Whether snap-to-peak is active. When ``True``, clicked positions
are snapped to the local intensity maximum within ``snap_radius``.
snap_radius : int, default 5
Search radius in pixels for snap-to-peak.
title : str, default ""
Title displayed in the widget header. Empty string shows ``"Mark2D"``.
show_stats : bool, default True
Show statistics bar (mean, min, max, std) below the canvas.
cmap : str, default "gray"
Colormap for image rendering. Options: ``"gray"``, ``"inferno"``,
``"viridis"``, ``"plasma"``, ``"magma"``, ``"hot"``.
auto_contrast : bool, default True
Enable automatic contrast via 2–98% percentile clipping.
When ``False``, contrast is controlled by the histogram slider.
log_scale : bool, default False
Apply log(1+x) transform before rendering. Useful for images
with large dynamic range (e.g. diffraction patterns).
show_fft : bool, default False
Show FFT power spectrum alongside the image.
disabled_tools : list of str, optional
Tool groups to disable in the frontend UI/interaction layer. This is
useful for shared notebooks where viewers should not be able to modify
selected controls or annotations. Supported values:
``"points"``, ``"roi"``, ``"profile"``, ``"display"``,
``"marker_style"``, ``"snap"``, ``"navigation"``, ``"view"``,
``"export"``, ``"all"``.
disable_points : bool, default False
Convenience flag equivalent to including ``"points"`` in
``disabled_tools``.
disable_roi : bool, default False
Convenience flag equivalent to including ``"roi"`` in
``disabled_tools``.
disable_profile : bool, default False
Convenience flag equivalent to including ``"profile"`` in
``disabled_tools``.
disable_display : bool, default False
Convenience flag equivalent to including ``"display"`` in
``disabled_tools``.
disable_marker_style : bool, default False
Convenience flag equivalent to including ``"marker_style"`` in
``disabled_tools``.
disable_snap : bool, default False
Convenience flag equivalent to including ``"snap"`` in
``disabled_tools``.
disable_navigation : bool, default False
Convenience flag equivalent to including ``"navigation"`` in
``disabled_tools``.
disable_view : bool, default False
Convenience flag equivalent to including ``"view"`` in
``disabled_tools``.
disable_export : bool, default False
Convenience flag equivalent to including ``"export"`` in
``disabled_tools``.
disable_all : bool, default False
Convenience flag equivalent to ``disabled_tools=["all"]``.
hidden_tools : list of str, optional
Tool groups to hide from the frontend UI. Hidden tools are also
interaction-locked (equivalent to disabled for behavior), but their
controls are not rendered. Supported values:
``"points"``, ``"roi"``, ``"profile"``, ``"display"``,
``"marker_style"``, ``"snap"``, ``"navigation"``, ``"view"``,
``"export"``, ``"all"``.
hide_points : bool, default False
Convenience flag equivalent to including ``"points"`` in
``hidden_tools``.
hide_roi : bool, default False
Convenience flag equivalent to including ``"roi"`` in
``hidden_tools``.
hide_profile : bool, default False
Convenience flag equivalent to including ``"profile"`` in
``hidden_tools``.
hide_display : bool, default False
Convenience flag equivalent to including ``"display"`` in
``hidden_tools``.
hide_marker_style : bool, default False
Convenience flag equivalent to including ``"marker_style"`` in
``hidden_tools``.
hide_snap : bool, default False
Convenience flag equivalent to including ``"snap"`` in
``hidden_tools``.
hide_navigation : bool, default False
Convenience flag equivalent to including ``"navigation"`` in
``hidden_tools``.
hide_view : bool, default False
Convenience flag equivalent to including ``"view"`` in
``hidden_tools``.
hide_export : bool, default False
Convenience flag equivalent to including ``"export"`` in
``hidden_tools``.
hide_all : bool, default False
Convenience flag equivalent to ``hidden_tools=["all"]``.
Attributes
----------
selected_points : list
Currently placed points, synced bidirectionally with the widget.
- **Single image**: flat list of point dicts
``[{"row": 10, "col": 20, "shape": "circle", "color": "#f44336"}, ...]``
- **Gallery mode**: list of lists, one per image
``[[point, ...], [point, ...], ...]``
Each point dict has keys: ``row`` (int), ``col`` (int), ``shape`` (str),
``color`` (str). You can set this attribute from Python to update
the widget in real time.
roi_list : list
Currently defined ROI overlays, synced with the widget. Each ROI
is a dict with keys:
- ``id`` (int) — unique identifier
- ``mode`` (str) — ``"circle"``, ``"square"``, or ``"rectangle"``
- ``row``, ``col`` (int) — center position in image pixels
- ``radius`` (int) — radius for circle/square modes
- ``rectW``, ``rectH`` (int) — width/height for rectangle mode
- ``color`` (str) — CSS stroke color
- ``opacity`` (float) — opacity (0.1–1.0)
Set from Python to programmatically define ROIs, or read after
interactive use to retrieve user-defined regions.
Notes
-----
**Marker shapes**: circle, triangle, square, diamond, star (5 shapes
that cycle automatically).
**Marker colors**: 10 colors that cycle: red, green, blue, orange,
purple, cyan, yellow, pink, lime, deep orange.
**Snap-to-peak**: When enabled in the UI, clicking snaps the point to
the local intensity maximum within a configurable search radius. Useful
for precise atom column picking on HAADF-STEM images.
**Distance measurements**: Distances between consecutive points are
displayed in the point list. With ``pixel_size`` set, distances
are shown in angstroms (< 10 Å) or nanometers (>= 10 Å).
**ROI statistics**: When an ROI is active, the widget computes and
displays mean, standard deviation, min, max, and pixel count for the
region. When multiple ROIs exist, a summary table shows all ROI stats.
Active ROIs also show dotted horizontal/vertical center guide lines.
**Pairwise distances**: When 2+ points are placed, a table below the
point list shows distances between all pairs of points.
**Line profile**: Toggle "Profile" mode in the controls, then click
two points to sample intensity along a line. A sparkline graph with
calibrated x-axis appears below the canvas. Also available
programmatically via ``set_profile()``, ``profile_values``, and
``profile_distance``.
**Keyboard shortcuts** (widget must be focused):
- ``Delete`` / ``Backspace`` — remove last point (undo)
- ``Ctrl+Z`` / ``Cmd+Z`` — undo
- ``Ctrl+Shift+Z`` / ``Cmd+Shift+Z`` — redo
- ``1``–``6`` — select ROI #1–6
- Arrow keys — nudge active ROI by 1 pixel
- Arrow keys (no ROI, gallery) — navigate between images
- ``Escape`` — deselect ROI
Examples
--------
Basic point picking:
>>> import numpy as np
>>> from quantem.widget import Mark2D
>>> img = np.random.rand(256, 256).astype(np.float32)
>>> w = Mark2D(img, scale=1.5, max_points=5)
>>> w # display in notebook; click to place points
>>> w.selected_points # read back placed points
Pre-loaded points from a detection algorithm:
>>> peaks = find_atom_columns(img) # returns (N, 2) array
>>> w = Mark2D(img, points=peaks, pixel_size=0.82)
>>> # Points appear immediately; user can add/remove/adjust
Pre-loaded points with custom appearance:
>>> pts = [
... {"row": 200, "col": 100, "shape": "star", "color": "#ff0"},
... {"row": 250, "col": 150, "shape": "diamond", "color": "#0ff"},
... ]
>>> w = Mark2D(img, points=pts, marker_border=0, marker_opacity=0.8)
Gallery mode for comparing multiple images:
>>> imgs = [original, filtered, denoised]
>>> w = Mark2D(imgs, ncols=3, labels=["Raw", "Filtered", "Denoised"])
>>> # Points are tracked independently per image
Gallery with per-image pre-loaded points:
>>> per_image_pts = [[(10, 20)], [(30, 40), (50, 60)], []]
>>> w = Mark2D(imgs, points=per_image_pts)
Programmatic ROI management:
>>> w = Mark2D(img)
>>> w.add_roi(row=128, col=128, mode="circle", radius=30, color="#0f0")
>>> w.add_roi(row=200, col=200, mode="rectangle", rect_w=80, rect_h=40)
>>> w.roi_list # inspect ROI parameters
>>> w.roi_center() # center of most recently added ROI -> (200, 200)
>>> w.roi_radius() # radius for circle/square, None for rectangle
>>> w.roi_size() # shape-aware size dict (e.g. width/height for rectangle)
>>> w.roi_list = [] # clear all ROIs
Snap-to-peak for precise atom picking:
>>> w = Mark2D(haadf_image, snap_enabled=True, snap_radius=8,
... pixel_size=0.82)
>>> # Clicks auto-snap to the nearest intensity maximum
Custom marker defaults:
>>> w = Mark2D(img, marker_shape="star", marker_color="#ff9800")
>>> # All new points will be orange stars until changed in the UI
Human-friendly tool locking:
>>> w = Mark2D(img, disable_points=True, disable_roi=True, disable_display=True)
>>> w_read_only = Mark2D(img, disable_all=True)
>>> w_clean = Mark2D(img, hide_display=True, hide_export=True)
Save and restore full widget state (state portability):
>>> # User A: create widget, place points and ROIs interactively
>>> w = Mark2D(img, pixel_size=1.5)
>>> # ... user clicks to place points, adds ROIs, enables snap ...
>>> state = {
... "points": w.selected_points,
... "rois": w.roi_list,
... "marker_shape": w.marker_shape,
... "marker_color": w.marker_color,
... "snap_enabled": w.snap_enabled,
... "snap_radius": w.snap_radius,
... }
>>> # User B: restore exact same state on another machine
>>> w2 = Mark2D(img, pixel_size=1.5,
... points=state["points"],
... marker_shape=state["marker_shape"],
... marker_color=state["marker_color"],
... snap_enabled=state["snap_enabled"],
... snap_radius=state["snap_radius"])
>>> w2.roi_list = state["rois"]
Line profile (programmatic):
>>> w = Mark2D(img, pixel_size=0.82)
>>> w.set_profile((10, 20), (100, 200))
>>> w.profile_values # sampled intensities along the line
>>> w.profile_distance # total distance in angstroms
Export points as NumPy array:
>>> w = Mark2D(img, points=[(10, 20), (30, 40)])
>>> w.points_as_array() # shape (2, 2), columns [row, col]
>>> w.points_as_dict() # [{"row": 10, "col": 20}, ...]
"""
_esm = pathlib.Path(__file__).parent / "static" / "mark2d.js"
# Image data (gallery-capable, matching Show2D pattern)
widget_version = traitlets.Unicode("unknown").tag(sync=True)
n_images = traitlets.Int(1).tag(sync=True)
width = traitlets.Int(0).tag(sync=True)
height = traitlets.Int(0).tag(sync=True)
frame_bytes = traitlets.Bytes(b"").tag(sync=True)
img_min = traitlets.List(traitlets.Float()).tag(sync=True)
img_max = traitlets.List(traitlets.Float()).tag(sync=True)
# Gallery controls
selected_idx = traitlets.Int(0).tag(sync=True)
ncols = traitlets.Int(3).tag(sync=True)
labels = traitlets.List(traitlets.Unicode()).tag(sync=True)
# UI controls
scale = traitlets.Float(1.0).tag(sync=True)
selected_points = traitlets.List().tag(sync=True)
dot_size = traitlets.Int(12).tag(sync=True)
max_points = traitlets.Int(10).tag(sync=True)
# Marker styling (advanced)
marker_border = traitlets.Int(2).tag(sync=True)
marker_opacity = traitlets.Float(1.0).tag(sync=True)
label_size = traitlets.Int(0).tag(sync=True)
label_color = traitlets.Unicode("").tag(sync=True)
# Scale bar
pixel_size = traitlets.Float(0.0).tag(sync=True)
# Active marker selection (synced for state portability)
marker_shape = traitlets.Unicode("circle").tag(sync=True)
marker_color = traitlets.Unicode("#f44336").tag(sync=True)
# Snap-to-peak (synced for state portability)
snap_enabled = traitlets.Bool(False).tag(sync=True)
snap_radius = traitlets.Int(5).tag(sync=True)
# ROI overlays (synced to JS)
roi_list = traitlets.List().tag(sync=True)
# Line profile
profile_line = traitlets.List(traitlets.Dict()).tag(sync=True)
# Display options
title = traitlets.Unicode("").tag(sync=True)
show_stats = traitlets.Bool(True).tag(sync=True)
# Colormap and contrast (synced for state portability)
cmap = traitlets.Unicode("gray").tag(sync=True)
auto_contrast = traitlets.Bool(True).tag(sync=True)
log_scale = traitlets.Bool(False).tag(sync=True)
show_fft = traitlets.Bool(False).tag(sync=True)
fft_window = traitlets.Bool(True).tag(sync=True)
# Canvas sizing
canvas_size = traitlets.Int(0).tag(sync=True)
# Control visibility
show_controls = traitlets.Bool(True).tag(sync=True)
# Optional UI/tool lockout for shared notebooks
disabled_tools = traitlets.List(traitlets.Unicode()).tag(sync=True)
hidden_tools = traitlets.List(traitlets.Unicode()).tag(sync=True)
@classmethod
def _normalize_tool_groups(cls, tool_groups) -> List[str]:
"""Validate and normalize tool group values with stable ordering."""
return normalize_tool_groups("Mark2D", tool_groups)
@classmethod
def _build_disabled_tools(
cls,
disabled_tools=None,
disable_points: bool = False,
disable_roi: bool = False,
disable_profile: bool = False,
disable_display: bool = False,
disable_marker_style: bool = False,
disable_snap: bool = False,
disable_navigation: bool = False,
disable_view: bool = False,
disable_export: bool = False,
disable_all: bool = False,
) -> List[str]:
"""Build disabled_tools from explicit list and ergonomic boolean flags."""
return build_tool_groups(
"Mark2D",
tool_groups=disabled_tools,
all_flag=disable_all,
flag_map={
"points": disable_points,
"roi": disable_roi,
"profile": disable_profile,
"display": disable_display,
"marker_style": disable_marker_style,
"snap": disable_snap,
"navigation": disable_navigation,
"view": disable_view,
"export": disable_export,
},
)
@classmethod
def _build_hidden_tools(
cls,
hidden_tools=None,
hide_points: bool = False,
hide_roi: bool = False,
hide_profile: bool = False,
hide_display: bool = False,
hide_marker_style: bool = False,
hide_snap: bool = False,
hide_navigation: bool = False,
hide_view: bool = False,
hide_export: bool = False,
hide_all: bool = False,
) -> List[str]:
"""Build hidden_tools from explicit list and ergonomic boolean flags."""
return build_tool_groups(
"Mark2D",
tool_groups=hidden_tools,
all_flag=hide_all,
flag_map={
"points": hide_points,
"roi": hide_roi,
"profile": hide_profile,
"display": hide_display,
"marker_style": hide_marker_style,
"snap": hide_snap,
"navigation": hide_navigation,
"view": hide_view,
"export": hide_export,
},
)
@traitlets.validate("disabled_tools")
def _validate_disabled_tools(self, proposal):
return self._normalize_tool_groups(proposal["value"])
@traitlets.validate("hidden_tools")
def _validate_hidden_tools(self, proposal):
return self._normalize_tool_groups(proposal["value"])
# Percentile clipping
percentile_low = traitlets.Float(2.0).tag(sync=True)
percentile_high = traitlets.Float(98.0).tag(sync=True)
# Per-image statistics
stats_mean = traitlets.List(traitlets.Float()).tag(sync=True)
stats_min = traitlets.List(traitlets.Float()).tag(sync=True)
stats_max = traitlets.List(traitlets.Float()).tag(sync=True)
stats_std = traitlets.List(traitlets.Float()).tag(sync=True)
def __init__(
self,
data,
scale: float = 1.0,
dot_size: int = 12,
max_points: int = 10,
ncols: int = 3,
labels: Optional[List[str]] = None,
marker_border: int = 2,
marker_opacity: float = 1.0,
label_size: int = 0,
label_color: str = "",
pixel_size: float = 0.0,
points=None,
marker_shape: str = "circle",
marker_color: str = "#f44336",
snap_enabled: bool = False,
snap_radius: int = 5,
title: str = "",
show_stats: bool = True,
cmap: str = "gray",
auto_contrast: bool = True,
log_scale: bool = False,
show_fft: bool = False,
fft_window: bool = True,
canvas_size: int = 0,
show_controls: bool = True,
disabled_tools: Optional[List[str]] = None,
disable_points: bool = False,
disable_roi: bool = False,
disable_profile: bool = False,
disable_display: bool = False,
disable_marker_style: bool = False,
disable_snap: bool = False,
disable_navigation: bool = False,
disable_view: bool = False,
disable_export: bool = False,
disable_all: bool = False,
hidden_tools: Optional[List[str]] = None,
hide_points: bool = False,
hide_roi: bool = False,
hide_profile: bool = False,
hide_display: bool = False,
hide_marker_style: bool = False,
hide_snap: bool = False,
hide_navigation: bool = False,
hide_view: bool = False,
hide_export: bool = False,
hide_all: bool = False,
percentile_low: float = 2.0,
percentile_high: float = 98.0,
state=None,
**kwargs,
):
super().__init__(**kwargs)
self.widget_version = resolve_widget_version()
self.show_stats = show_stats
self.scale = scale
self.dot_size = dot_size
self.max_points = max_points
self.ncols = ncols
self.marker_border = marker_border
self.marker_opacity = marker_opacity
self.label_size = label_size
self.label_color = label_color
self.marker_shape = marker_shape
self.marker_color = marker_color
self.snap_enabled = snap_enabled
self.snap_radius = snap_radius
self.cmap = cmap
self.auto_contrast = auto_contrast
self.log_scale = log_scale
self.show_fft = show_fft
self.fft_window = fft_window
self.canvas_size = canvas_size
self.show_controls = show_controls
self.disabled_tools = self._build_disabled_tools(
disabled_tools=disabled_tools,
disable_points=disable_points,
disable_roi=disable_roi,
disable_profile=disable_profile,
disable_display=disable_display,
disable_marker_style=disable_marker_style,
disable_snap=disable_snap,
disable_navigation=disable_navigation,
disable_view=disable_view,
disable_export=disable_export,
disable_all=disable_all,
)
self.hidden_tools = self._build_hidden_tools(
hidden_tools=hidden_tools,
hide_points=hide_points,
hide_roi=hide_roi,
hide_profile=hide_profile,
hide_display=hide_display,
hide_marker_style=hide_marker_style,
hide_snap=hide_snap,
hide_navigation=hide_navigation,
hide_view=hide_view,
hide_export=hide_export,
hide_all=hide_all,
)
self.percentile_low = percentile_low
self.percentile_high = percentile_high
# Check if data is an IOResult and extract metadata
if isinstance(data, IOResult):
if not title and data.title:
title = data.title
if pixel_size == 0.0 and data.pixel_size is not None:
pixel_size = data.pixel_size
if labels is None and data.labels:
labels = data.labels
data = data.data
self._set_data(data, labels)
# Explicit overrides take priority over Dataset metadata
if title:
self.title = title
if pixel_size != 0.0:
self.pixel_size = pixel_size
if points is not None:
self.selected_points = self._normalize_points(points)
if state is not None:
if isinstance(state, (str, pathlib.Path)):
state = unwrap_state_payload(
json.loads(pathlib.Path(state).read_text()),
require_envelope=True,
)
else:
state = unwrap_state_payload(state)
self.load_state_dict(state)
def _set_data(self, data, labels=None):
# Check if data is a Dataset2d and extract metadata
if hasattr(data, "array") and hasattr(data, "name") and hasattr(data, "sampling"):
if data.name:
self.title = data.name
if hasattr(data, "units"):
units = list(data.units)
sampling_val = float(data.sampling[-1])
if units[-1] in ("nm",):
self.pixel_size = sampling_val * 10 # nm → Å
elif units[-1] in ("Å", "angstrom", "A"):
self.pixel_size = sampling_val
data = data.array
if isinstance(data, list):
images = [to_numpy(d) for d in data]
for img in images:
if img.ndim != 2:
raise ValueError("Each image in the list must be 2D (H, W).")
shapes = [img.shape for img in images]
if len(set(shapes)) > 1:
max_h = max(s[0] for s in shapes)
max_w = max(s[1] for s in shapes)
images = [_resize_image(img, max_h, max_w) for img in images]
arr = np.stack(images).astype(np.float32)
else:
arr = to_numpy(data).astype(np.float32)
if arr.ndim == 2:
arr = arr[np.newaxis, ...]
elif arr.ndim == 3:
pass # (N, H, W) gallery
else:
raise ValueError("Expected 2D (H,W) or 3D (N,H,W) array, or list of 2D arrays.")
self._data = arr
n, h, w = arr.shape
self.n_images = n
self.height = h
self.width = w
# Per-image min/max and statistics
mins, maxs, means, stds = [], [], [], []
for i in range(n):
frame = arr[i]
mins.append(float(frame.min()))
maxs.append(float(frame.max()))
means.append(float(frame.mean()))
stds.append(float(frame.std()))
self.img_min = mins
self.img_max = maxs
self.stats_mean = means
self.stats_min = mins
self.stats_max = maxs
self.stats_std = stds
# Labels
if labels is not None:
self.labels = list(labels)
else:
self.labels = [f"Image {i + 1}" for i in range(n)]
# Frame bytes (raw float32, all images concatenated)
self.frame_bytes = arr.tobytes()
# Reset points
if n == 1:
self.selected_points = []
else:
self.selected_points = [[] for _ in range(n)]
self.selected_idx = 0
def _normalize_point(self, p, idx):
if isinstance(p, dict):
return {
"row": int(p["row"]),
"col": int(p["col"]),
"shape": p.get("shape", _MARKER_SHAPES[idx % len(_MARKER_SHAPES)]),
"color": p.get("color", _MARKER_COLORS[idx % len(_MARKER_COLORS)]),
}
if isinstance(p, (list, tuple)) and len(p) == 2:
return {
"row": int(p[0]),
"col": int(p[1]),
"shape": _MARKER_SHAPES[idx % len(_MARKER_SHAPES)],
"color": _MARKER_COLORS[idx % len(_MARKER_COLORS)],
}
raise ValueError(f"Invalid point format: {p}")
def _normalize_point_list(self, pts):
if isinstance(pts, np.ndarray):
if pts.ndim == 2 and pts.shape[1] == 2:
return [self._normalize_point((int(pts[i, 0]), int(pts[i, 1])), i)
for i in range(pts.shape[0])]
raise ValueError(f"Expected (N, 2) array, got shape {pts.shape}")
return [self._normalize_point(p, i) for i, p in enumerate(pts)]
def _normalize_points(self, raw_points):
if self.n_images == 1:
return self._normalize_point_list(raw_points)
# Gallery: expect list of point lists, one per image
if not isinstance(raw_points, (list, tuple)):
raise ValueError("Gallery mode requires list of point lists")
if len(raw_points) != self.n_images:
raise ValueError(
f"Expected {self.n_images} point lists, got {len(raw_points)}"
)
return [self._normalize_point_list(pts) for pts in raw_points]
[docs]
def set_image(self, data, labels=None):
"""
Replace the displayed image(s) and reset all points.
Can switch between single-image and gallery modes. All existing
points are cleared; ROIs are preserved.
Parameters
----------
data : array_like
2D array ``(H, W)``, 3D array ``(N, H, W)``, or list of 2D
arrays. Same formats as the constructor.
labels : list of str, optional
Per-image labels for gallery mode. If ``None``, defaults to
``"Image 1"``, ``"Image 2"``, etc.
Examples
--------
>>> w = Mark2D(img1)
>>> w.set_image(img2) # switch to a different image
>>> w.set_image([img1, img2, img3], labels=["A", "B", "C"])
"""
self._set_data(data, labels)
[docs]
def add_roi(self, row, col, shape="square", radius=30, width=60, height=40,
color="#0f0", opacity=0.8):
"""
Add an ROI overlay to the widget.
Multiple ROIs can be added. Each gets a unique ID. In the widget,
the user can click ROI centers to select them, drag to reposition,
and adjust size/color/opacity interactively. The widget also displays
pixel statistics (mean, std, min, max) for the active ROI.
Parameters
----------
row, col : int
Center position in image pixel coordinates (row, col).
shape : str, default "circle"
ROI shape. One of ``"circle"``, ``"square"``, or ``"rectangle"``.
radius : int, default 30
Radius in pixels for circle and square modes.
width : int, default 60
Width in pixels for rectangle mode.
height : int, default 40
Height in pixels for rectangle mode.
color : str, default "#0f0"
Stroke color as a CSS color string (e.g. ``"#ff0"``, ``"red"``).
opacity : float, default 0.8
Stroke opacity (0.1–1.0).
Examples
--------
>>> w = Mark2D(img)
>>> w.add_roi(128, 128) # green circle at center
>>> w.add_roi(50, 50, shape="square", radius=20, color="#ff0")
>>> w.add_roi(200, 100, shape="rectangle", width=80, height=30)
>>> len(w.roi_list) # 3
"""
roi_id = max((r["id"] for r in self.roi_list), default=-1) + 1
roi = {
"id": roi_id,
"shape": shape,
"row": int(row), "col": int(col),
"radius": int(radius),
"width": int(width), "height": int(height),
"color": color, "opacity": float(opacity),
}
self.roi_list = [*self.roi_list, roi]
[docs]
def clear_rois(self):
"""
Remove all ROI overlays.
Examples
--------
>>> w.add_roi(100, 100)
>>> w.add_roi(200, 200)
>>> w.clear_rois()
>>> w.roi_list # []
"""
self.roi_list = []
def _resolve_roi(self, index: Optional[int] = None, roi_id: Optional[int] = None):
"""Resolve one ROI by index or id (defaults to the most recently added ROI)."""
if index is not None and roi_id is not None:
raise ValueError("Pass either index or roi_id, not both.")
if not self.roi_list:
raise ValueError("No ROIs are defined.")
if roi_id is not None:
target_id = int(roi_id)
for roi in self.roi_list:
if int(roi.get("id", -1)) == target_id:
return roi
raise ValueError(f"ROI id {roi_id} not found.")
idx = -1 if index is None else int(index)
try:
return self.roi_list[idx]
except IndexError as exc:
raise IndexError(
f"ROI index {idx} out of range for {len(self.roi_list)} ROIs."
) from exc
[docs]
def roi_center(self, index: Optional[int] = None, roi_id: Optional[int] = None):
"""
Return ROI center coordinates as ``(row, col)``.
By default, returns the center of the most recently added ROI.
Parameters
----------
index : int, optional
ROI list index to query. Supports negative indexing.
roi_id : int, optional
ROI ``id`` value to query. Mutually exclusive with ``index``.
Returns
-------
tuple of int
Center point ``(row, col)``.
"""
roi = self._resolve_roi(index=index, roi_id=roi_id)
return int(roi["row"]), int(roi["col"])
[docs]
def roi_radius(self, index: Optional[int] = None, roi_id: Optional[int] = None):
"""
Return ROI radius for ``circle``/``square`` ROIs.
For ``rectangle`` ROIs, returns ``None`` (use ``roi_size()`` for
rectangle width/height).
By default, queries the most recently added ROI.
Parameters
----------
index : int, optional
ROI list index to query. Supports negative indexing.
roi_id : int, optional
ROI ``id`` value to query. Mutually exclusive with ``index``.
Returns
-------
int or None
Radius in pixels for circle/square ROIs, otherwise ``None``.
"""
roi = self._resolve_roi(index=index, roi_id=roi_id)
if roi.get("shape") == "rectangle":
return None
return int(roi["radius"])
[docs]
def roi_size(self, index: Optional[int] = None, roi_id: Optional[int] = None):
"""
Return shape-aware ROI size information.
- ``circle`` / ``square`` -> ``{"shape", "radius", "diameter"}``
- ``rectangle`` -> ``{"shape", "width", "height"}``
Parameters
----------
index : int, optional
ROI list index to query. Supports negative indexing.
roi_id : int, optional
ROI ``id`` value to query. Mutually exclusive with ``index``.
Returns
-------
dict
Shape-aware size dictionary for the selected ROI.
"""
roi = self._resolve_roi(index=index, roi_id=roi_id)
shape = str(roi.get("shape", "circle"))
if shape == "rectangle":
return {
"shape": shape,
"width": int(roi["width"]),
"height": int(roi["height"]),
}
radius = int(roi["radius"])
return {
"shape": shape,
"radius": radius,
"diameter": 2 * radius,
}
def _sample_profile(self, row0, col0, row1, col1):
"""Sample intensity values along a line using bilinear interpolation."""
idx = self.selected_idx if self.n_images > 1 else 0
img = self._data[idx]
h, w = img.shape
dc, dr = col1 - col0, row1 - row0
length = (dc**2 + dr**2) ** 0.5
n = max(2, int(np.ceil(length)))
t = np.linspace(0, 1, n)
cs = col0 + t * dc
rs = row0 + t * dr
ci = np.floor(cs).astype(int)
ri = np.floor(rs).astype(int)
cf = cs - ci
rf = rs - ri
c0c = np.clip(ci, 0, w - 1)
c1c = np.clip(ci + 1, 0, w - 1)
r0c = np.clip(ri, 0, h - 1)
r1c = np.clip(ri + 1, 0, h - 1)
vals = (img[r0c, c0c] * (1 - cf) * (1 - rf) +
img[r0c, c1c] * cf * (1 - rf) +
img[r1c, c0c] * (1 - cf) * rf +
img[r1c, c1c] * cf * rf)
return vals.astype(np.float32)
[docs]
def set_profile(self, start: tuple, end: tuple):
"""
Set a line profile between two points.
The profile is drawn on the canvas and intensity values are sampled
along the line with bilinear interpolation. A sparkline graph appears
below the canvas.
Parameters
----------
start : tuple of (row, col)
Start point in image pixel coordinates.
end : tuple of (row, col)
End point in image pixel coordinates.
Examples
--------
>>> w = Mark2D(img, pixel_size=0.82)
>>> w.set_profile((10, 20), (100, 200))
>>> w.profile_values # sampled intensities along the line
"""
row0, col0 = start
row1, col1 = end
self.profile_line = [
{"row": float(row0), "col": float(col0)},
{"row": float(row1), "col": float(col1)},
]
[docs]
def clear_profile(self):
"""Clear the current line profile."""
self.profile_line = []
@property
def profile(self):
"""
Get profile line endpoints as ``[(row0, col0), (row1, col1)]`` or ``[]``.
Returns
-------
list of tuple
Line endpoints in pixel coordinates, or empty list if no profile.
"""
return [(p["row"], p["col"]) for p in self.profile_line]
@property
def profile_values(self):
"""
Get intensity values along the profile line as a numpy array.
Returns
-------
np.ndarray or None
Float32 array of sampled intensities, or ``None`` if no profile.
"""
if len(self.profile_line) < 2:
return None
p0, p1 = self.profile_line
return self._sample_profile(p0["row"], p0["col"], p1["row"], p1["col"])
@property
def profile_distance(self):
"""
Get total distance of the profile line in calibrated units.
Returns
-------
float or None
Distance in angstroms (if ``pixel_size > 0``) or pixels.
``None`` if no profile line is set.
"""
if len(self.profile_line) < 2:
return None
p0, p1 = self.profile_line
dc = p1["col"] - p0["col"]
dr = p1["row"] - p0["row"]
dist_px = (dc**2 + dr**2) ** 0.5
if self.pixel_size > 0:
return dist_px * self.pixel_size
return dist_px
def __repr__(self) -> str:
is_gallery = self.n_images > 1
# Points count
if is_gallery:
per_img = [len(pts) if isinstance(pts, list) else 0
for pts in self.selected_points]
pts_str = "+".join(str(n) for n in per_img)
total = sum(per_img)
else:
total = len(self.selected_points)
pts_str = str(total)
# Shape string
if is_gallery:
shape = f"{self.n_images}×{self.height}×{self.width}"
else:
shape = f"{self.height}×{self.width}"
name = self.title if self.title else "Mark2D"
parts = [f"{name}({shape}"]
if is_gallery:
parts.append(f"idx={self.selected_idx}")
# Pixel size
if self.pixel_size > 0:
ps = self.pixel_size
if ps >= 10:
parts.append(f"px={ps / 10:.2f} nm")
else:
parts.append(f"px={ps:.2f} Å")
# Points
if total > 0:
parts.append(f"pts={pts_str}")
else:
parts.append("pts=0")
# ROIs
n_rois = len(self.roi_list)
if n_rois > 0:
parts.append(f"rois={n_rois}")
# Non-default imaging settings
if self.cmap != "gray":
parts.append(f"cmap={self.cmap}")
if self.log_scale:
parts.append("log")
if not self.auto_contrast:
parts.append("manual contrast")
if self.show_fft:
parts.append("fft")
if self.snap_enabled:
parts.append(f"snap r={self.snap_radius}")
return ", ".join(parts) + ")"
[docs]
def summary(self):
"""
Print a detailed summary of the widget state.
Shows image info, display settings, points with coordinates,
ROI details, and marker configuration.
Examples
--------
>>> w = Mark2D(img, points=[(10, 20), (30, 40)],
... pixel_size=0.82, cmap='viridis')
>>> w.summary()
Mark2D
═══════════════════════════════
Image: 128×128 (0.82 Å/px)
Display: viridis | auto contrast | linear
...
"""
is_gallery = self.n_images > 1
name = self.title if self.title else "Mark2D"
lines = [name, "═" * 32]
# Image info
if is_gallery:
shape = f"{self.n_images}×{self.height}×{self.width}"
lines.append(f"Image: {shape} ({self.ncols} cols)")
else:
shape = f"{self.height}×{self.width}"
lines.append(f"Image: {shape}")
if self.pixel_size > 0:
ps = self.pixel_size
if ps >= 10:
lines[-1] += f" ({ps / 10:.2f} nm/px)"
else:
lines[-1] += f" ({ps:.2f} Å/px)"
if self.scale != 1.0:
lines[-1] += f" scale={self.scale}x"
# Data range
if hasattr(self, "_data") and self._data is not None:
arr = self._data
lines.append(f"Data: min={float(arr.min()):.4g} max={float(arr.max()):.4g} mean={float(arr.mean()):.4g} dtype={arr.dtype}")
# Display settings
cmap = self.cmap
scale = "log" if self.log_scale else "linear"
contrast = "auto contrast" if self.auto_contrast else "manual contrast"
display = f"{cmap} | {contrast} | {scale}"
if self.show_fft:
display += " | FFT"
if not self.fft_window:
display += " (no window)"
lines.append(f"Display: {display}")
# Point formatting helper
def _fmt_point(j, p, prev=None):
color = _color_name(p.get("color", ""))
coord = f" {j + 1}. ({p['row']}, {p['col']}) {p.get('shape', 'circle')} {color}"
if prev is not None:
dr, dc = p["row"] - prev["row"], p["col"] - prev["col"]
dist = (dr * dr + dc * dc) ** 0.5
if self.pixel_size > 0:
phys = dist * self.pixel_size
if phys >= 10:
coord += f" ↔ {phys / 10:.2f} nm"
else:
coord += f" ↔ {phys:.2f} Å"
else:
coord += f" ↔ {dist:.1f} px"
return coord
# Points
if is_gallery:
for i in range(self.n_images):
pts = self.selected_points[i] if i < len(self.selected_points) else []
label = self.labels[i] if i < len(self.labels) else f"Image {i + 1}"
lines.append(f"Points [{label}]: {len(pts)}/{self.max_points}")
for j, p in enumerate(pts):
lines.append(_fmt_point(j, p, pts[j - 1] if j > 0 else None))
else:
pts = self.selected_points
lines.append(f"Points: {len(pts)}/{self.max_points}")
for j, p in enumerate(pts):
lines.append(_fmt_point(j, p, pts[j - 1] if j > 0 else None))
# ROIs
if self.roi_list:
lines.append(f"ROIs: {len(self.roi_list)}")
for roi in self.roi_list:
mode = roi["shape"]
pos = f"({roi['row']}, {roi['col']})"
if mode == "rectangle":
size = f"{roi['width']}×{roi['height']}"
area_px = roi["width"] * roi["height"]
elif mode == "circle":
size = f"r={roi['radius']}"
area_px = 3.14159265 * roi["radius"] ** 2
else: # square
size = f"r={roi['radius']}"
area_px = (2 * roi["radius"]) ** 2
color = _color_name(roi["color"])
if self.pixel_size > 0:
ps = self.pixel_size
area_phys = area_px * ps * ps
if area_phys >= 100:
area_str = f" area={area_phys / 100:.1f} nm²"
else:
area_str = f" area={area_phys:.1f} Ų"
else:
area_str = f" area={area_px:.0f} px²"
lines.append(f" {roi['id']+1}. {mode} at {pos} {size} {color}{area_str}")
# Marker settings
color = _color_name(self.marker_color)
marker = f"{self.marker_shape} {color} size={self.dot_size}px"
if self.marker_border != 2:
marker += f" border={self.marker_border}"
if self.marker_opacity != 1.0:
marker += f" opacity={self.marker_opacity:.0%}"
lines.append(f"Marker: {marker}")
# Snap
if self.snap_enabled:
lines.append(f"Snap: ON (radius={self.snap_radius} px)")
if self.disabled_tools:
lines.append(f"Locked: {', '.join(self.disabled_tools)}")
if self.hidden_tools:
lines.append(f"Hidden: {', '.join(self.hidden_tools)}")
print("\n".join(lines))
[docs]
def points_as_array(self):
"""
Return placed points as a NumPy array of shape ``(N, 2)`` with columns ``[row, col]``.
In gallery mode, returns a list of arrays (one per image).
Examples
--------
>>> w = Mark2D(img, points=[(10, 20), (30, 40)])
>>> w.points_as_array()
array([[10, 20],
[30, 40]])
"""
if self.n_images > 1:
result = []
for pts in self.selected_points:
if pts:
result.append(np.array([[p["row"], p["col"]] for p in pts], dtype=np.float64))
else:
result.append(np.empty((0, 2), dtype=np.float64))
return result
pts = self.selected_points
if not pts:
return np.empty((0, 2), dtype=np.float64)
return np.array([[p["row"], p["col"]] for p in pts], dtype=np.float64)
[docs]
def points_as_dict(self):
"""
Return placed points as a list of ``{"row": int, "col": int}`` dicts.
In gallery mode, returns a list of lists (one per image).
Examples
--------
>>> w = Mark2D(img, points=[(10, 20), (30, 40)])
>>> w.points_as_dict()
[{'row': 10, 'col': 20}, {'row': 30, 'col': 40}]
"""
if self.n_images > 1:
return [
[{"row": p["row"], "col": p["col"]} for p in pts]
for pts in self.selected_points
]
return [{"row": p["row"], "col": p["col"]} for p in self.selected_points]
[docs]
def clear_points(self):
"""
Remove all placed points from all images.
Examples
--------
>>> w.clear_points()
>>> w.selected_points # [] or [[], [], ...]
"""
if self.n_images == 1:
self.selected_points = []
else:
self.selected_points = [[] for _ in range(self.n_images)]
@property
def points_enabled(self) -> bool:
"""
Whether adding/editing points is enabled via ``disabled_tools``.
This convenience toggle controls the ``"points"`` lock in
``disabled_tools``. It does not modify ``hidden_tools``.
"""
disabled = {str(t).strip().lower() for t in self.disabled_tools}
return "all" not in disabled and "points" not in disabled
@points_enabled.setter
def points_enabled(self, enabled: bool):
enabled = bool(enabled)
disabled = [str(t).strip().lower() for t in self.disabled_tools]
if enabled:
if "all" in disabled:
raise ValueError(
"Cannot enable points while disabled_tools contains 'all'. "
"Remove 'all' first."
)
if "points" in disabled:
self.disabled_tools = [t for t in disabled if t != "points"]
return
if "all" in disabled or "points" in disabled:
return
self.disabled_tools = [*disabled, "points"]
def _normalize_frame(self, frame: np.ndarray) -> np.ndarray:
if self.log_scale:
frame = np.log1p(np.maximum(frame, 0))
if self.auto_contrast:
vmin = float(np.percentile(frame, self.percentile_low))
vmax = float(np.percentile(frame, self.percentile_high))
else:
vmin = float(frame.min())
vmax = float(frame.max())
if vmax > vmin:
normalized = np.clip((frame - vmin) / (vmax - vmin) * 255, 0, 255)
return normalized.astype(np.uint8)
return np.zeros(frame.shape, dtype=np.uint8)
[docs]
def save_image(
self,
path: str | pathlib.Path,
*,
idx: int | None = None,
include_markers: bool = True,
format: str | None = None,
dpi: int = 150,
) -> pathlib.Path:
"""Save current image as PNG or PDF, optionally with marker overlays.
Parameters
----------
path : str or pathlib.Path
Output file path.
idx : int, optional
Image index in gallery mode. Defaults to current selected_idx.
include_markers : bool, default True
If True, render marker points on the exported image.
format : str, optional
'png' or 'pdf'. If omitted, inferred from file extension.
dpi : int, default 150
Output DPI metadata.
Returns
-------
pathlib.Path
The written file path.
"""
from matplotlib import colormaps
from PIL import Image, ImageDraw
path = pathlib.Path(path)
fmt = (format or path.suffix.lstrip(".").lower() or "png").lower()
if fmt not in ("png", "pdf", "tiff", "tif"):
raise ValueError(f"Unsupported format: {fmt!r}. Use 'png', 'pdf', or 'tiff'.")
i = idx if idx is not None else self.selected_idx
if i < 0 or i >= self.n_images:
raise IndexError(f"Image index {i} out of range [0, {self.n_images})")
frame = self._data[i]
normalized = self._normalize_frame(frame)
cmap_fn = colormaps.get_cmap(self.cmap)
rgba = (cmap_fn(normalized / 255.0) * 255).astype(np.uint8)
img = Image.fromarray(rgba)
if include_markers:
# In gallery mode, points are nested per image; single-image is flat
if self.n_images > 1:
pts = self.selected_points[i] if i < len(self.selected_points) else []
else:
pts = self.selected_points
if pts:
draw = ImageDraw.Draw(img)
r = max(2, self.dot_size // 2)
for pt in pts:
row, col = pt.get("row", 0), pt.get("col", 0)
color = pt.get("color", "#f44336")
draw.ellipse(
[col - r, row - r, col + r, row + r],
fill=color,
outline="white",
)
path.parent.mkdir(parents=True, exist_ok=True)
img.save(str(path), dpi=(dpi, dpi))
return path
[docs]
def state_dict(self):
"""
Return a dict of all restorable widget state.
Use this to persist the widget state across kernel restarts.
Pass the returned dict as the ``state`` parameter to a new
``Mark2D`` to restore everything.
Examples
--------
>>> w = Mark2D(img, pixel_size=1.5)
>>> # ... user places points, adds ROIs, changes settings ...
>>> state = w.state_dict()
>>> # Later (or after kernel restart):
>>> w2 = Mark2D(img, state=state)
"""
return {
"selected_points": self.selected_points,
"roi_list": self.roi_list,
"profile_line": self.profile_line,
"selected_idx": self.selected_idx,
"marker_shape": self.marker_shape,
"marker_color": self.marker_color,
"dot_size": self.dot_size,
"max_points": self.max_points,
"marker_border": self.marker_border,
"marker_opacity": self.marker_opacity,
"label_size": self.label_size,
"label_color": self.label_color,
"snap_enabled": self.snap_enabled,
"snap_radius": self.snap_radius,
"cmap": self.cmap,
"auto_contrast": self.auto_contrast,
"log_scale": self.log_scale,
"show_fft": self.show_fft,
"fft_window": self.fft_window,
"show_stats": self.show_stats,
"show_controls": self.show_controls,
"disabled_tools": self.disabled_tools,
"hidden_tools": self.hidden_tools,
"percentile_low": self.percentile_low,
"percentile_high": self.percentile_high,
"title": self.title,
"pixel_size": self.pixel_size,
"scale": self.scale,
"canvas_size": self.canvas_size,
}
[docs]
def save(self, path: str):
"""
Save widget state to a JSON file.
Parameters
----------
path : str
File path to write (e.g. ``"analysis.json"``).
Examples
--------
>>> w = Mark2D(img)
>>> # ... place points, add ROIs ...
>>> w.save("my_analysis.json")
>>> # After kernel restart:
>>> w2 = Mark2D(img, state="my_analysis.json")
"""
save_state_file(path, "Mark2D", self.state_dict())
[docs]
def load_state_dict(self, state):
"""
Restore widget state from a dict returned by ``state_dict()``.
Parameters
----------
state : dict
State dict from a previous ``state_dict()`` call.
Missing keys are silently skipped.
Examples
--------
>>> state = old_widget.state_dict()
>>> new_widget = Mark2D(img)
>>> new_widget.load_state_dict(state)
"""
COMPAT_MAP = {"colormap": "cmap", "pixel_size_angstrom": "pixel_size"}
ROI_KEY_MAP = {"mode": "shape", "rectW": "width", "rectH": "height"}
for key, val in state.items():
key = COMPAT_MAP.get(key, key)
if key == "roi_list" and isinstance(val, list):
val = [
{ROI_KEY_MAP.get(k, k): v for k, v in roi.items()}
for roi in val
]
if hasattr(self, key):
setattr(self, key, val)
bind_tool_runtime_api(Mark2D, "Mark2D")