[1]:
# Install in Google Colab
try:
    import google.colab
    !pip install -q -i https://test.pypi.org/simple/ --extra-index-url https://pypi.org/simple/ quantem-widget
except ImportError:
    pass  # Not in Colab, skip
[2]:
try:
    %load_ext autoreload
    %autoreload 2
    %env ANYWIDGET_HMR=1
except Exception:
    pass  # autoreload unavailable (Colab Python 3.12+)
env: ANYWIDGET_HMR=1

Open In Colab # Mark2D — All Features Comprehensive demo of every Mark2D capability: basic atom picking, custom scale/dot size, image replacement, coordinate retrieval, lattice basis definition, PyTorch tensor input, gallery mode, snap-to-peak, and state save/load for reproducible analysis.

1. Basic HAADF-STEM atom picking#

Hexagonal lattice simulating a [110] zone axis. Click on bright atom columns to select positions.

[3]:
import numpy as np
import torch
import quantem.widget
from quantem.widget import Mark2D
device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")
def make_haadf_stem(size=256, spacing=18, sigma=2.8):
    """HAADF-STEM image with atomic columns on a hexagonal lattice."""
    coords = torch.stack(torch.meshgrid(torch.arange(size, device=device, dtype=torch.float32),
                                        torch.arange(size, device=device, dtype=torch.float32), indexing="ij"), dim=-1)
    y = coords[..., 0]  # (size, size)
    x = coords[..., 1]  # (size, size)
    img = torch.zeros(size, size, device=device, dtype=torch.float32)
    # Precompute all lattice positions
    a1 = torch.tensor([spacing, 0.0], device=device)
    a2 = torch.tensor([spacing * 0.5, spacing * (3**0.5) / 2], device=device)
    i_range = torch.arange(-1, size // spacing + 2, device=device, dtype=torch.float32)
    j_range = torch.arange(-1, size // spacing + 2, device=device, dtype=torch.float32)
    ii, jj = torch.meshgrid(i_range, j_range, indexing="ij")
    ii = ii.reshape(-1)
    jj = jj.reshape(-1)
    cx = ii * a1[0] + jj * a2[0]  # (N,)
    cy = ii * a1[1] + jj * a2[1]  # (N,)
    # Filter positions within bounds
    mask = (cx > -spacing) & (cx < size + spacing) & (cy > -spacing) & (cy < size + spacing)
    cx = cx[mask]
    cy = cy[mask]
    ii_filt = ii[mask]
    jj_filt = jj[mask]
    # Intensity variation (like mixed Z columns)
    intensity = 0.7 + 0.3 * (((ii_filt + jj_filt) % 3) == 0).float()
    # Vectorized Gaussian over lattice positions
    for k in range(len(cx)):
        img += intensity[k] * torch.exp(-((x - cx[k])**2 + (y - cy[k])**2) / (2 * sigma**2))
    img_np = img.cpu().numpy()
    # Add background and noise in NumPy (Poisson/normal noise unreliable on MPS)
    img_np += np.random.normal(0.08, 0.015, (size, size)).astype(np.float32)
    scan_noise = np.random.normal(0, 0.01, (size, 1)).astype(np.float32) * np.ones((1, size), dtype=np.float32)
    img_np += scan_noise
    return np.clip(img_np, 0, None).astype(np.float32)
def make_cubic_stem(size=256, spacing=20, sigma=2.5):
    """HAADF-STEM of cubic [001] zone axis."""
    coords = torch.stack(torch.meshgrid(torch.arange(size, device=device, dtype=torch.float32),
                                        torch.arange(size, device=device, dtype=torch.float32), indexing="ij"), dim=-1)
    y = coords[..., 0]
    x = coords[..., 1]
    img = torch.zeros(size, size, device=device, dtype=torch.float32)
    # Precompute all lattice positions (simple square grid)
    i_range = torch.arange(-1, size // spacing + 2, device=device, dtype=torch.float32)
    j_range = torch.arange(-1, size // spacing + 2, device=device, dtype=torch.float32)
    ii, jj = torch.meshgrid(i_range, j_range, indexing="ij")
    cx = (ii * spacing).reshape(-1)
    cy = (jj * spacing).reshape(-1)
    # Filter positions within bounds
    mask = (cx > -spacing) & (cx < size + spacing) & (cy > -spacing) & (cy < size + spacing)
    cx = cx[mask]
    cy = cy[mask]
    for k in range(len(cx)):
        img += 0.8 * torch.exp(-((x - cx[k])**2 + (y - cy[k])**2) / (2 * sigma**2))
    img_np = img.cpu().numpy()
    img_np += np.random.normal(0.08, 0.015, (size, size)).astype(np.float32)
    scan_noise = np.random.normal(0, 0.01, (size, 1)).astype(np.float32) * np.ones((1, size), dtype=np.float32)
    img_np += scan_noise
    return np.clip(img_np, 0, None).astype(np.float32)
def make_diffraction_pattern(size=256, spot_sigma=0.8):
    """Electron diffraction pattern with sharp Bragg spots on a hexagonal reciprocal lattice."""
    coords = torch.stack(torch.meshgrid(torch.arange(size, device=device, dtype=torch.float32),
                                        torch.arange(size, device=device, dtype=torch.float32), indexing="ij"), dim=-1)
    y = coords[..., 0]
    x = coords[..., 1]
    cx_center, cy_center = size // 2, size // 2
    img = torch.zeros(size, size, device=device, dtype=torch.float32)
    # Hexagonal reciprocal lattice
    a = 28  # spot spacing (px)
    g1 = torch.tensor([a, 0.0], device=device)
    g2 = torch.tensor([a * 0.5, a * (3**0.5) / 2], device=device)
    i_range = torch.arange(-6, 7, device=device, dtype=torch.float32)
    j_range = torch.arange(-6, 7, device=device, dtype=torch.float32)
    ii, jj = torch.meshgrid(i_range, j_range, indexing="ij")
    ii = ii.reshape(-1)
    jj = jj.reshape(-1)
    sx = cx_center + ii * g1[0] + jj * g2[0]
    sy = cy_center + ii * g1[1] + jj * g2[1]
    # Filter spots within image bounds
    in_bounds = (sx >= 0) & (sx < size) & (sy >= 0) & (sy < size)
    sx = sx[in_bounds]
    sy = sy[in_bounds]
    ii_filt = ii[in_bounds]
    jj_filt = jj[in_bounds]
    dist = torch.sqrt((sx - cx_center)**2 + (sy - cy_center)**2)
    # Intensity envelope: central beam bright, outer spots dimmer
    intensity = torch.exp(-dist**2 / (2 * (3 * a)**2))
    is_center = (ii_filt == 0) & (jj_filt == 0)
    intensity[is_center] = 1.0
    for k in range(len(sx)):
        img += intensity[k] * torch.exp(-((x - sx[k])**2 + (y - sy[k])**2) / (2 * spot_sigma**2))
    img_np = img.cpu().numpy()
    img_np += np.random.normal(0.02, 0.005, (size, size)).astype(np.float32)
    return np.clip(img_np, 0, None).astype(np.float32)
print(f"Generators ready (device={device})")
haadf = make_haadf_stem()
w1 = Mark2D(haadf, max_points=3)
w1
print(f"quantem.widget {quantem.widget.__version__}")
Generators ready (device=mps)
quantem.widget 0.4.0a3

2. Custom scale, dot size, max points#

Zoomed-in view with larger markers and more allowed selections.

[4]:
w2 = Mark2D(haadf, scale=2.0, dot_size=18, max_points=10)
w2
[4]:

3. Replace image with set_image()#

Switch between two different zone axes without creating a new widget. The cubic [001] pattern has a simple square lattice, while the hexagonal pattern above has alternating column intensities.

[5]:
cubic = make_cubic_stem()
w3 = Mark2D(haadf, scale=1.0, max_points=5)
w3
[5]:
[6]:
# Replace the hexagonal image with the cubic [001] zone axis
w3.set_image(cubic)
print("Image replaced: now showing cubic [001] zone axis")
Image replaced: now showing cubic [001] zone axis

4. Inspect widget state#

Use summary() to see a detailed breakdown of all widgets — image info, placed points, ROIs, display settings.

[7]:
for name, widget in [("Hexagonal", w1), ("Zoomed", w2), ("Cubic", w3)]:
    print(f"--- {name} ---")
    widget.summary()
    print()
--- Hexagonal ---
Mark2D
════════════════════════════════
Image:    256×256
Data:     min=0.01285  max=1.104  mean=0.1909  dtype=float32
Display:  gray | auto contrast | linear
Points:   0/3
Marker:   circle red  size=12px

--- Zoomed ---
Mark2D
════════════════════════════════
Image:    256×256  scale=2.0x
Data:     min=0.01285  max=1.104  mean=0.1909  dtype=float32
Display:  gray | auto contrast | linear
Points:   0/10
Marker:   circle red  size=18px

--- Cubic ---
Mark2D
════════════════════════════════
Image:    256×256
Data:     min=0.00792  max=0.9349  mean=0.1568  dtype=float32
Display:  gray | auto contrast | linear
Points:   0/5
Marker:   circle red  size=12px

5. Define lattice basis from 3 points#

Pick 3 atom columns on w1 above: an origin and two nearest neighbors. Then run this cell to compute lattice vectors u and v, plus the angle between them.

[8]:
points = w1.selected_points
if len(points) < 3:
    print("Click 3 atom columns on w1 above, then re-run this cell.")
else:
    origin = np.array([points[0]["row"], points[0]["col"]])
    p1 = np.array([points[1]["row"], points[1]["col"]])
    p2 = np.array([points[2]["row"], points[2]["col"]])
    u = p1 - origin
    v = p2 - origin
    angle = np.degrees(np.arccos(
        np.dot(u, v) / (np.linalg.norm(u) * np.linalg.norm(v))
    ))
    print(f"Origin: (row={origin[0]:.1f}, col={origin[1]:.1f})")
    print(f"u = ({u[0]:.1f}, {u[1]:.1f}), |u| = {np.linalg.norm(u):.1f} px")
    print(f"v = ({v[0]:.1f}, {v[1]:.1f}), |v| = {np.linalg.norm(v):.1f} px")
    print(f"Angle(u, v) = {angle:.1f} degrees")
    print(f"Expected for hexagonal: |u| ~ |v| ~ 18 px, angle ~ 60 degrees")
Click 3 atom columns on w1 above, then re-run this cell.

6. PyTorch tensor input#

Mark2D accepts both NumPy arrays and PyTorch tensors.

[9]:
haadf_tensor = torch.from_numpy(haadf)
print(f"Tensor shape: {haadf_tensor.shape}, dtype: {haadf_tensor.dtype}")
w4 = Mark2D(haadf_tensor, scale=1.5, dot_size=14, max_points=5)
w4
Tensor shape: torch.Size([256, 256]), dtype: torch.float32
[9]:

9. Snap-to-peak on a sharp diffraction pattern#

Snap-to-peak finds the nearest local intensity maximum within a search radius. This is most useful on images with sharp, well-separated peaks — like electron diffraction patterns with Bragg spots. Try it: Click anywhere near a Bragg spot. With snap enabled (green), your point jumps to the exact peak center. Toggle snap off to see the difference — points land exactly where you click instead.

[14]:
diffraction = make_diffraction_pattern()
# Snap enabled with 8px search radius — clicks jump to the nearest Bragg spot
w7 = Mark2D(
    diffraction,
    snap_enabled=True,
    snap_radius=8,
    max_points=10,
    dot_size=8,
    colormap="viridis",
    log_scale=True,
)
w7
[14]:
[15]:
w7.summary()
Mark2D
════════════════════════════════
Image:    256×256
Data:     min=0  max=1.019  mean=0.02315  dtype=float32
Display:  gray | auto contrast | log
Points:   0/10
Marker:   circle red  size=8px
Snap:     ON (radius=8 px)

Side-by-side: snap on vs snap off#

Same diffraction pattern, same settings — only snap differs. Click near the same Bragg spot in both images to compare precision.

[16]:
# Gallery: snap OFF (left) vs snap ON (right)
w8 = Mark2D(
    [diffraction, diffraction],
    ncols=2,
    max_points=5,
    dot_size=8,
    colormap="viridis",
    log_scale=True,
    snap_enabled=True,
    snap_radius=8,
    labels=["Snap OFF (toggle it off)", "Snap ON (default)"],
)
w8
[16]:

10. Save and load state#

All widget state — points, ROIs, profile lines, display settings — can be saved to a JSON file with save() and restored with the state parameter. This lets you resume analysis after a kernel restart or share exact results with a colleague.

[17]:
# Create a widget with pre-placed points and custom settings
w9 = Mark2D(
    haadf,
    points=[(36, 36), (54, 36), (45, 52)],
    snap_enabled=True,
    snap_radius=8,
    colormap="viridis",
    marker_shape="diamond",
    marker_color="#00bcd4",
    title="HAADF analysis",
    pixel_size_angstrom=1.5,
)
w9.add_roi(128, 128, shape="circle", radius=30)
w9
[17]:
[18]:
# Save all state to a JSON file
w9.save("haadf_analysis.json")
print("Saved to haadf_analysis.json")
Saved to haadf_analysis.json
[19]:
# Restore from file — same image, all state comes back
w10 = Mark2D(haadf, state="haadf_analysis.json")
print(f"Restored: {len(w10.selected_points)} points, {len(w10.roi_list)} ROIs")
print(f"Colormap: {w10.cmap}, title: {w10.title}")
w10
Restored: 3 points, 1 ROIs
Colormap: gray, title: HAADF analysis
[19]:
[20]:
# The JSON file is small and human-readable — great for version control
import json
from pathlib import Path
state = json.loads(Path("haadf_analysis.json").read_text())
print(json.dumps(state, indent=2))
{
  "metadata_version": "1.0",
  "widget_name": "Mark2D",
  "widget_version": "0.4.0a3",
  "state": {
    "selected_points": [
      {
        "row": 36,
        "col": 36,
        "shape": "circle",
        "color": "#f44336"
      },
      {
        "row": 54,
        "col": 36,
        "shape": "triangle",
        "color": "#4caf50"
      },
      {
        "row": 45,
        "col": 52,
        "shape": "square",
        "color": "#2196f3"
      }
    ],
    "roi_list": [
      {
        "id": 0,
        "shape": "circle",
        "row": 128,
        "col": 128,
        "radius": 30,
        "width": 60,
        "height": 40,
        "color": "#0f0",
        "opacity": 0.8
      }
    ],
    "profile_line": [],
    "selected_idx": 0,
    "marker_shape": "diamond",
    "marker_color": "#00bcd4",
    "dot_size": 12,
    "max_points": 10,
    "marker_border": 2,
    "marker_opacity": 1.0,
    "label_size": 0,
    "label_color": "",
    "snap_enabled": true,
    "snap_radius": 8,
    "cmap": "gray",
    "auto_contrast": true,
    "log_scale": false,
    "show_fft": false,
    "show_stats": true,
    "show_controls": true,
    "disabled_tools": [],
    "hidden_tools": [],
    "percentile_low": 2.0,
    "percentile_high": 98.0,
    "title": "HAADF analysis",
    "pixel_size": 0.0,
    "scale": 1.0,
    "canvas_size": 0
  }
}
[21]:
# Clean up the saved file
p = Path("haadf_analysis.json")
if p.exists():
    p.unlink()

11. Disable tools for shared notebooks#

Use ergonomic disable_* flags to lock selected controls so collaborators cannot modify them in shared notebooks.

[22]:
# Disable selected editing/control groups
w11 = Mark2D(
    haadf,
    disable_points=True,
    disable_roi=True,
    disable_display=True,
)
w11
[22]:
[23]:
w11.summary()
Mark2D
════════════════════════════════
Image:    256×256
Data:     min=0.01285  max=1.104  mean=0.1909  dtype=float32
Display:  gray | auto contrast | linear
Points:   0/10
Marker:   circle red  size=12px
Locked:   points, roi, display
[24]:
# Fully lock the widget as a read-only viewer
w12 = Mark2D(haadf, disable_all=True)
print(w12)
Mark2D(256×256, pts=0)
[25]:
w12.summary()
Mark2D
════════════════════════════════
Image:    256×256
Data:     min=0.01285  max=1.104  mean=0.1909  dtype=float32
Display:  gray | auto contrast | linear
Points:   0/10
Marker:   circle red  size=12px
Locked:   all

Hide selected controls#

Hide controls entirely while keeping a clean viewer layout.

[26]:
w13 = Mark2D(
    haadf,
    hide_display=True,
    hide_export=True,
    hide_marker_style=True,
)
w13
[26]:
[27]:
w13.summary()
Mark2D
════════════════════════════════
Image:    256×256
Data:     min=0.01285  max=1.104  mean=0.1909  dtype=float32
Display:  gray | auto contrast | linear
Points:   0/10
Marker:   circle red  size=12px
Hidden:   display, marker_style, export