Open In Colab # Show4DSTEM — Quick Demo Synthetic 4D-STEM dataset with a bright-field disk, six first-order Bragg reflections, six second-order spots, and scan-position-dependent intensity variation. Data generated with PyTorch (GPU-accelerated on MPS/CUDA) for realistic vectorized simulation. For 5D time/tilt series support, see show4dstem_5d.ipynb.

[1]:
# Install in Google Colab
try:
    import google.colab
    !pip install -q -i https://test.pypi.org/simple/ --extra-index-url https://pypi.org/simple/ quantem-widget
except ImportError:
    pass  # Not in Colab, skip
[2]:
try:
    %load_ext autoreload
    %autoreload 2
    %env ANYWIDGET_HMR=1
except Exception:
    pass  # autoreload unavailable (Colab Python 3.12+)
env: ANYWIDGET_HMR=1
[3]:
import torch
import numpy as np
import quantem.widget
from quantem.widget import Show4DSTEM
device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
def make_4dstem(scan_rows=16, scan_cols=16, det_rows=64, det_cols=64):
    """4D-STEM dataset with BF disk, Bragg spots, and amorphous background (PyTorch)."""
    # Detector coordinate grids
    dr = torch.arange(det_rows, device=device, dtype=torch.float32)
    dc = torch.arange(det_cols, device=device, dtype=torch.float32)
    rr, cc = torch.meshgrid(dr, dc, indexing="ij")  # (det_rows, det_cols)
    cr, cc0 = det_rows / 2, det_cols / 2
    center_dist = ((rr - cr) ** 2 + (cc - cc0) ** 2).sqrt()
    # Amorphous background (radial falloff) — same for all positions
    bg = 0.05 * torch.exp(-center_dist / 30)
    # BF disk (sharp circular edge with internal modulation)
    bf = (center_dist < 8).float() * (1.0 + 0.2 * torch.cos(center_dist * 0.5))
    # 6 first-order Bragg spots — precompute positions
    spots = torch.zeros(det_rows, det_cols, device=device)
    for k in range(6):
        angle = k * torch.pi / 3
        sr = cr + 20 * torch.sin(torch.tensor(angle, device=device))
        sc = cc0 + 20 * torch.cos(torch.tensor(angle, device=device))
        d2 = (rr - sr) ** 2 + (cc - sc) ** 2
        spots += 0.4 * torch.exp(-d2 / (2 * 2.5**2))
    # 6 second-order spots (weaker, at larger radius)
    for k in range(6):
        angle = k * torch.pi / 3 + torch.pi / 6
        sr = cr + 35 * torch.sin(torch.tensor(angle, device=device))
        sc = cc0 + 35 * torch.cos(torch.tensor(angle, device=device))
        d2 = (rr - sr) ** 2 + (cc - sc) ** 2
        spots += 0.1 * torch.exp(-d2 / (2 * 2.0**2))
    # Base pattern: (det_rows, det_cols)
    base = bg + bf + spots  # (det_rows, det_cols)
    # Scan-position-dependent modulation via broadcasting
    # Simulates thickness/orientation variation across the sample
    si = torch.arange(scan_rows, device=device, dtype=torch.float32)
    sj = torch.arange(scan_cols, device=device, dtype=torch.float32)
    si_grid, sj_grid = torch.meshgrid(si, sj, indexing="ij")  # (scan_rows, scan_cols)
    modulation = 1.0 + 0.15 * torch.sin(
        2 * torch.pi * si_grid / scan_rows
    ) * torch.cos(
        2 * torch.pi * sj_grid / scan_cols
    )  # (scan_rows, scan_cols)
    # Broadcast: (scan_rows, scan_cols, 1, 1) * (1, 1, det_rows, det_cols)
    data = base.unsqueeze(0).unsqueeze(0) * modulation.unsqueeze(-1).unsqueeze(-1)
    # Poisson shot noise for realism
    # MPS does not implement torch.poisson; sample on CPU when needed.
    if device.type == "mps":
        data = torch.poisson(data.clamp(min=0).cpu() * 200) / 200
    else:
        data = torch.poisson(data.clamp(min=0) * 200) / 200
    return data.cpu().numpy()
data = make_4dstem()
print(f"Shape: {data.shape}, dtype: {data.dtype}")
print(f"Range: [{data.min():.3f}, {data.max():.3f}]")
print(f"quantem.widget {quantem.widget.__version__}")
Using device: mps
Shape: (16, 16, 64, 64), dtype: float32
Range: [0.000, 1.620]
quantem.widget 0.4.0a3
[4]:
w = Show4DSTEM(data)
w.auto_detect_center()
w.roi_circle()
print(f"Detected center: ({w.center_row:.1f}, {w.center_col:.1f}), BF radius: {w.bf_radius:.1f}")
w
Detected center: (32.0, 32.0), BF radius: 9.8
[4]:

Inspect Widget State#

[5]:
w.summary()
Show4DSTEM
════════════════════════════════
Scan:     16×16 (1.00 Å/px)
Detector: 64×64 (1.0000 px/px)
Position: (8, 8)
Center:   (32.0, 32.0)  BF r=9.8 px
Display:  DC masked
ROI:      circle at (32.0, 32.0) r=4.9
DP view:  inferno, linear, 0.0-100.0%
VI view:  inferno, linear, 0.0-100.0%
[ ]: