[1]:
# Install in Google Colab
try:
import google.colab
!pip install -q -i https://test.pypi.org/simple/ --extra-index-url https://pypi.org/simple/ quantem-widget
except ImportError:
pass # Not in Colab, skip
[2]:
try:
%load_ext autoreload
%autoreload 2
%env ANYWIDGET_HMR=1
except Exception:
pass # autoreload unavailable (Colab Python 3.12+)
env: ANYWIDGET_HMR=1
# Mark2D — All Features Comprehensive demo of every Mark2D capability: basic atom picking, custom scale/dot size, image replacement, coordinate retrieval, lattice basis definition, PyTorch tensor input, gallery mode, snap-to-peak, and state save/load for reproducible analysis.
1. Basic HAADF-STEM atom picking#
Hexagonal lattice simulating a [110] zone axis. Click on bright atom columns to select positions.
[3]:
import numpy as np
import torch
import quantem.widget
from quantem.widget import Mark2D
device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")
def make_haadf_stem(size=256, spacing=18, sigma=2.8):
"""HAADF-STEM image with atomic columns on a hexagonal lattice."""
coords = torch.stack(torch.meshgrid(torch.arange(size, device=device, dtype=torch.float32),
torch.arange(size, device=device, dtype=torch.float32), indexing="ij"), dim=-1)
y = coords[..., 0] # (size, size)
x = coords[..., 1] # (size, size)
img = torch.zeros(size, size, device=device, dtype=torch.float32)
# Precompute all lattice positions
a1 = torch.tensor([spacing, 0.0], device=device)
a2 = torch.tensor([spacing * 0.5, spacing * (3**0.5) / 2], device=device)
i_range = torch.arange(-1, size // spacing + 2, device=device, dtype=torch.float32)
j_range = torch.arange(-1, size // spacing + 2, device=device, dtype=torch.float32)
ii, jj = torch.meshgrid(i_range, j_range, indexing="ij")
ii = ii.reshape(-1)
jj = jj.reshape(-1)
cx = ii * a1[0] + jj * a2[0] # (N,)
cy = ii * a1[1] + jj * a2[1] # (N,)
# Filter positions within bounds
mask = (cx > -spacing) & (cx < size + spacing) & (cy > -spacing) & (cy < size + spacing)
cx = cx[mask]
cy = cy[mask]
ii_filt = ii[mask]
jj_filt = jj[mask]
# Intensity variation (like mixed Z columns)
intensity = 0.7 + 0.3 * (((ii_filt + jj_filt) % 3) == 0).float()
# Vectorized Gaussian over lattice positions
for k in range(len(cx)):
img += intensity[k] * torch.exp(-((x - cx[k])**2 + (y - cy[k])**2) / (2 * sigma**2))
img_np = img.cpu().numpy()
# Add background and noise in NumPy (Poisson/normal noise unreliable on MPS)
img_np += np.random.normal(0.08, 0.015, (size, size)).astype(np.float32)
scan_noise = np.random.normal(0, 0.01, (size, 1)).astype(np.float32) * np.ones((1, size), dtype=np.float32)
img_np += scan_noise
return np.clip(img_np, 0, None).astype(np.float32)
def make_cubic_stem(size=256, spacing=20, sigma=2.5):
"""HAADF-STEM of cubic [001] zone axis."""
coords = torch.stack(torch.meshgrid(torch.arange(size, device=device, dtype=torch.float32),
torch.arange(size, device=device, dtype=torch.float32), indexing="ij"), dim=-1)
y = coords[..., 0]
x = coords[..., 1]
img = torch.zeros(size, size, device=device, dtype=torch.float32)
# Precompute all lattice positions (simple square grid)
i_range = torch.arange(-1, size // spacing + 2, device=device, dtype=torch.float32)
j_range = torch.arange(-1, size // spacing + 2, device=device, dtype=torch.float32)
ii, jj = torch.meshgrid(i_range, j_range, indexing="ij")
cx = (ii * spacing).reshape(-1)
cy = (jj * spacing).reshape(-1)
# Filter positions within bounds
mask = (cx > -spacing) & (cx < size + spacing) & (cy > -spacing) & (cy < size + spacing)
cx = cx[mask]
cy = cy[mask]
for k in range(len(cx)):
img += 0.8 * torch.exp(-((x - cx[k])**2 + (y - cy[k])**2) / (2 * sigma**2))
img_np = img.cpu().numpy()
img_np += np.random.normal(0.08, 0.015, (size, size)).astype(np.float32)
scan_noise = np.random.normal(0, 0.01, (size, 1)).astype(np.float32) * np.ones((1, size), dtype=np.float32)
img_np += scan_noise
return np.clip(img_np, 0, None).astype(np.float32)
def make_diffraction_pattern(size=256, spot_sigma=0.8):
"""Electron diffraction pattern with sharp Bragg spots on a hexagonal reciprocal lattice."""
coords = torch.stack(torch.meshgrid(torch.arange(size, device=device, dtype=torch.float32),
torch.arange(size, device=device, dtype=torch.float32), indexing="ij"), dim=-1)
y = coords[..., 0]
x = coords[..., 1]
cx_center, cy_center = size // 2, size // 2
img = torch.zeros(size, size, device=device, dtype=torch.float32)
# Hexagonal reciprocal lattice
a = 28 # spot spacing (px)
g1 = torch.tensor([a, 0.0], device=device)
g2 = torch.tensor([a * 0.5, a * (3**0.5) / 2], device=device)
i_range = torch.arange(-6, 7, device=device, dtype=torch.float32)
j_range = torch.arange(-6, 7, device=device, dtype=torch.float32)
ii, jj = torch.meshgrid(i_range, j_range, indexing="ij")
ii = ii.reshape(-1)
jj = jj.reshape(-1)
sx = cx_center + ii * g1[0] + jj * g2[0]
sy = cy_center + ii * g1[1] + jj * g2[1]
# Filter spots within image bounds
in_bounds = (sx >= 0) & (sx < size) & (sy >= 0) & (sy < size)
sx = sx[in_bounds]
sy = sy[in_bounds]
ii_filt = ii[in_bounds]
jj_filt = jj[in_bounds]
dist = torch.sqrt((sx - cx_center)**2 + (sy - cy_center)**2)
# Intensity envelope: central beam bright, outer spots dimmer
intensity = torch.exp(-dist**2 / (2 * (3 * a)**2))
is_center = (ii_filt == 0) & (jj_filt == 0)
intensity[is_center] = 1.0
for k in range(len(sx)):
img += intensity[k] * torch.exp(-((x - sx[k])**2 + (y - sy[k])**2) / (2 * spot_sigma**2))
img_np = img.cpu().numpy()
img_np += np.random.normal(0.02, 0.005, (size, size)).astype(np.float32)
return np.clip(img_np, 0, None).astype(np.float32)
print(f"Generators ready (device={device})")
haadf = make_haadf_stem()
w1 = Mark2D(haadf, max_points=3)
w1
print(f"quantem.widget {quantem.widget.__version__}")
Generators ready (device=mps)
quantem.widget 0.4.0a3
2. Custom scale, dot size, max points#
Zoomed-in view with larger markers and more allowed selections.
[4]:
w2 = Mark2D(haadf, scale=2.0, dot_size=18, max_points=10)
w2
[4]:
3. Replace image with set_image()#
Switch between two different zone axes without creating a new widget. The cubic [001] pattern has a simple square lattice, while the hexagonal pattern above has alternating column intensities.
[5]:
cubic = make_cubic_stem()
w3 = Mark2D(haadf, scale=1.0, max_points=5)
w3
[5]:
[6]:
# Replace the hexagonal image with the cubic [001] zone axis
w3.set_image(cubic)
print("Image replaced: now showing cubic [001] zone axis")
Image replaced: now showing cubic [001] zone axis
4. Inspect widget state#
Use summary() to see a detailed breakdown of all widgets — image info, placed points, ROIs, display settings.
[7]:
for name, widget in [("Hexagonal", w1), ("Zoomed", w2), ("Cubic", w3)]:
print(f"--- {name} ---")
widget.summary()
print()
--- Hexagonal ---
Mark2D
════════════════════════════════
Image: 256×256
Data: min=0.01285 max=1.104 mean=0.1909 dtype=float32
Display: gray | auto contrast | linear
Points: 0/3
Marker: circle red size=12px
--- Zoomed ---
Mark2D
════════════════════════════════
Image: 256×256 scale=2.0x
Data: min=0.01285 max=1.104 mean=0.1909 dtype=float32
Display: gray | auto contrast | linear
Points: 0/10
Marker: circle red size=18px
--- Cubic ---
Mark2D
════════════════════════════════
Image: 256×256
Data: min=0.00792 max=0.9349 mean=0.1568 dtype=float32
Display: gray | auto contrast | linear
Points: 0/5
Marker: circle red size=12px
5. Define lattice basis from 3 points#
Pick 3 atom columns on w1 above: an origin and two nearest neighbors. Then run this cell to compute lattice vectors u and v, plus the angle between them.
[8]:
points = w1.selected_points
if len(points) < 3:
print("Click 3 atom columns on w1 above, then re-run this cell.")
else:
origin = np.array([points[0]["row"], points[0]["col"]])
p1 = np.array([points[1]["row"], points[1]["col"]])
p2 = np.array([points[2]["row"], points[2]["col"]])
u = p1 - origin
v = p2 - origin
angle = np.degrees(np.arccos(
np.dot(u, v) / (np.linalg.norm(u) * np.linalg.norm(v))
))
print(f"Origin: (row={origin[0]:.1f}, col={origin[1]:.1f})")
print(f"u = ({u[0]:.1f}, {u[1]:.1f}), |u| = {np.linalg.norm(u):.1f} px")
print(f"v = ({v[0]:.1f}, {v[1]:.1f}), |v| = {np.linalg.norm(v):.1f} px")
print(f"Angle(u, v) = {angle:.1f} degrees")
print(f"Expected for hexagonal: |u| ~ |v| ~ 18 px, angle ~ 60 degrees")
Click 3 atom columns on w1 above, then re-run this cell.
6. PyTorch tensor input#
Mark2D accepts both NumPy arrays and PyTorch tensors.
[9]:
haadf_tensor = torch.from_numpy(haadf)
print(f"Tensor shape: {haadf_tensor.shape}, dtype: {haadf_tensor.dtype}")
w4 = Mark2D(haadf_tensor, scale=1.5, dot_size=14, max_points=5)
w4
Tensor shape: torch.Size([256, 256]), dtype: torch.float32
[9]:
7. Gallery mode — pick points across multiple images#
Pass a list of images to pick points on each independently. Click an unselected image to select it. Only the selected image allows point placement.
[10]:
# Gallery with 3 different crystal structures
hexagonal = make_haadf_stem(size=128, spacing=18)
cubic_128 = make_cubic_stem(size=128, spacing=20)
# Ring pattern (simulated amorphous diffraction)
yy, xx = torch.meshgrid(torch.arange(128, device=device, dtype=torch.float32),
torch.arange(128, device=device, dtype=torch.float32), indexing="ij")
r = torch.sqrt((xx - 64)**2 + (yy - 64)**2)
ring = (torch.exp(-(r - 40)**2 / 20) + 0.5 * torch.exp(-(r - 20)**2 / 10)).cpu().numpy().astype(np.float32)
w5 = Mark2D(
[hexagonal, cubic_128, ring],
ncols=3,
max_points=5,
labels=["Hex [110]", "Cubic [001]", "Ring"],
)
w5
[10]:
[11]:
w5.summary()
Mark2D
════════════════════════════════
Image: 3×128×128 (3 cols)
Data: min=0 max=1.131 mean=0.1702 dtype=float32
Display: gray | auto contrast | linear
Points [Hex [110]]: 0/5
Points [Cubic [001]]: 0/5
Points [Ring]: 0/5
Marker: circle red size=12px
8. Gallery with torch tensors#
[12]:
# Gallery with torch tensors
t1 = torch.from_numpy(hexagonal)
t2 = torch.from_numpy(cubic_128)
w6 = Mark2D([t1, t2], ncols=2, max_points=4, labels=["Hex (torch)", "Cubic (torch)"])
w6
[12]:
[13]:
w6.summary()
Mark2D
════════════════════════════════
Image: 2×128×128 (2 cols)
Data: min=0.007051 max=1.131 mean=0.1838 dtype=float32
Display: gray | auto contrast | linear
Points [Hex (torch)]: 0/4
Points [Cubic (torch)]: 0/4
Marker: circle red size=12px
9. Snap-to-peak on a sharp diffraction pattern#
Snap-to-peak finds the nearest local intensity maximum within a search radius. This is most useful on images with sharp, well-separated peaks — like electron diffraction patterns with Bragg spots. Try it: Click anywhere near a Bragg spot. With snap enabled (green), your point jumps to the exact peak center. Toggle snap off to see the difference — points land exactly where you click instead.
[14]:
diffraction = make_diffraction_pattern()
# Snap enabled with 8px search radius — clicks jump to the nearest Bragg spot
w7 = Mark2D(
diffraction,
snap_enabled=True,
snap_radius=8,
max_points=10,
dot_size=8,
colormap="viridis",
log_scale=True,
)
w7
[14]:
[15]:
w7.summary()
Mark2D
════════════════════════════════
Image: 256×256
Data: min=0 max=1.019 mean=0.02315 dtype=float32
Display: gray | auto contrast | log
Points: 0/10
Marker: circle red size=8px
Snap: ON (radius=8 px)
Side-by-side: snap on vs snap off#
Same diffraction pattern, same settings — only snap differs. Click near the same Bragg spot in both images to compare precision.
[16]:
# Gallery: snap OFF (left) vs snap ON (right)
w8 = Mark2D(
[diffraction, diffraction],
ncols=2,
max_points=5,
dot_size=8,
colormap="viridis",
log_scale=True,
snap_enabled=True,
snap_radius=8,
labels=["Snap OFF (toggle it off)", "Snap ON (default)"],
)
w8
[16]:
10. Save and load state#
All widget state — points, ROIs, profile lines, display settings — can be saved to a JSON file with save() and restored with the state parameter. This lets you resume analysis after a kernel restart or share exact results with a colleague.
[17]:
# Create a widget with pre-placed points and custom settings
w9 = Mark2D(
haadf,
points=[(36, 36), (54, 36), (45, 52)],
snap_enabled=True,
snap_radius=8,
colormap="viridis",
marker_shape="diamond",
marker_color="#00bcd4",
title="HAADF analysis",
pixel_size_angstrom=1.5,
)
w9.add_roi(128, 128, shape="circle", radius=30)
w9
[17]:
[18]:
# Save all state to a JSON file
w9.save("haadf_analysis.json")
print("Saved to haadf_analysis.json")
Saved to haadf_analysis.json
[19]:
# Restore from file — same image, all state comes back
w10 = Mark2D(haadf, state="haadf_analysis.json")
print(f"Restored: {len(w10.selected_points)} points, {len(w10.roi_list)} ROIs")
print(f"Colormap: {w10.cmap}, title: {w10.title}")
w10
Restored: 3 points, 1 ROIs
Colormap: gray, title: HAADF analysis
[19]:
[20]:
# The JSON file is small and human-readable — great for version control
import json
from pathlib import Path
state = json.loads(Path("haadf_analysis.json").read_text())
print(json.dumps(state, indent=2))
{
"metadata_version": "1.0",
"widget_name": "Mark2D",
"widget_version": "0.4.0a3",
"state": {
"selected_points": [
{
"row": 36,
"col": 36,
"shape": "circle",
"color": "#f44336"
},
{
"row": 54,
"col": 36,
"shape": "triangle",
"color": "#4caf50"
},
{
"row": 45,
"col": 52,
"shape": "square",
"color": "#2196f3"
}
],
"roi_list": [
{
"id": 0,
"shape": "circle",
"row": 128,
"col": 128,
"radius": 30,
"width": 60,
"height": 40,
"color": "#0f0",
"opacity": 0.8
}
],
"profile_line": [],
"selected_idx": 0,
"marker_shape": "diamond",
"marker_color": "#00bcd4",
"dot_size": 12,
"max_points": 10,
"marker_border": 2,
"marker_opacity": 1.0,
"label_size": 0,
"label_color": "",
"snap_enabled": true,
"snap_radius": 8,
"cmap": "gray",
"auto_contrast": true,
"log_scale": false,
"show_fft": false,
"show_stats": true,
"show_controls": true,
"disabled_tools": [],
"hidden_tools": [],
"percentile_low": 2.0,
"percentile_high": 98.0,
"title": "HAADF analysis",
"pixel_size": 0.0,
"scale": 1.0,
"canvas_size": 0
}
}
[21]:
# Clean up the saved file
p = Path("haadf_analysis.json")
if p.exists():
p.unlink()