Source code for quantem.widget.align2d

"""
Align2D: Interactive image alignment widget.
Overlay two 2D images with alpha blending and drag/pad to align.
Auto-alignment via FFT cross-correlation with live NCC display.
"""
import json
import pathlib
from typing import Union

import anywidget
import numpy as np
import traitlets

from quantem.widget.array_utils import to_numpy, _resize_image
from quantem.widget.io import IOResult
from quantem.widget.json_state import resolve_widget_version, save_state_file, unwrap_state_payload
from quantem.widget.tool_parity import (
    bind_tool_runtime_api,
    build_tool_groups,
    normalize_tool_groups,
)


def _tukey_2d(h: int, w: int, alpha: float = 0.2) -> np.ndarray:
    """2D Tukey window — flat center, cosine-tapered edges."""
    def _t1d(n: int) -> np.ndarray:
        if n <= 1:
            return np.ones(n)
        x = np.linspace(0, 1, n)
        win = np.ones(n)
        left = x < alpha / 2
        right = x > 1 - alpha / 2
        win[left] = 0.5 * (1 + np.cos(2 * np.pi / alpha * (x[left] - alpha / 2)))
        win[right] = 0.5 * (1 + np.cos(2 * np.pi / alpha * (x[right] - 1 + alpha / 2)))
        return win
    return np.outer(_t1d(h), _t1d(w))


def _dft_upsample(
    fa_conj_fb: np.ndarray,
    peak_y: int,
    peak_x: int,
    upsample_factor: int = 100,
    region: float = 1.5,
) -> tuple[float, float]:
    """Matrix DFT sub-pixel refinement (Guizar-Sicairos et al. 2008).

    Evaluates the inverse DFT at upsampled coordinates in a small region
    around the integer peak. 100x upsampling -> 1/100 pixel accuracy.
    """
    h, w = fa_conj_fb.shape
    size = int(np.ceil(region * upsample_factor))
    ups_y = peak_y + (np.arange(size) - size // 2) / upsample_factor
    ups_x = peak_x + (np.arange(size) - size // 2) / upsample_factor
    # Use proper frequency indices (negative freqs for k >= N/2)
    freq_y = np.fft.fftfreq(h) * h
    freq_x = np.fft.fftfreq(w) * w
    row_kernel = np.exp(2j * np.pi * ups_y[:, None] * freq_y[None, :] / h)
    col_kernel = np.exp(2j * np.pi * ups_x[:, None] * freq_x[None, :] / w)
    upsampled = np.real(row_kernel @ fa_conj_fb @ col_kernel.T)
    up_y, up_x = np.unravel_index(np.argmax(upsampled), upsampled.shape)
    return float(ups_y[up_y]), float(ups_x[up_x])


def _cross_correlate_fft(
    a: np.ndarray,
    b: np.ndarray,
    max_shift_x: int = 0,
    max_shift_y: int = 0,
) -> tuple[float, float]:
    """Phase correlation with Tukey windowing and constrained peak search."""
    h, w = a.shape
    # Tukey window to suppress edge artifacts
    win = _tukey_2d(h, w)
    a_win = (a - a.mean()) * win
    b_win = (b - b.mean()) * win
    fa = np.fft.fft2(a_win)
    fb = np.fft.fft2(b_win)
    # Phase correlation: normalize by magnitude
    cross_power = fa * np.conj(fb)
    cross_power /= np.abs(cross_power) + 1e-10
    xcorr = np.real(np.fft.ifft2(cross_power))
    # Constrain search to valid shift range
    if max_shift_x > 0 or max_shift_y > 0:
        msy = min(max_shift_y, h // 2) if max_shift_y > 0 else h // 2
        msx = min(max_shift_x, w // 2) if max_shift_x > 0 else w // 2
        valid_y = np.zeros(h, dtype=bool)
        valid_y[:msy + 1] = True
        valid_y[max(h - msy, 0):] = True
        valid_x = np.zeros(w, dtype=bool)
        valid_x[:msx + 1] = True
        valid_x[max(w - msx, 0):] = True
        xcorr = np.where(np.outer(valid_y, valid_x), xcorr, -np.inf)
    peak_y, peak_x = np.unravel_index(np.argmax(xcorr), xcorr.shape)
    sub_y, sub_x = _dft_upsample(cross_power, int(peak_y), int(peak_x))
    dy = float(sub_y if sub_y <= h / 2 else sub_y - h)
    dx = float(sub_x if sub_x <= w / 2 else sub_x - w)
    return dx, dy


def _compute_ncc(a: np.ndarray, b: np.ndarray, dx: float, dy: float) -> float:
    """Compute normalized cross-correlation at a specific sub-pixel offset."""
    h, w = a.shape
    idx, idy = int(np.floor(dx)), int(np.floor(dy))
    fx, fy = dx - idx, dy - idy
    # Overlap region (shrunk by 1 for bilinear +1 neighbor)
    y_start = max(0, idy)
    y_end = min(h, h + idy - 1)
    x_start = max(0, idx)
    x_end = min(w, w + idx - 1)
    if y_end <= y_start or x_end <= x_start:
        return 0.0
    a_crop = a[y_start:y_end, x_start:x_end]
    # Bilinear interpolation of b
    by_s, by_e = y_start - idy, y_end - idy
    bx_s, bx_e = x_start - idx, x_end - idx
    b_interp = (b[by_s:by_e, bx_s:bx_e] * (1 - fx) * (1 - fy)
                + b[by_s:by_e, bx_s + 1:bx_e + 1] * fx * (1 - fy)
                + b[by_s + 1:by_e + 1, bx_s:bx_e] * (1 - fx) * fy
                + b[by_s + 1:by_e + 1, bx_s + 1:bx_e + 1] * fx * fy)
    a_c = a_crop - a_crop.mean()
    b_c = b_interp - b_interp.mean()
    denom = np.sqrt(np.sum(a_c ** 2) * np.sum(b_c ** 2))
    if denom == 0:
        return 0.0
    return float(np.sum(a_c * b_c) / denom)


[docs] class Align2D(anywidget.AnyWidget): """ Interactive alignment of two 2D images. Parameters ---------- image_a : array_like First 2D image (reference, stays fixed). image_b : array_like Second 2D image (draggable). title : str, optional Title displayed above the viewer. label_a : str, default "Image A" Label for the first image. label_b : str, default "Image B" Label for the second image. cmap : str, default "gray" Colormap name. opacity : float, default 0.5 Blend ratio (0 = only A, 1 = only B). padding : float, default 0.2 Fractional padding on each side. Adjustable from the frontend. pixel_size : float, default 0.0 Pixel size in Å for scale bar (0 = uncalibrated). canvas_size : int, default 300 Initial canvas size in CSS pixels for each column. auto_align : bool, default True Automatically compute initial alignment via cross-correlation. max_shift : float, default 0.0 Maximum allowed shift in pixels (0 = unlimited, constrained by padding). rotation : float, default 0.0 Initial rotation angle of image B in degrees. hist_source : str, default "a" Which image to show in the histogram ("a" or "b"). Examples -------- >>> import numpy as np >>> from quantem.widget import Align2D >>> a = np.random.rand(64, 64).astype(np.float32) >>> b = np.random.rand(64, 64).astype(np.float32) >>> Align2D(a, b, title="Alignment") """ _esm = pathlib.Path(__file__).parent / "static" / "align2d.js" _css = pathlib.Path(__file__).parent / "static" / "align2d.css" # Image dimensions (unpadded) height = traitlets.Int(1).tag(sync=True) width = traitlets.Int(1).tag(sync=True) # Image data (unpadded, float32 bytes) image_a_bytes = traitlets.Bytes(b"").tag(sync=True) image_b_bytes = traitlets.Bytes(b"").tag(sync=True) # Padding (fractional, adjustable from frontend) padding = traitlets.Float(0.2).tag(sync=True) # Median values for padding fill median_a = traitlets.Float(0.0).tag(sync=True) median_b = traitlets.Float(0.0).tag(sync=True) # Alignment offset (image B relative to A, in pixels) dx = traitlets.Float(0.0).tag(sync=True) dy = traitlets.Float(0.0).tag(sync=True) # Rotation angle (degrees, of image B around its center) rotation = traitlets.Float(0.0).tag(sync=True) # Auto-aligned values (stored so user can restore) auto_dx = traitlets.Float(0.0).tag(sync=True) auto_dy = traitlets.Float(0.0).tag(sync=True) # Cross-correlation: NCC at (0,0) offset (baseline before alignment) xcorr_zero = traitlets.Float(0.0).tag(sync=True) # NCC at auto-aligned position (accurate, computed by Python) ncc_aligned = traitlets.Float(0.0).tag(sync=True) # Display title = traitlets.Unicode("").tag(sync=True) cmap = traitlets.Unicode("gray").tag(sync=True) opacity = traitlets.Float(0.5).tag(sync=True) label_a = traitlets.Unicode("Image A").tag(sync=True) label_b = traitlets.Unicode("Image B").tag(sync=True) # Scale bar pixel_size = traitlets.Float(0.0).tag(sync=True) # Bounds max_shift = traitlets.Float(0.0).tag(sync=True) # UI canvas_size = traitlets.Int(300).tag(sync=True) hist_source = traitlets.Unicode("a").tag(sync=True) # Tool visibility / locking disabled_tools = traitlets.List(traitlets.Unicode()).tag(sync=True) hidden_tools = traitlets.List(traitlets.Unicode()).tag(sync=True) @classmethod def _normalize_tool_groups(cls, tool_groups) -> list[str]: return normalize_tool_groups("Align2D", tool_groups) @classmethod def _build_disabled_tools( cls, disabled_tools=None, disable_alignment: bool = False, disable_overlay: bool = False, disable_display: bool = False, disable_histogram: bool = False, disable_stats: bool = False, disable_export: bool = False, disable_view: bool = False, disable_all: bool = False, ) -> list[str]: return build_tool_groups( "Align2D", tool_groups=disabled_tools, all_flag=disable_all, flag_map={ "alignment": disable_alignment, "overlay": disable_overlay, "display": disable_display, "histogram": disable_histogram, "stats": disable_stats, "export": disable_export, "view": disable_view, }, ) @classmethod def _build_hidden_tools( cls, hidden_tools=None, hide_alignment: bool = False, hide_overlay: bool = False, hide_display: bool = False, hide_histogram: bool = False, hide_stats: bool = False, hide_export: bool = False, hide_view: bool = False, hide_all: bool = False, ) -> list[str]: return build_tool_groups( "Align2D", tool_groups=hidden_tools, all_flag=hide_all, flag_map={ "alignment": hide_alignment, "overlay": hide_overlay, "display": hide_display, "histogram": hide_histogram, "stats": hide_stats, "export": hide_export, "view": hide_view, }, ) @traitlets.validate("disabled_tools") def _validate_disabled_tools(self, proposal): return self._normalize_tool_groups(proposal["value"]) @traitlets.validate("hidden_tools") def _validate_hidden_tools(self, proposal): return self._normalize_tool_groups(proposal["value"]) def __init__( self, image_a: Union[np.ndarray, "torch.Tensor"], image_b: Union[np.ndarray, "torch.Tensor"], title: str = "", label_a: str = "Image A", label_b: str = "Image B", cmap: str = "gray", opacity: float = 0.5, padding: float = 0.2, pixel_size: float = 0.0, canvas_size: int = 300, auto_align: bool = True, max_shift: float = 0.0, rotation: float = 0.0, hist_source: str = "a", disabled_tools=None, disable_alignment: bool = False, disable_overlay: bool = False, disable_display: bool = False, disable_histogram: bool = False, disable_stats: bool = False, disable_export: bool = False, disable_view: bool = False, disable_all: bool = False, hidden_tools=None, hide_alignment: bool = False, hide_overlay: bool = False, hide_display: bool = False, hide_histogram: bool = False, hide_stats: bool = False, hide_export: bool = False, hide_view: bool = False, hide_all: bool = False, state=None, **kwargs, ): super().__init__(**kwargs) self.widget_version = resolve_widget_version() # Check if inputs are IOResult and extract metadata for img_ref in ("image_a", "image_b"): img_data = image_a if img_ref == "image_a" else image_b if isinstance(img_data, IOResult): if not title and img_data.title: title = img_data.title if pixel_size == 0.0 and img_data.pixel_size is not None: pixel_size = img_data.pixel_size if img_ref == "image_a": image_a = img_data.data else: image_b = img_data.data # Check if inputs are Dataset2d and extract metadata for img_data in (image_a, image_b): if hasattr(img_data, "array") and hasattr(img_data, "name") and hasattr(img_data, "sampling"): if not title and img_data.name: title = img_data.name if pixel_size == 0.0 and hasattr(img_data, "units"): units = list(img_data.units) sampling_val = float(img_data.sampling[-1]) # pixel_size is in Å — convert if units are nm if units[-1] in ("nm", "nanometer"): sampling_val = sampling_val * 10 # nm → Å pixel_size = sampling_val # Extract arrays from Dataset objects if hasattr(image_a, "array"): image_a = image_a.array if hasattr(image_b, "array"): image_b = image_b.array a = to_numpy(image_a).astype(np.float32) b = to_numpy(image_b).astype(np.float32) if a.ndim != 2: raise ValueError(f"Align2D requires 2D images, image_a is {a.ndim}D") if b.ndim != 2: raise ValueError(f"Align2D requires 2D images, image_b is {b.ndim}D") # Resize smaller to match larger target_h = max(a.shape[0], b.shape[0]) target_w = max(a.shape[1], b.shape[1]) if a.shape != (target_h, target_w): a = _resize_image(a, target_h, target_w) if b.shape != (target_h, target_w): b = _resize_image(b, target_h, target_w) self.height = target_h self.width = target_w self.padding = padding # Medians for padding fill self.median_a = float(np.median(a)) self.median_b = float(np.median(b)) # Display options self.title = title self.label_a = label_a self.label_b = label_b self.cmap = cmap self.opacity = opacity self.pixel_size = pixel_size self.canvas_size = canvas_size self.max_shift = max_shift self.rotation = rotation self.hist_source = hist_source self.disabled_tools = self._build_disabled_tools( disabled_tools=disabled_tools, disable_alignment=disable_alignment, disable_overlay=disable_overlay, disable_display=disable_display, disable_histogram=disable_histogram, disable_stats=disable_stats, disable_export=disable_export, disable_view=disable_view, disable_all=disable_all, ) self.hidden_tools = self._build_hidden_tools( hidden_tools=hidden_tools, hide_alignment=hide_alignment, hide_overlay=hide_overlay, hide_display=hide_display, hide_histogram=hide_histogram, hide_stats=hide_stats, hide_export=hide_export, hide_view=hide_view, hide_all=hide_all, ) # Cross-correlation at (0,0) — baseline self.xcorr_zero = _compute_ncc(a, b, 0.0, 0.0) # Translation alignment via phase correlation if auto_align: try: limit_x = int(max_shift if max_shift > 0 else target_w * padding) limit_y = int(max_shift if max_shift > 0 else target_h * padding) best_dx, best_dy = _cross_correlate_fft(a, b, limit_x, limit_y) self.dx = max(-limit_x, min(limit_x, best_dx)) self.dy = max(-limit_y, min(limit_y, best_dy)) # Store auto values so user can restore self.auto_dx = self.dx self.auto_dy = self.dy # Compute NCC at aligned position self.ncc_aligned = _compute_ncc(a, b, self.dx, self.dy) except Exception as e: import warnings warnings.warn(f"Auto-alignment failed: {e}", stacklevel=2) # Send unpadded bytes self.image_a_bytes = a.tobytes() self.image_b_bytes = b.tobytes() if state is not None: if isinstance(state, (str, pathlib.Path)): state = unwrap_state_payload( json.loads(pathlib.Path(state).read_text()), require_envelope=True, ) else: state = unwrap_state_payload(state) self.load_state_dict(state)
[docs] def set_images(self, image_a, image_b, auto_align=True): """Replace both images. Preserves display settings, recomputes alignment.""" if hasattr(image_a, "array"): image_a = image_a.array if hasattr(image_b, "array"): image_b = image_b.array a = to_numpy(image_a).astype(np.float32) b = to_numpy(image_b).astype(np.float32) if a.ndim != 2: raise ValueError(f"Align2D requires 2D images, image_a is {a.ndim}D") if b.ndim != 2: raise ValueError(f"Align2D requires 2D images, image_b is {b.ndim}D") target_h = max(a.shape[0], b.shape[0]) target_w = max(a.shape[1], b.shape[1]) if a.shape != (target_h, target_w): a = _resize_image(a, target_h, target_w) if b.shape != (target_h, target_w): b = _resize_image(b, target_h, target_w) self.height = target_h self.width = target_w self.median_a = float(np.median(a)) self.median_b = float(np.median(b)) self.xcorr_zero = _compute_ncc(a, b, 0.0, 0.0) self.dx = 0.0 self.dy = 0.0 self.rotation = 0.0 if auto_align: try: limit_x = int(self.max_shift if self.max_shift > 0 else target_w * self.padding) limit_y = int(self.max_shift if self.max_shift > 0 else target_h * self.padding) best_dx, best_dy = _cross_correlate_fft(a, b, limit_x, limit_y) self.dx = max(-limit_x, min(limit_x, best_dx)) self.dy = max(-limit_y, min(limit_y, best_dy)) self.auto_dx = self.dx self.auto_dy = self.dy self.ncc_aligned = _compute_ncc(a, b, self.dx, self.dy) except Exception: pass self.image_a_bytes = a.tobytes() self.image_b_bytes = b.tobytes()
def __repr__(self) -> str: return f"Align2D({self.height}×{self.width}, dx={self.dx:.1f}, dy={self.dy:.1f}, rot={self.rotation:.1f}°)"
[docs] def state_dict(self): return { "title": self.title, "label_a": self.label_a, "label_b": self.label_b, "cmap": self.cmap, "opacity": self.opacity, "padding": self.padding, "dx": self.dx, "dy": self.dy, "rotation": self.rotation, "pixel_size": self.pixel_size, "max_shift": self.max_shift, "canvas_size": self.canvas_size, "hist_source": self.hist_source, "disabled_tools": self.disabled_tools, "hidden_tools": self.hidden_tools, }
[docs] def save(self, path: str): save_state_file(path, "Align2D", self.state_dict())
[docs] def load_state_dict(self, state): for key, val in state.items(): if hasattr(self, key): setattr(self, key, val)
[docs] def summary(self): lines = [self.title or "Align2D", "═" * 32] lines.append(f"Image: {self.height}×{self.width}") if self.pixel_size > 0: ps = self.pixel_size if ps >= 10: lines[-1] += f" ({ps / 10:.2f} nm/px)" else: lines[-1] += f" ({ps:.2f} Å/px)" lines.append(f"Labels: A={self.label_a!r} B={self.label_b!r}") lines.append(f"Offset: dx={self.dx:.2f} dy={self.dy:.2f} rotation={self.rotation:.2f}°") lines.append(f"Display: {self.cmap} | opacity={self.opacity:.0%} | padding={self.padding:.0%}") if self.ncc_aligned != 0: lines.append(f"NCC: aligned={self.ncc_aligned:.4f} zero={self.xcorr_zero:.4f}") print("\n".join(lines))
[docs] def reset_alignment(self): self.dx = 0.0 self.dy = 0.0 self.rotation = 0.0
@property def offset(self) -> tuple[float, float]: """Return (dx, dy) alignment offset.""" return (self.dx, self.dy)
bind_tool_runtime_api(Align2D, "Align2D")