Source code for quantem.widget.edit2d

"""
edit2d: Interactive crop/pad/mask tool for 2D images.

Visually define a rectangular output region on a 2D image.
Region inside image bounds crops; region outside pads.
Mask mode allows painting a binary mask on the image.
"""

import json
import pathlib
from typing import Optional, Union, List, Tuple

import anywidget
import numpy as np
import traitlets

from quantem.widget.array_utils import to_numpy, _resize_image
from quantem.widget.io import IOResult
from quantem.widget.json_state import resolve_widget_version, save_state_file, unwrap_state_payload
from quantem.widget.tool_parity import (
    bind_tool_runtime_api,
    build_tool_groups,
    normalize_tool_groups,
)


[docs] class Edit2D(anywidget.AnyWidget): """ Interactive visual crop/pad tool for 2D images. Display a 2D image with a draggable crop rectangle. The rectangle can be positioned anywhere -- inside the image for cropping, extending beyond image bounds for padding, or fully enclosing the image for pure padding. Parameters ---------- data : array_like 2D array (height, width) for a single image, or 3D array (N, height, width) or list of 2D arrays for multi-image mode. All images are cropped with the same region. bounds : tuple of int, optional Initial crop bounds as (top, left, bottom, right) in image pixel coordinates. Negative values and values exceeding image dimensions are allowed for padding. If None, defaults to the full image extent. fill_value : float, default 0.0 Fill value for padded regions outside the original image bounds. title : str, default "" Title displayed in the widget header. cmap : str, default "gray" Colormap name. pixel_size : float, default 0.0 Pixel size in angstroms for scale bar display. show_stats : bool, default True Show statistics bar. show_controls : bool, default True Show control row. show_display_controls : bool, default True Show display control group. show_edit_controls : bool, default True Show edit control group. show_histogram : bool, default True Show histogram control group. log_scale : bool, default False Log intensity mapping. auto_contrast : bool, default True Percentile-based contrast. disabled_tools : list of str, optional Tool groups to disable in the frontend UI/interaction layer. Supported values: ``"mode"``, ``"edit"``, ``"display"``, ``"histogram"``, ``"stats"``, ``"navigation"``, ``"export"``, ``"view"``, ``"all"``. disable_* : bool, optional Convenience flags (``disable_mode``, ``disable_edit``, ``disable_display``, ``disable_histogram``, ``disable_stats``, ``disable_navigation``, ``disable_export``, ``disable_view``, ``disable_all``) equivalent to including those tool names in ``disabled_tools``. hidden_tools : list of str, optional Tool groups to hide from the frontend UI. Hidden tools are also interaction-locked (equivalent to disabled for behavior). hide_* : bool, optional Convenience flags (``hide_mode``, ``hide_edit``, ``hide_display``, ``hide_histogram``, ``hide_stats``, ``hide_navigation``, ``hide_export``, ``hide_view``, ``hide_all``) equivalent to including those tool names in ``hidden_tools``. Examples -------- >>> import numpy as np >>> from quantem.widget import Edit2D >>> img = np.random.rand(256, 256).astype(np.float32) >>> crop = Edit2D(img) >>> crop # display, draw crop region interactively >>> crop.result # returns cropped NumPy array >>> crop.crop_bounds # (top, left, bottom, right) tuple """ _esm = pathlib.Path(__file__).parent / "static" / "edit2d.js" _css = pathlib.Path(__file__).parent / "static" / "edit2d.css" # ========================================================================= # Core State # ========================================================================= n_images = traitlets.Int(1).tag(sync=True) height = traitlets.Int(1).tag(sync=True) width = traitlets.Int(1).tag(sync=True) frame_bytes = traitlets.Bytes(b"").tag(sync=True) labels = traitlets.List(traitlets.Unicode()).tag(sync=True) title = traitlets.Unicode("").tag(sync=True) cmap = traitlets.Unicode("gray").tag(sync=True) # ========================================================================= # Crop Region (synced bidirectionally with JS) # ========================================================================= crop_top = traitlets.Int(0).tag(sync=True) crop_left = traitlets.Int(0).tag(sync=True) crop_bottom = traitlets.Int(0).tag(sync=True) crop_right = traitlets.Int(0).tag(sync=True) fill_value = traitlets.Float(0.0).tag(sync=True) # ========================================================================= # Display Options # ========================================================================= log_scale = traitlets.Bool(False).tag(sync=True) auto_contrast = traitlets.Bool(True).tag(sync=True) # ========================================================================= # Scale Bar # ========================================================================= pixel_size = traitlets.Float(0.0).tag(sync=True) # ========================================================================= # UI Visibility # ========================================================================= show_controls = traitlets.Bool(True).tag(sync=True) show_stats = traitlets.Bool(True).tag(sync=True) show_display_controls = traitlets.Bool(True).tag(sync=True) show_edit_controls = traitlets.Bool(True).tag(sync=True) show_histogram = traitlets.Bool(True).tag(sync=True) disabled_tools = traitlets.List(traitlets.Unicode()).tag(sync=True) hidden_tools = traitlets.List(traitlets.Unicode()).tag(sync=True) stats_mean = traitlets.Float(0.0).tag(sync=True) stats_min = traitlets.Float(0.0).tag(sync=True) stats_max = traitlets.Float(0.0).tag(sync=True) stats_std = traitlets.Float(0.0).tag(sync=True) # ========================================================================= # Mode: "crop" or "mask" # ========================================================================= mode = traitlets.Unicode("crop").tag(sync=True) # ========================================================================= # Mask State # ========================================================================= mask_bytes = traitlets.Bytes(b"").tag(sync=True) mask_tool = traitlets.Unicode("rectangle").tag(sync=True) mask_action = traitlets.Unicode("add").tag(sync=True) # ========================================================================= # Gallery (multi-image) # ========================================================================= selected_idx = traitlets.Int(0).tag(sync=True) # ========================================================================= # Shared / Independent editing # ========================================================================= shared = traitlets.Bool(True).tag(sync=True) per_image_crops_json = traitlets.Unicode("[]").tag(sync=True) per_image_masks_bytes = traitlets.Bytes(b"").tag(sync=True) @classmethod def _normalize_tool_groups(cls, tool_groups) -> List[str]: """Validate and normalize tool group values with stable ordering.""" return normalize_tool_groups("Edit2D", tool_groups) @classmethod def _build_disabled_tools( cls, disabled_tools=None, disable_mode: bool = False, disable_edit: bool = False, disable_display: bool = False, disable_histogram: bool = False, disable_stats: bool = False, disable_navigation: bool = False, disable_export: bool = False, disable_view: bool = False, disable_all: bool = False, ) -> List[str]: """Build disabled_tools from explicit list and ergonomic boolean flags.""" return build_tool_groups( "Edit2D", tool_groups=disabled_tools, all_flag=disable_all, flag_map={ "mode": disable_mode, "edit": disable_edit, "display": disable_display, "histogram": disable_histogram, "stats": disable_stats, "navigation": disable_navigation, "export": disable_export, "view": disable_view, }, ) @classmethod def _build_hidden_tools( cls, hidden_tools=None, hide_mode: bool = False, hide_edit: bool = False, hide_display: bool = False, hide_histogram: bool = False, hide_stats: bool = False, hide_navigation: bool = False, hide_export: bool = False, hide_view: bool = False, hide_all: bool = False, ) -> List[str]: """Build hidden_tools from explicit list and ergonomic boolean flags.""" return build_tool_groups( "Edit2D", tool_groups=hidden_tools, all_flag=hide_all, flag_map={ "mode": hide_mode, "edit": hide_edit, "display": hide_display, "histogram": hide_histogram, "stats": hide_stats, "navigation": hide_navigation, "export": hide_export, "view": hide_view, }, ) @traitlets.validate("disabled_tools") def _validate_disabled_tools(self, proposal): return self._normalize_tool_groups(proposal["value"]) @traitlets.validate("hidden_tools") def _validate_hidden_tools(self, proposal): return self._normalize_tool_groups(proposal["value"]) def __init__( self, data: Union[np.ndarray, List[np.ndarray]], bounds: Optional[Tuple[int, int, int, int]] = None, fill_value: float = 0.0, mode: str = "crop", shared: bool = True, labels: Optional[List[str]] = None, title: str = "", cmap: str = "gray", pixel_size: float = 0.0, show_controls: bool = True, show_stats: bool = True, show_display_controls: bool = True, show_edit_controls: bool = True, show_histogram: bool = True, log_scale: bool = False, auto_contrast: bool = True, disabled_tools: Optional[List[str]] = None, disable_mode: bool = False, disable_edit: bool = False, disable_display: bool = False, disable_histogram: bool = False, disable_stats: bool = False, disable_navigation: bool = False, disable_export: bool = False, disable_view: bool = False, disable_all: bool = False, hidden_tools: Optional[List[str]] = None, hide_mode: bool = False, hide_edit: bool = False, hide_display: bool = False, hide_histogram: bool = False, hide_stats: bool = False, hide_navigation: bool = False, hide_export: bool = False, hide_view: bool = False, hide_all: bool = False, state=None, **kwargs, ): super().__init__(**kwargs) self.widget_version = resolve_widget_version() self.mode = mode # Check if data is an IOResult and extract metadata if isinstance(data, IOResult): if not title and data.title: title = data.title if pixel_size == 0.0 and data.pixel_size is not None: pixel_size = data.pixel_size if labels is None and data.labels: labels = data.labels data = data.data # Check if data is a Dataset2d and extract metadata if hasattr(data, "array") and hasattr(data, "name") and hasattr(data, "sampling"): if not title and data.name: title = data.name if pixel_size == 0.0 and hasattr(data, "units"): units = list(data.units) sampling_val = float(data.sampling[-1]) if units[-1] in ("nm",): pixel_size = sampling_val * 10 # nm -> angstrom elif units[-1] in ("\u00c5", "angstrom", "A"): pixel_size = sampling_val data = data.array # Convert input to NumPy (handles NumPy, CuPy, PyTorch) if isinstance(data, list): images = [to_numpy(d) for d in data] shapes = [img.shape for img in images] if len(set(shapes)) > 1: max_h = max(s[0] for s in shapes) max_w = max(s[1] for s in shapes) images = [_resize_image(img, max_h, max_w) for img in images] data = np.stack(images) else: data = to_numpy(data) if data.ndim == 2: data = data[np.newaxis, ...] self._data = data.astype(np.float32) self.n_images = int(data.shape[0]) self.height = int(data.shape[1]) self.width = int(data.shape[2]) # Labels if labels is None: if self.n_images == 1: self.labels = ["Image"] else: self.labels = [f"Image {i+1}" for i in range(self.n_images)] else: self.labels = list(labels) # Options self.title = title self.cmap = cmap self.pixel_size = pixel_size self.show_controls = show_controls self.show_stats = show_stats self.show_display_controls = show_display_controls self.show_edit_controls = show_edit_controls self.show_histogram = show_histogram self.log_scale = log_scale self.auto_contrast = auto_contrast self.disabled_tools = self._build_disabled_tools( disabled_tools=disabled_tools, disable_mode=disable_mode, disable_edit=disable_edit, disable_display=disable_display, disable_histogram=disable_histogram, disable_stats=disable_stats, disable_navigation=disable_navigation, disable_export=disable_export, disable_view=disable_view, disable_all=disable_all, ) self.hidden_tools = self._build_hidden_tools( hidden_tools=hidden_tools, hide_mode=hide_mode, hide_edit=hide_edit, hide_display=hide_display, hide_histogram=hide_histogram, hide_stats=hide_stats, hide_navigation=hide_navigation, hide_export=hide_export, hide_view=hide_view, hide_all=hide_all, ) self.fill_value = fill_value # Crop bounds if bounds is not None: self.crop_top, self.crop_left, self.crop_bottom, self.crop_right = bounds else: self.crop_top = 0 self.crop_left = 0 self.crop_bottom = self.height self.crop_right = self.width self.shared = shared if not self.shared and self.n_images > 1: crop = {"top": self.crop_top, "left": self.crop_left, "bottom": self.crop_bottom, "right": self.crop_right} self.per_image_crops_json = json.dumps([crop] * self.n_images) # Compute stats for current image self._compute_stats() # Send raw float32 data to JS self.frame_bytes = self._data.tobytes() self.selected_idx = 0 # State restoration (must be last) if state is not None: if isinstance(state, (str, pathlib.Path)): state = unwrap_state_payload( json.loads(pathlib.Path(state).read_text()), require_envelope=True, ) else: state = unwrap_state_payload(state) self.load_state_dict(state) def _compute_stats(self): img = self._data[self.selected_idx] self.stats_mean = float(np.mean(img)) self.stats_min = float(np.min(img)) self.stats_max = float(np.max(img)) self.stats_std = float(np.std(img)) def _crop_single_with_bounds(self, img, top, left, bottom, right): h, w = img.shape out_h = bottom - top out_w = right - left if out_h <= 0 or out_w <= 0: return np.empty((0, 0), dtype=img.dtype) result = np.full((out_h, out_w), self.fill_value, dtype=img.dtype) src_top = max(0, top) src_left = max(0, left) src_bottom = min(h, bottom) src_right = min(w, right) if src_top >= src_bottom or src_left >= src_right: return result dst_top = src_top - top dst_left = src_left - left result[dst_top:dst_top + (src_bottom - src_top), dst_left:dst_left + (src_right - src_left)] = \ img[src_top:src_bottom, src_left:src_right] return result def _crop_single(self, img: np.ndarray) -> np.ndarray: return self._crop_single_with_bounds( img, self.crop_top, self.crop_left, self.crop_bottom, self.crop_right ) def _apply_mask(self, img: np.ndarray, m: np.ndarray) -> np.ndarray: out = img.copy() out[m] = self.fill_value return out def _get_per_image_crops(self): default = {"top": self.crop_top, "left": self.crop_left, "bottom": self.crop_bottom, "right": self.crop_right} if not self.per_image_crops_json or self.per_image_crops_json == "[]": return [{**default} for _ in range(self.n_images)] crops = json.loads(self.per_image_crops_json) while len(crops) < self.n_images: crops.append({**default}) return crops[:self.n_images] def _get_per_image_masks(self): size = self.height * self.width total = self.n_images * size if not self.per_image_masks_bytes or len(self.per_image_masks_bytes) != total: return [np.zeros((self.height, self.width), dtype=bool) for _ in range(self.n_images)] all_masks = np.frombuffer(self.per_image_masks_bytes, dtype=np.uint8).reshape( self.n_images, self.height, self.width ) return [all_masks[i] > 0 for i in range(self.n_images)] @property def mask(self) -> np.ndarray: """Current mask as a boolean array (H, W). True = masked.""" if not self.shared and self.n_images > 1: masks = self._get_per_image_masks() idx = min(self.selected_idx, self.n_images - 1) return masks[idx] if not self.mask_bytes: return np.zeros((self.height, self.width), dtype=bool) arr = np.frombuffer(self.mask_bytes, dtype=np.uint8).reshape( self.height, self.width ) return arr > 0 @property def result(self) -> Union[np.ndarray, List[np.ndarray]]: """Return result based on current mode. Crop mode: cropped/padded image(s). Mask mode: image(s) with masked pixels set to fill_value. In independent mode (shared=False), each image gets its own crop/mask. """ if self.shared or self.n_images == 1: if self.mode == "mask": m = self.mask if self.n_images == 1: return self._apply_mask(self._data[0], m) return [self._apply_mask(self._data[i], m) for i in range(self.n_images)] if self.n_images == 1: return self._crop_single(self._data[0]) return [self._crop_single(self._data[i]) for i in range(self.n_images)] # Independent mode if self.mode == "mask": masks = self._get_per_image_masks() return [self._apply_mask(self._data[i], masks[i]) for i in range(self.n_images)] crops = self._get_per_image_crops() return [ self._crop_single_with_bounds( self._data[i], c["top"], c["left"], c["bottom"], c["right"] ) for i, c in enumerate(crops) ] @property def crop_bounds(self) -> Tuple[int, int, int, int]: """Current crop bounds as (top, left, bottom, right). In independent mode, returns the current image's bounds. """ if not self.shared and self.n_images > 1: crops = self._get_per_image_crops() idx = min(self.selected_idx, self.n_images - 1) c = crops[idx] return (c["top"], c["left"], c["bottom"], c["right"]) return (self.crop_top, self.crop_left, self.crop_bottom, self.crop_right) @crop_bounds.setter def crop_bounds(self, bounds: Tuple[int, int, int, int]): top, left, bottom, right = bounds if not self.shared and self.n_images > 1: crops = self._get_per_image_crops() idx = min(self.selected_idx, self.n_images - 1) crops[idx] = {"top": top, "left": left, "bottom": bottom, "right": right} self.per_image_crops_json = json.dumps(crops) else: self.crop_top, self.crop_left, self.crop_bottom, self.crop_right = bounds @property def crop_size(self) -> Tuple[int, int]: """Output size as (height, width).""" top, left, bottom, right = self.crop_bounds return (bottom - top, right - left)
[docs] def set_image(self, data, **kwargs): """Replace the image data.""" if hasattr(data, "array") and hasattr(data, "name") and hasattr(data, "sampling"): if "title" not in kwargs and data.name: self.title = data.name data = data.array if isinstance(data, list): images = [to_numpy(d) for d in data] shapes = [img.shape for img in images] if len(set(shapes)) > 1: max_h = max(s[0] for s in shapes) max_w = max(s[1] for s in shapes) images = [_resize_image(img, max_h, max_w) for img in images] data = np.stack(images) else: data = to_numpy(data) if data.ndim == 2: data = data[np.newaxis, ...] self._data = data.astype(np.float32) self.n_images = int(data.shape[0]) self.height = int(data.shape[1]) self.width = int(data.shape[2]) self.crop_top = 0 self.crop_left = 0 self.crop_bottom = self.height self.crop_right = self.width self.mask_bytes = b"" self.per_image_crops_json = "[]" self.per_image_masks_bytes = b"" self._compute_stats() self.frame_bytes = self._data.tobytes()
# ========================================================================= # Export # ========================================================================= def _normalize_frame(self, frame: np.ndarray) -> np.ndarray: if self.log_scale: frame = np.log1p(np.maximum(frame, 0)) if self.auto_contrast: vmin = float(np.percentile(frame, 2)) vmax = float(np.percentile(frame, 98)) else: vmin = float(frame.min()) vmax = float(frame.max()) if vmax > vmin: normalized = np.clip((frame - vmin) / (vmax - vmin) * 255, 0, 255) return normalized.astype(np.uint8) return np.zeros(frame.shape, dtype=np.uint8)
[docs] def save_image( self, path: str | pathlib.Path, *, format: str | None = None, dpi: int = 150, ) -> pathlib.Path: """Save current image as PNG, PDF, or TIFF. In crop mode, saves the cropped/padded result. In mask mode, saves the masked result. Parameters ---------- path : str or pathlib.Path Output file path. format : str, optional 'png', 'pdf', or 'tiff'. If omitted, inferred from file extension. dpi : int, default 150 Output DPI metadata. Returns ------- pathlib.Path The written file path. """ from matplotlib import colormaps from PIL import Image path = pathlib.Path(path) fmt = (format or path.suffix.lstrip(".").lower() or "png").lower() if fmt not in ("png", "pdf", "tiff", "tif"): raise ValueError(f"Unsupported format: {fmt!r}. Use 'png', 'pdf', or 'tiff'.") result = self.result if isinstance(result, list): result = result[self.selected_idx] normalized = self._normalize_frame(result) cmap_fn = colormaps.get_cmap(self.cmap) rgba = (cmap_fn(normalized / 255.0) * 255).astype(np.uint8) img = Image.fromarray(rgba) path.parent.mkdir(parents=True, exist_ok=True) img.save(str(path), dpi=(dpi, dpi)) return path
# ========================================================================= # State Protocol # =========================================================================
[docs] def state_dict(self): sd = { "title": self.title, "cmap": self.cmap, "mode": self.mode, "log_scale": self.log_scale, "auto_contrast": self.auto_contrast, "show_controls": self.show_controls, "show_stats": self.show_stats, "show_display_controls": self.show_display_controls, "show_edit_controls": self.show_edit_controls, "show_histogram": self.show_histogram, "disabled_tools": self.disabled_tools, "hidden_tools": self.hidden_tools, "pixel_size": self.pixel_size, "fill_value": self.fill_value, "crop_top": self.crop_top, "crop_left": self.crop_left, "crop_bottom": self.crop_bottom, "crop_right": self.crop_right, "shared": self.shared, } if not self.shared and self.n_images > 1: sd["per_image_crops"] = self._get_per_image_crops() return sd
[docs] def save(self, path: str): save_state_file(path, "Edit2D", self.state_dict())
[docs] def load_state_dict(self, state): for key, val in state.items(): if key == "pixel_size_angstrom": key = "pixel_size" if key == "per_image_crops": self.per_image_crops_json = json.dumps(val) continue if hasattr(self, key): setattr(self, key, val) # Clear stale per-image state when restoring to shared mode if state.get("shared", True) and "per_image_crops" not in state: self.per_image_crops_json = "[]" self.per_image_masks_bytes = b""
[docs] def summary(self): name = self.title if self.title else "Edit2D" lines = [name, "═" * 32] lines.append(f"Image: {self.height}×{self.width}") if self.n_images > 1: link = "shared" if self.shared else "independent" lines[-1] += f" ({self.n_images} images, {link})" if self.pixel_size > 0: ps = self.pixel_size if ps >= 10: lines[-1] += f" ({ps / 10:.2f} nm/px)" else: lines[-1] += f" ({ps:.2f} Å/px)" lines.append(f"Mode: {self.mode}") if self.mode == "crop": crop_h, crop_w = self.crop_size top, left, bottom, right = self.crop_bounds lines.append( f"Crop: ({top}, {left}) → " f"({bottom}, {right}) " f"= {crop_h}×{crop_w}" ) lines.append(f"Fill: {self.fill_value}") else: mask_px = int(np.sum(self.mask)) if (self.mask_bytes or self.per_image_masks_bytes) else 0 total = self.height * self.width pct = 100 * mask_px / total if total > 0 else 0 lines.append(f"Mask: {mask_px} px ({pct:.1f}%)") scale = "log" if self.log_scale else "linear" contrast = "auto" if self.auto_contrast else "manual" lines.append(f"Display: {self.cmap} | {contrast} | {scale}") if self.disabled_tools: lines.append(f"Locked: {', '.join(self.disabled_tools)}") if self.hidden_tools: lines.append(f"Hidden: {', '.join(self.hidden_tools)}") print("\n".join(lines))
def __repr__(self): independent = not self.shared and self.n_images > 1 suffix = ", independent" if independent else "" imgs = f", {self.n_images} images" if self.n_images > 1 else "" if self.mode == "mask": mask_px = int(np.sum(self.mask)) if (self.mask_bytes or self.per_image_masks_bytes) else 0 total = self.height * self.width pct = 100 * mask_px / total if total > 0 else 0 return f"Edit2D({self.height}x{self.width}{imgs}, mask={mask_px}px ({pct:.1f}%){suffix})" crop_h, crop_w = self.crop_size top, left, _, _ = self.crop_bounds return ( f"Edit2D({self.height}x{self.width}{imgs}, " f"crop={crop_h}x{crop_w} at ({top},{left}), " f"fill={self.fill_value}{suffix})" )
bind_tool_runtime_api(Edit2D, "Edit2D")