"""
showcomplex: Interactive complex-valued image viewer.
For displaying complex data from ptychography, holography, and exit wave
reconstruction. Supports amplitude, phase, HSV, real, and imaginary display modes.
"""
import json
import pathlib
from typing import List, Optional
import anywidget
import numpy as np
import traitlets
from quantem.widget.array_utils import to_numpy
from quantem.widget.io import IOResult
from quantem.widget.json_state import resolve_widget_version, save_state_file, unwrap_state_payload
from quantem.widget.tool_parity import (
bind_tool_runtime_api,
build_tool_groups,
normalize_tool_groups,
)
[docs]
class ShowComplex2D(anywidget.AnyWidget):
"""
Interactive viewer for complex-valued 2D data.
Display complex images from ptychography, holography, or exit wave
reconstruction with five visualization modes: amplitude, phase, HSV
(hue=phase, brightness=amplitude), real part, and imaginary part.
Parameters
----------
data : array_like (complex) or tuple of (real, imag)
Complex 2D array of shape (height, width) with dtype complex64 or
complex128. Also accepts a tuple ``(real, imag)`` of two real arrays.
display_mode : str, default "amplitude"
Initial display mode: ``"amplitude"``, ``"phase"``, ``"hsv"``,
``"real"``, or ``"imag"``.
title : str, optional
Title displayed in the widget header.
cmap : str, default "inferno"
Colormap for amplitude/real/imag modes. Phase and HSV modes use
a fixed cyclic colormap.
pixel_size : float, default 0.0
Pixel size in angstroms for scale bar display.
log_scale : bool, default False
Apply log(1+x) to amplitude before display.
auto_contrast : bool, default False
Use percentile-based contrast.
show_fft : bool, default False
Show FFT panel.
show_stats : bool, default True
Show statistics bar.
show_controls : bool, default True
Show control panel.
Examples
--------
>>> import numpy as np
>>> from quantem.widget import ShowComplex2D
>>>
>>> # Complex exit wave
>>> data = np.exp(1j * phase) * amplitude
>>> ShowComplex2D(data, title="Exit Wave", display_mode="hsv")
>>>
>>> # From real and imaginary parts
>>> ShowComplex2D((real_part, imag_part), display_mode="phase")
"""
_esm = pathlib.Path(__file__).parent / "static" / "showcomplex.js"
_css = pathlib.Path(__file__).parent / "static" / "showcomplex.css"
# Core state
height = traitlets.Int(1).tag(sync=True)
width = traitlets.Int(1).tag(sync=True)
real_bytes = traitlets.Bytes(b"").tag(sync=True)
imag_bytes = traitlets.Bytes(b"").tag(sync=True)
title = traitlets.Unicode("").tag(sync=True)
# Display mode
display_mode = traitlets.Unicode("amplitude").tag(sync=True)
cmap = traitlets.Unicode("inferno").tag(sync=True)
# Display options
log_scale = traitlets.Bool(False).tag(sync=True)
auto_contrast = traitlets.Bool(False).tag(sync=True)
percentile_low = traitlets.Float(1.0).tag(sync=True)
percentile_high = traitlets.Float(99.0).tag(sync=True)
# Scale bar
pixel_size = traitlets.Float(0.0).tag(sync=True)
scale_bar_visible = traitlets.Bool(True).tag(sync=True)
# UI
show_stats = traitlets.Bool(True).tag(sync=True)
show_fft = traitlets.Bool(False).tag(sync=True)
fft_window = traitlets.Bool(True).tag(sync=True)
show_controls = traitlets.Bool(True).tag(sync=True)
image_width_px = traitlets.Int(0).tag(sync=True)
disabled_tools = traitlets.List(traitlets.Unicode()).tag(sync=True)
hidden_tools = traitlets.List(traitlets.Unicode()).tag(sync=True)
# ROI (single-mode, same pattern as Show4D)
roi_mode = traitlets.Unicode("off").tag(sync=True) # "off", "circle", "square", "rect"
roi_center_row = traitlets.Float(0.0).tag(sync=True)
roi_center_col = traitlets.Float(0.0).tag(sync=True)
roi_center = traitlets.List(traitlets.Float(), default_value=[0.0, 0.0]).tag(sync=True)
roi_radius = traitlets.Float(5.0).tag(sync=True)
roi_width = traitlets.Float(10.0).tag(sync=True)
roi_height = traitlets.Float(10.0).tag(sync=True)
# Statistics (recomputed per display_mode)
stats_mean = traitlets.Float(0.0).tag(sync=True)
stats_min = traitlets.Float(0.0).tag(sync=True)
stats_max = traitlets.Float(0.0).tag(sync=True)
stats_std = traitlets.Float(0.0).tag(sync=True)
@classmethod
def _normalize_tool_groups(cls, tool_groups) -> List[str]:
return normalize_tool_groups("ShowComplex2D", tool_groups)
@classmethod
def _build_disabled_tools(
cls,
disabled_tools=None,
disable_display: bool = False,
disable_histogram: bool = False,
disable_fft: bool = False,
disable_roi: bool = False,
disable_stats: bool = False,
disable_export: bool = False,
disable_view: bool = False,
disable_all: bool = False,
) -> List[str]:
return build_tool_groups(
"ShowComplex2D",
tool_groups=disabled_tools,
all_flag=disable_all,
flag_map={
"display": disable_display,
"histogram": disable_histogram,
"fft": disable_fft,
"roi": disable_roi,
"stats": disable_stats,
"export": disable_export,
"view": disable_view,
},
)
@classmethod
def _build_hidden_tools(
cls,
hidden_tools=None,
hide_display: bool = False,
hide_histogram: bool = False,
hide_fft: bool = False,
hide_roi: bool = False,
hide_stats: bool = False,
hide_export: bool = False,
hide_view: bool = False,
hide_all: bool = False,
) -> List[str]:
return build_tool_groups(
"ShowComplex2D",
tool_groups=hidden_tools,
all_flag=hide_all,
flag_map={
"display": hide_display,
"histogram": hide_histogram,
"fft": hide_fft,
"roi": hide_roi,
"stats": hide_stats,
"export": hide_export,
"view": hide_view,
},
)
@traitlets.validate("disabled_tools")
def _validate_disabled_tools(self, proposal):
return self._normalize_tool_groups(proposal["value"])
@traitlets.validate("hidden_tools")
def _validate_hidden_tools(self, proposal):
return self._normalize_tool_groups(proposal["value"])
def __init__(
self,
data,
display_mode: str = "amplitude",
title: str = "",
cmap: str = "inferno",
pixel_size: float = 0.0,
log_scale: bool = False,
auto_contrast: bool = False,
percentile_low: float = 1.0,
percentile_high: float = 99.0,
show_fft: bool = False,
fft_window: bool = True,
show_stats: bool = True,
show_controls: bool = True,
scale_bar_visible: bool = True,
image_width_px: int = 0,
disabled_tools: Optional[List[str]] = None,
disable_display: bool = False,
disable_histogram: bool = False,
disable_fft: bool = False,
disable_roi: bool = False,
disable_stats: bool = False,
disable_export: bool = False,
disable_view: bool = False,
disable_all: bool = False,
hidden_tools: Optional[List[str]] = None,
hide_display: bool = False,
hide_histogram: bool = False,
hide_fft: bool = False,
hide_roi: bool = False,
hide_stats: bool = False,
hide_export: bool = False,
hide_view: bool = False,
hide_all: bool = False,
state=None,
**kwargs,
):
super().__init__(**kwargs)
self.widget_version = resolve_widget_version()
# Check if data is an IOResult and extract metadata
if isinstance(data, IOResult):
if not title and data.title:
title = data.title
if pixel_size == 0.0 and data.pixel_size is not None:
pixel_size = data.pixel_size
data = data.data
# Dataset duck typing
_extracted_title = None
_extracted_pixel_size = None
if hasattr(data, "array") and hasattr(data, "name") and hasattr(data, "sampling"):
_extracted_title = data.name if data.name else None
if hasattr(data, "units"):
units = list(data.units)
sampling_val = float(data.sampling[-1])
if units[-1] in ("nm",):
_extracted_pixel_size = sampling_val * 10 # nm → Å
elif units[-1] in ("Å", "angstrom", "A"):
_extracted_pixel_size = sampling_val
data = data.array
# Handle (real, imag) tuple input
if isinstance(data, tuple) and len(data) == 2:
real_arr = to_numpy(data[0]).astype(np.float32)
imag_arr = to_numpy(data[1]).astype(np.float32)
if real_arr.shape != imag_arr.shape:
raise ValueError(
f"Real and imaginary parts must have same shape, "
f"got {real_arr.shape} and {imag_arr.shape}"
)
if real_arr.ndim != 2:
raise ValueError(f"Expected 2D arrays, got {real_arr.ndim}D")
self._real = real_arr
self._imag = imag_arr
else:
arr = to_numpy(data)
if not np.issubdtype(arr.dtype, np.complexfloating):
raise ValueError(
f"Expected complex array (complex64/complex128), got {arr.dtype}. "
f"Use ShowComplex2D((real, imag)) for real-valued input."
)
if arr.ndim != 2:
raise ValueError(f"Expected 2D array, got {arr.ndim}D")
self._real = arr.real.astype(np.float32)
self._imag = arr.imag.astype(np.float32)
self.height = int(self._real.shape[0])
self.width = int(self._real.shape[1])
# Options
self.display_mode = display_mode
self.title = title if title else (_extracted_title or "")
self.cmap = cmap
if pixel_size == 0.0 and _extracted_pixel_size is not None:
pixel_size = _extracted_pixel_size
self.pixel_size = pixel_size
self.log_scale = log_scale
self.auto_contrast = auto_contrast
self.percentile_low = percentile_low
self.percentile_high = percentile_high
self.show_fft = show_fft
self.fft_window = fft_window
self.show_stats = show_stats
self.show_controls = show_controls
self.scale_bar_visible = scale_bar_visible
self.image_width_px = image_width_px
self.disabled_tools = self._build_disabled_tools(
disabled_tools=disabled_tools,
disable_display=disable_display,
disable_histogram=disable_histogram,
disable_fft=disable_fft,
disable_roi=disable_roi,
disable_stats=disable_stats,
disable_export=disable_export,
disable_view=disable_view,
disable_all=disable_all,
)
self.hidden_tools = self._build_hidden_tools(
hidden_tools=hidden_tools,
hide_display=hide_display,
hide_histogram=hide_histogram,
hide_fft=hide_fft,
hide_roi=hide_roi,
hide_stats=hide_stats,
hide_export=hide_export,
hide_view=hide_view,
hide_all=hide_all,
)
# ROI defaults (centered, radius proportional to image size)
default_roi_size = max(3, min(self.height, self.width) // 6)
self.roi_center_row = float(self.height / 2)
self.roi_center_col = float(self.width / 2)
self.roi_center = [float(self.height / 2), float(self.width / 2)]
self.roi_radius = float(default_roi_size)
self.roi_width = float(default_roi_size * 2)
self.roi_height = float(default_roi_size)
# Compute stats for initial display mode
self._update_stats()
# Send data to JS
self.real_bytes = self._real.tobytes()
self.imag_bytes = self._imag.tobytes()
# Observers
self.observe(self._on_display_mode_change, names=["display_mode"])
self.observe(self._on_roi_center_change, names=["roi_center"])
# State restoration (must be last)
if state is not None:
if isinstance(state, (str, pathlib.Path)):
state = unwrap_state_payload(
json.loads(pathlib.Path(state).read_text()),
require_envelope=True,
)
else:
state = unwrap_state_payload(state)
self.load_state_dict(state)
def _get_display_data(self, mode: str | None = None) -> np.ndarray:
mode = mode or self.display_mode
if mode == "amplitude":
return np.sqrt(self._real ** 2 + self._imag ** 2)
elif mode == "phase":
return np.arctan2(self._imag, self._real)
elif mode == "real":
return self._real
elif mode == "imag":
return self._imag
elif mode == "hsv":
return np.sqrt(self._real ** 2 + self._imag ** 2)
else:
raise ValueError(f"Unknown display mode: {mode!r}")
def _update_stats(self):
data = self._get_display_data()
self.stats_mean = float(data.mean())
self.stats_min = float(data.min())
self.stats_max = float(data.max())
self.stats_std = float(data.std())
def _on_display_mode_change(self, change=None):
self._update_stats()
def _on_roi_center_change(self, change=None):
val = self.roi_center
if isinstance(val, (list, tuple)) and len(val) >= 2:
self.roi_center_row = float(val[0])
self.roi_center_col = float(val[1])
[docs]
def roi_circle(self, row=None, col=None, radius=None) -> "ShowComplex2D":
if row is not None:
self.roi_center_row = float(row)
if col is not None:
self.roi_center_col = float(col)
if radius is not None:
self.roi_radius = float(radius)
self.roi_mode = "circle"
return self
[docs]
def roi_square(self, row=None, col=None, radius=None) -> "ShowComplex2D":
if row is not None:
self.roi_center_row = float(row)
if col is not None:
self.roi_center_col = float(col)
if radius is not None:
self.roi_radius = float(radius)
self.roi_mode = "square"
return self
[docs]
def roi_rect(self, row=None, col=None, width=None, height=None) -> "ShowComplex2D":
if row is not None:
self.roi_center_row = float(row)
if col is not None:
self.roi_center_col = float(col)
if width is not None:
self.roi_width = float(width)
if height is not None:
self.roi_height = float(height)
self.roi_mode = "rect"
return self
[docs]
def set_image(self, data):
"""Replace the complex data. Preserves all display settings."""
if hasattr(data, "array") and hasattr(data, "name") and hasattr(data, "sampling"):
data = data.array
if isinstance(data, tuple) and len(data) == 2:
real_arr = to_numpy(data[0]).astype(np.float32)
imag_arr = to_numpy(data[1]).astype(np.float32)
if real_arr.shape != imag_arr.shape:
raise ValueError(
f"Real and imaginary parts must have same shape, "
f"got {real_arr.shape} and {imag_arr.shape}"
)
if real_arr.ndim != 2:
raise ValueError(f"Expected 2D arrays, got {real_arr.ndim}D")
self._real = real_arr
self._imag = imag_arr
else:
arr = to_numpy(data)
if not np.issubdtype(arr.dtype, np.complexfloating):
raise ValueError(
f"Expected complex array (complex64/complex128), got {arr.dtype}."
)
if arr.ndim != 2:
raise ValueError(f"Expected 2D array, got {arr.ndim}D")
self._real = arr.real.astype(np.float32)
self._imag = arr.imag.astype(np.float32)
self.height = int(self._real.shape[0])
self.width = int(self._real.shape[1])
self._update_stats()
self.real_bytes = self._real.tobytes()
self.imag_bytes = self._imag.tobytes()
# =========================================================================
# Export
# =========================================================================
def _normalize_frame(self, frame: np.ndarray) -> np.ndarray:
if self.log_scale:
frame = np.log1p(np.maximum(frame, 0))
if self.auto_contrast:
vmin = float(np.percentile(frame, self.percentile_low))
vmax = float(np.percentile(frame, self.percentile_high))
else:
vmin = float(frame.min())
vmax = float(frame.max())
if vmax > vmin:
return np.clip((frame - vmin) / (vmax - vmin) * 255, 0, 255).astype(np.uint8)
return np.zeros(frame.shape, dtype=np.uint8)
[docs]
def save_image(
self,
path: str | pathlib.Path,
*,
display_mode: str | None = None,
format: str | None = None,
dpi: int = 150,
) -> pathlib.Path:
"""Save current view as PNG, PDF, or TIFF.
Parameters
----------
path : str or pathlib.Path
Output file path.
display_mode : str, optional
Override display mode. One of 'amplitude', 'phase', 'hsv',
'real', 'imag'. Defaults to current display_mode.
format : str, optional
'png', 'pdf', or 'tiff'. If omitted, inferred from file extension.
dpi : int, default 150
Output DPI metadata.
Returns
-------
pathlib.Path
The written file path.
"""
import matplotlib.colors as mcolors
from matplotlib import colormaps
from PIL import Image
path = pathlib.Path(path)
fmt = (format or path.suffix.lstrip(".").lower() or "png").lower()
if fmt not in ("png", "pdf", "tiff", "tif"):
raise ValueError(f"Unsupported format: {fmt!r}. Use 'png', 'pdf', or 'tiff'.")
mode = display_mode or self.display_mode
valid_modes = ("amplitude", "phase", "hsv", "real", "imag")
if mode not in valid_modes:
raise ValueError(f"Unknown display_mode: {mode!r}. Use one of {valid_modes}.")
if mode == "hsv":
amp = np.sqrt(self._real ** 2 + self._imag ** 2)
phase = np.arctan2(self._imag, self._real)
amp_min, amp_max = float(amp.min()), float(amp.max())
if amp_max > amp_min:
amp_norm = (amp - amp_min) / (amp_max - amp_min)
else:
amp_norm = np.zeros_like(amp)
hue = (phase + np.pi) / (2 * np.pi)
hsv_array = np.stack([hue, np.ones_like(hue), amp_norm], axis=-1)
rgb = mcolors.hsv_to_rgb(hsv_array)
rgba = np.zeros((*rgb.shape[:2], 4), dtype=np.uint8)
rgba[:, :, :3] = (rgb * 255).astype(np.uint8)
rgba[:, :, 3] = 255
img = Image.fromarray(rgba)
else:
data = self._get_display_data(mode)
if self.log_scale and mode in ("amplitude", "real", "imag"):
data = np.log1p(np.maximum(data, 0))
if self.auto_contrast:
vmin = float(np.percentile(data, self.percentile_low))
vmax = float(np.percentile(data, self.percentile_high))
else:
vmin = float(data.min())
vmax = float(data.max())
if vmax > vmin:
normalized = np.clip((data - vmin) / (vmax - vmin), 0, 1)
else:
normalized = np.zeros_like(data)
cmap_name = "hsv" if mode == "phase" else self.cmap
cmap_fn = colormaps.get_cmap(cmap_name)
rgba = (cmap_fn(normalized) * 255).astype(np.uint8)
img = Image.fromarray(rgba)
path.parent.mkdir(parents=True, exist_ok=True)
img.save(str(path), dpi=(dpi, dpi))
return path
# =========================================================================
# State Protocol
# =========================================================================
[docs]
def state_dict(self):
return {
"display_mode": self.display_mode,
"title": self.title,
"cmap": self.cmap,
"log_scale": self.log_scale,
"auto_contrast": self.auto_contrast,
"percentile_low": self.percentile_low,
"percentile_high": self.percentile_high,
"pixel_size": self.pixel_size,
"scale_bar_visible": self.scale_bar_visible,
"show_fft": self.show_fft,
"fft_window": self.fft_window,
"show_stats": self.show_stats,
"show_controls": self.show_controls,
"image_width_px": self.image_width_px,
"roi_mode": self.roi_mode,
"roi_center_row": self.roi_center_row,
"roi_center_col": self.roi_center_col,
"roi_radius": self.roi_radius,
"roi_width": self.roi_width,
"roi_height": self.roi_height,
"disabled_tools": self.disabled_tools,
"hidden_tools": self.hidden_tools,
}
[docs]
def save(self, path: str):
"""Save widget state to a JSON file."""
save_state_file(path, "ShowComplex2D", self.state_dict())
[docs]
def load_state_dict(self, state):
"""Restore widget state from a dict."""
if "pixel_size_angstrom" in state and "pixel_size" not in state:
state = dict(state, pixel_size=state.pop("pixel_size_angstrom"))
for key, val in state.items():
if hasattr(self, key):
setattr(self, key, val)
[docs]
def summary(self):
"""Print a human-readable summary of the widget state."""
name = self.title if self.title else "ShowComplex2D"
lines = [name, "═" * 32]
lines.append(f"Image: {self.height}×{self.width} (complex)")
if self.pixel_size > 0:
ps = self.pixel_size
if ps >= 10:
lines[-1] += f" ({ps / 10:.2f} nm/px)"
else:
lines[-1] += f" ({ps:.2f} Å/px)"
amp = np.sqrt(self._real ** 2 + self._imag ** 2)
lines.append(
f"Amp: min={float(amp.min()):.4g} max={float(amp.max()):.4g} "
f"mean={float(amp.mean()):.4g}"
)
phase = np.arctan2(self._imag, self._real)
lines.append(
f"Phase: min={float(phase.min()):.4g} max={float(phase.max()):.4g} "
f"mean={float(phase.mean()):.4g}"
)
mode = self.display_mode
cmap = self.cmap if mode in ("amplitude", "real", "imag") else "hsv (cyclic)"
scale = "log" if self.log_scale else "linear"
contrast = "auto" if self.auto_contrast else "manual"
lines.append(f"Display: {mode} | {cmap} | {contrast} | {scale}")
if self.show_fft:
lines[-1] += " | FFT"
if not self.fft_window:
lines[-1] += " (no window)"
print("\n".join(lines))
def __repr__(self) -> str:
name = self.title if self.title else "ShowComplex2D"
parts = [f"{name}({self.height}×{self.width}"]
parts.append(f"mode={self.display_mode}")
if self.pixel_size > 0:
ps = self.pixel_size
if ps >= 10:
parts.append(f"px={ps / 10:.2f} nm")
else:
parts.append(f"px={ps:.2f} Å")
if self.log_scale:
parts.append("log")
if self.show_fft:
parts.append("fft")
return ", ".join(parts) + ")"
bind_tool_runtime_api(ShowComplex2D, "ShowComplex2D")