GPU optimization
Objective: Reduce ptychography reconstruction time from 0.11s to < 0.01s per iteration on a single L40s GPU, achieving 1 second for 100 iterations. On more powerful GPUs like H100, this could be even faster due to higher memory bandwidth (3.0-3.3 TB/s vs L40S’s 864 GB/s).
Current performance: 0.11s per iteration
512×512 object (amplitude + phase)
80×80 probe
256×256 diffraction patterns
4,096 scan positions per iteration
Target performance: < 0.01s per iteration (11× speedup required)
Strategy: Profile to identify bottlenecks, write custom FP16 CUDA kernels, fuse operations across FFT boundaries.
GPU memory bandwidth reference:
L40S: 864 GB/s
A100 80GB: 2.0 TB/s
H100: 3.0-3.3 TB/s
This document covers the actual implementation with measured profiling data. For CUDA programming fundamentals, see CUDA tutorials.
PyTorch optimization techniques
This section covers general PyTorch optimization patterns used throughout the ptychography implementation. Understanding these techniques is essential before diving into the custom CUDA kernels.
Using batch gather to obtain object patches (foundational technique, used once)
In ptychography, we need to extract 4,096 overlapping patches (128×128 pixels each, matching the probe size) from a large object (512×512 pixels) at different scan positions.
Visualization:
Object (512×512): Extract 4,096 patches (128×128):
┌────────────────────┐ ┌──────┐ ┌──────┐ ┌──────┐
│ │ │Patch │ │Patch │ │Patch │
│ ┌──────┐ │ ───→ │ 1 │ │ 2 │ │ 3 │
│ │ │ │ └──────┘ └──────┘ └──────┘
│ └──────┘ │ ...
│ ┌──────┐ │ ┌──────┐
│ │ │ │ │Patch │
│ └──────┘ │ │ 4096 │
└────────────────────┘ └──────┘
Each patch: 128×128 pixels at position (y_i, x_i)
Naive approach with Python loops (extremely slow)
# Bad: Sequential extraction with Python loop
patches = []
for i in range(4096):
y_start, x_start = scan_positions[i]
patch = object[y_start:y_start+128, x_start:x_start+128]
patches.append(patch)
patches = torch.stack(patches) # [4096, 128, 128]
# Problem: 4,096 separate CPU→GPU→CPU round trips
# Time: ~500ms on T4 GPU
Why this is slow:
Each iteration launches a separate GPU kernel
Python loop prevents vectorization
Overhead: 4,096 × ~0.12ms = 492ms wasted
Optimized approach with batch gather (fast)
import torch
# Setup: precompute indices once
object = torch.randn(512, 512).cuda() # Large 2D array on GPU
scan_positions = torch.tensor([...]).cuda() # [4096, 2] (y, x) positions
# Create coordinate grids for patch extraction
patch_size = 128 # Must match probe size
grid_y = torch.arange(patch_size).cuda()
grid_x = torch.arange(patch_size).cuda()
grid_y, grid_x = torch.meshgrid(grid_y, grid_x, indexing='ij')
# Broadcast scan positions + local grid
y_indices = scan_positions[:, 0:1, None] + grid_y[None, :, :] # [4096, 128, 128]
x_indices = scan_positions[:, 1:2, None] + grid_x[None, :, :] # [4096, 128, 128]
# Single batched gather operation
patches = object[y_indices, x_indices] # [4096, 128, 128]
# Time: ~0.8ms on T4 GPU (625× faster!)
How it works:
Step 1: Create local coordinate grid (128×128)
┌─────────────┐
│ 0 1 2 ... │
│ 1 2 3 ... │
│ 2 3 4 ... │
│ ... │
└─────────────┘
Step 2: Add scan position offset to each grid
Position 1: (y=50, x=100) Position 2: (y=75, x=150)
┌─────────────┐ ┌─────────────┐
│ 50 51 52... │ │ 75 76 77... │
│ 51 52 53... │ │ 76 77 78... │
│ ... │ │ ... │
└─────────────┘ └─────────────┘
Step 3: Use advanced indexing to gather all 4,096 patches in parallel
object[y_indices, x_indices] → Single GPU kernel launch
Performance comparison
Approach |
Time (ms) |
Kernel Launches |
Speedup |
|---|---|---|---|
Python loop |
~500ms |
4,096 |
1× |
Batch gather |
~0.8ms |
1 |
625× |
Key insights
Avoid Python loops over GPU operations - Always vectorize when possible
Advanced indexing is highly optimized - PyTorch’s
tensor[indices]uses efficient GPU gather kernelsPrecompute coordinate grids - Create indices once, reuse across iterations
Broadcasting eliminates loops - Use broadcasting to generate all indices simultaneously
Warning
Gotcha: Make sure indices are on the same device as the tensor. Mixed CPU/GPU indices force expensive device transfers:
# Bad: indices on CPU, object on GPU
patches = object_gpu[indices_cpu] # Transfers indices CPU→GPU every call
# Good: both on GPU
patches = object_gpu[indices_gpu] # Fast, no transfers
When to write custom CUDA kernels
PyTorch core developers write custom CUDA kernels for operations not covered by cuBLAS or cuDNN. These include element-wise operations, reductions, indexing, and specialized functions. Users can also write custom CUDA kernels using PyTorch’s C++ extension API or tools like Triton (which compiles Python-like code to GPU kernels).
Most PyTorch users never write kernels since the built-in operations cover typical neural network needs. However, custom kernels enable specialized optimizations for novel architectures or unique algorithms.
When custom CUDA justifies the complexity:
For specialized algorithms where memory access patterns matter more than ease of development, CUDA offers performance gains that justify implementation complexity. PyTorch’s general-purpose kernels cannot provide precise control over GPU resources that production systems need.
Why use custom CUDA when PyTorch is already fast?
PyTorch provides high-level operations but custom CUDA kernels give fine-grained control over memory access patterns, thread synchronization, and data layout. Optimizations like memory tiling can achieve 3-8× speedups. Production systems like FlashAttention (GPT-4), xFormers (Llama), DeepSpeed (Microsoft), and vLLM use custom CUDA C++ for maximum performance.
Mixing PyTorch operations with custom CUDA kernels
You can mix PyTorch operations with custom CUDA kernels. For example, PyTorch handles FFTs using cuFFT while your kernel implements specialized phase retrieval updates. PyTorch tensors expose raw GPU memory pointers that your CUDA kernel can access directly through the C++ extension API.
This hybrid approach lets you leverage PyTorch’s optimized FFT implementation while adding custom operations that aren’t built-in, like constraint projection or error reduction.
import torch
from torch.utils.cpp_extension import load
# Load custom CUDA kernel
custom_kernel = load(
name='custom_ops',
sources=['custom_kernel.cu'],
extra_cuda_cflags=['-O3']
)
# Mix PyTorch and custom operations
x = torch.randn(512, 512, dtype=torch.complex64).cuda()
# Use PyTorch's optimized FFT
x_fft = torch.fft.fft2(x)
# Apply custom CUDA kernel
x_processed = custom_kernel.apply_constraint(x_fft)
# Back to PyTorch for inverse FFT
result = torch.fft.ifft2(x_processed)
Where are PyTorch CUDA kernels in source code?
PyTorch CUDA kernels live in aten/src/ATen/native/cuda/ for core operations and aten/src/ATen/native/cudnn/ for cuDNN wrappers. Files like BinaryOps.cu contain element-wise operations and Activation.cu has functions like ReLU. Each .cu file contains CUDA C++ code with kernel launches and device functions organized by operation type. You can browse them at pytorch/pytorch on GitHub.
PyTorch vs CuPy: When to use automatic differentiation
The problem:
CuPy handles standard GPU operations (FFT, array math, linear algebra) but lacks automatic differentiation. For optimization problems including loss function minimization, parameter fitting, and deep learning, you need autodiff. Writing gradients manually is error-prone and slow. Additionally, cross-platform support for NVIDIA, Mac, and AMD is important for many workflows.
The solution:
PyTorch provides GPU arrays with automatic differentiation (autodiff). You define your computation like loss = (model(x) - y)**2 then call loss.backward() where PyTorch computes all gradients automatically. GPU-accelerated optimizers update parameters efficiently. PyTorch works on NVIDIA (CUDA), Mac (Metal), and AMD (ROCm).
When to use PyTorch vs CuPy:
Use PyTorch for: Deep learning, training neural networks on diffraction data, parameter optimization, gradient-based fitting, Mac compatibility
Use CuPy for: Pure array operations (FFT, virtual imaging) on NVIDIA GPUs without needing gradients
Performance comparison example
256×256 4D-STEM FFT on RTX 4090:
Framework |
Time (seconds) |
Notes |
|---|---|---|
CuPy |
0.8s |
cuFFT directly |
PyTorch |
0.85s |
cuFFT through PyTorch |
NumPy (CPU) |
45s |
56× slower |
Both call cuFFT underneath with nearly identical speed, so choose by features not performance.
Background: Ptychography reconstruction
Ptychography reconstructs a high-resolution complex-valued object \(O(x,y)\) and probe \(P(x,y)\) from diffraction intensity measurements. The object and probe are both trainable parameters updated via gradient descent.
Problem setup:
Object: Complex-valued transmission function at each pixel (amplitude and phase)
Probe: Complex-valued electron beam wavefront (contains aberrations)
Measurements: Diffraction intensity patterns at detector (magnitude only, phase lost)
Goal: Recover object and probe from intensity-only measurements
Forward model: Seven computational steps per iteration
Step 1: Generate probe from aberration parameters (k-space → real space)
The probe starts in reciprocal space with aberration function:
where \(k = \sqrt{k_x^2 + k_y^2}\) is the spatial frequency, \(\Delta f\) is defocus, \(C_s\) is spherical aberration, and \(A_{astig}\) is astigmatism magnitude.
The probe in k-space is:
where \(A(k)\) is the aperture function (sigmoid for soft edges).
Transform to real space via inverse FFT:
Step 2: Construct object from amplitude and phase
Step 3: Extract object patches and compute exit waves
For each scan position \(j\):
Step 4: Propagate to detector via FFT
Step 5: Compute predicted intensity
Step 6: Loss function
where \(s_j = \text{counts} / \text{mean}(I_j)\) is the per-position scaling factor.
Step 7: Backward pass and optimization
PyTorch computes gradients via autograd and updates all trainable parameters using the Adam optimizer:
Note
Caveat: Autograd overhead is significant. For 4,096 scan positions, PyTorch builds a computation graph with thousands of nodes, costing ~50% of wall-clock time. Custom CUDA kernels bypass autograd for the forward pass, then use PyTorch’s built-in backward implementations.
Trainable parameters being optimized
The object is stored as two separate real-valued arrays (amplitude and phase) but combined into complex-valued form during forward pass:
Why separate storage? Gradients flow better through separate amplitude and phase parameters. PyTorch’s autograd computes \(\frac{\partial \mathcal{L}}{\partial A}\) and \(\frac{\partial \mathcal{L}}{\partial \phi}\) independently, then updates each via Adam. If we stored the complex object directly, gradients would couple real/imaginary parts in a less physically meaningful way.
Parameter (stored) |
Physical meaning |
Shape |
Dtype |
|---|---|---|---|
Object parameters (stored separately) |
|||
|
Transmission amplitude (0-1 range) |
|
float32 |
|
Phase shift in radians |
|
float32 |
Computed during forward |
|||
|
\(A \cdot e^{i\phi}\) (not stored, computed on-the-fly) |
|
complex64 |
Probe aberration parameters |
|||
|
Defocus distance in nm (C1) |
|
float32 |
|
Spherical aberration in mm (C3) |
|
float32 |
|
Astigmatism magnitude in nm |
|
float32 |
|
Astigmatism orientation in radians |
|
float32 |
|
Aperture edge smoothness parameter |
|
float32 |
Total trainable parameters: 524,293
Object: 512×512×2 = 524,288 parameters (amplitude + phase stored separately)
Probe aberrations: 5 parameters (defocus, Cs, astigmatism magnitude, astigmatism angle, aperture smoothness)
Optimization strategy: Two separate Adam optimizers with fused kernels update object and probe parameters independently after each iteration.
Warning
Gotcha: Object is stored as two separate real-valued arrays (amplitude and phase) but must be combined into complex form during forward pass. This separation enables better gradient flow but requires careful handling to avoid creating unnecessary intermediate tensors.
What stays in PyTorch vs what moves to custom CUDA
Keep in PyTorch (already optimal or difficult to replace):
FFT operations:
torch.fft.fft2()andtorch.fft.ifft2()Uses cuFFT internally (NVIDIA’s highly optimized library)
Cannot be beaten with custom implementation
Accounts for 49.4% of GPU time (will remain constant)
Adam optimizer:
torch.optim.Adam(fused=True)Single CUDA kernel with numerically stable updates
Well-tested, includes gradient clipping and weight decay
Replacing offers minimal gains
Backward pass: PyTorch’s automatic differentiation
Highly optimized for common operations
Maintains numerical stability
Custom backward kernels offer ~20% gains but add complexity
Move to custom CUDA kernels (high-impact optimizations):
Probe generation: \(\chi(k) → \tilde{P}(k)\) preparation before IFFT
Fuses defocus, spherical aberration, astigmatism calculations
Runs 1× per iteration (low priority)
Exit wave computation: \(O[indices] \cdot P\)
Fuses tile gather + amplitude/phase → complex + complex multiply
Runs 4,096× per iteration (CRITICAL)
FP16 implementation uses Tensor Cores for 2-3× speedup
Intensity and loss: \(|\Psi|^2 → \text{scaling} → \text{loss}\)
Fuses magnitude squared, per-batch scaling, loss reduction
Runs 4,096× per iteration (IMPORTANT)
Eliminates 6 separate kernel launches
Note
Key insight: With 4,096 scan positions per iteration, operations that run once per position dominate the total time. Prioritize optimizing exit wave and intensity loss over probe generation.
The actual implementation
Here’s the complete code being optimized. Current performance: 0.11s per iteration
Aberration-parameterized probe model
import torch
import torch.nn as nn
import numpy as np
class AberrationProbe(nn.Module):
"""Physics-informed probe with aberration parameterization."""
def __init__(self, size, pixel_size, wavelength, convergence_angle, device='cuda'):
super().__init__()
self.size = size
self.pixel_size = pixel_size
self.wavelength = wavelength
self.convergence_angle = convergence_angle
self.device = device
# Create k-space grids - registered as buffers (no gradient retention)
kx = torch.fft.fftfreq(size, d=pixel_size, device=device)
ky = torch.fft.fftfreq(size, d=pixel_size, device=device)
kx, ky = torch.meshgrid(kx, ky, indexing='ij')
self.register_buffer('k', torch.sqrt(kx**2 + ky**2))
self.register_buffer('k_angle', torch.atan2(ky, kx))
# Trainable aberration coefficients
self.defocus = nn.Parameter(torch.tensor(50.0)) # nm (C1)
self.Cs = nn.Parameter(torch.tensor(1.0)) # mm (C3)
self.astig_mag = nn.Parameter(torch.tensor(10.0)) # nm
self.astig_angle = nn.Parameter(torch.tensor(0.0)) # radians
self.aperture_smooth = nn.Parameter(torch.tensor(0.1))
def compute_aberration_function(self):
"""χ(k) = π λ Δf k² + (1/2) π λ³ Cs k⁴ + astigmatism"""
k = self.k
wavelength = self.wavelength
Cs_nm = self.Cs * 1e6 # mm → nm
# Defocus term
chi_defocus = np.pi * wavelength * self.defocus * k**2
# Spherical aberration term
chi_Cs = 0.5 * np.pi * wavelength**3 * Cs_nm * k**4
# Astigmatism
chi_astig = (np.pi * wavelength * self.astig_mag * k**2 *
torch.cos(2 * (self.k_angle - self.astig_angle)))
return chi_defocus + chi_Cs + chi_astig
def compute_aperture(self):
"""Soft aperture function with sigmoid rolloff."""
k_max = self.convergence_angle / self.wavelength
smoothness = torch.abs(self.aperture_smooth) + 0.01
return torch.sigmoid((k_max - self.k) / (k_max * smoothness))
def forward(self):
"""Generate probe: P = IFFT[A(k) × exp(iχ(k))]"""
chi = self.compute_aberration_function()
aperture = self.compute_aperture()
# Probe in k-space
probe_k = aperture * torch.exp(1j * chi)
# Transform to real space
probe_real = torch.fft.ifftshift(torch.fft.ifft2(probe_k))
# Normalize
probe_real = probe_real / torch.sqrt(torch.sum(torch.abs(probe_real)**2))
return probe_real
def get_probe_crop(self, crop_size):
"""Get cropped probe WITH gradients for optimization."""
probe_full = self.forward()
center = self.size // 2
half = crop_size // 2
return probe_full[center-half:center+half, center-half:center+half]
Physics-informed ptychography model
class PhysicsInformedPtychography(nn.Module):
"""Ptychography model with batched operations."""
def __init__(self, object_size, probe_size, diffraction_size, scan_positions,
pixel_size, wavelength, convergence_angle, device='cuda'):
super().__init__()
self.object_size = object_size
self.probe_size = probe_size
self.diffraction_size = diffraction_size
self.device = device
# Precompute coordinate grids for batched gather
dy = torch.arange(probe_size, device=device, dtype=torch.long)
dx = torch.arange(probe_size, device=device, dtype=torch.long)
grid_y, grid_x = torch.meshgrid(dy, dx, indexing='ij')
self.register_buffer('grid_y', grid_y)
self.register_buffer('grid_x', grid_x)
# Convert scan positions to tensor
self.register_buffer('scan_positions',
torch.tensor(scan_positions, dtype=torch.long, device=device))
# Trainable object parameters
self.object_amplitude = nn.Parameter(
0.9 * torch.ones(object_size, object_size, device=device) +
0.1 * torch.randn(object_size, object_size, device=device)
)
self.object_phase = nn.Parameter(
0.1 * torch.randn(object_size, object_size, device=device)
)
# Probe model with trainable aberrations
self.probe_model = AberrationProbe(
diffraction_size, pixel_size, wavelength, convergence_angle, device=device
)
def get_object_complex(self):
"""Construct complex object from amplitude and phase."""
return self.object_amplitude * torch.exp(1j * self.object_phase)
def forward(self, batch_indices):
"""Forward pass with batched gather."""
if not isinstance(batch_indices, torch.Tensor):
batch_indices = torch.tensor(batch_indices, dtype=torch.long, device=self.device)
batch_indices = batch_indices.detach()
positions = self.scan_positions[batch_indices]
# Get complex object
object_complex = self.get_object_complex()
# Get probe (regenerated every call - BOTTLENECK!)
probe_complex = self.probe_model.get_probe_crop(self.probe_size)
# Batched gather - extract all patches at once
y_idx = positions[:, 0, None, None] + self.grid_y[None, :, :]
x_idx = positions[:, 1, None, None] + self.grid_x[None, :, :]
tiles = object_complex[y_idx, x_idx] # [B, probe_size, probe_size]
# Exit waves
exit_waves = tiles * probe_complex.unsqueeze(0)
# FFT with internal padding (cuFFT handles this optimally)
Psi = torch.fft.fft2(exit_waves, s=(self.diffraction_size, self.diffraction_size))
return Psi
Training loop
# Model setup
model = PhysicsInformedPtychography(
object_size=512,
probe_size=80,
diffraction_size=256,
scan_positions=scan_positions, # 4,096 positions
pixel_size=0.5,
wavelength=0.0197, # 300 keV electrons
convergence_angle=0.02,
device='cuda'
).to('cuda')
# Fused Adam optimizers (single CUDA kernel each)
optimizer_object = torch.optim.Adam(
[model.object_amplitude, model.object_phase],
lr=0.01,
fused=True
)
optimizer_probe = torch.optim.Adam(
model.probe_model.parameters(),
lr=0.001,
fused=True
)
# Precompute constants
I_meas = (diffraction_patterns ** 2).contiguous()
meas_scale = (counts / I_meas.mean(dim=(1,2), keepdim=True)).contiguous()
# JIT-compiled loss function
@torch.jit.script
def compute_intensity_loss(I_pred, I_meas, meas_scale, counts):
pred_scale = counts / I_pred.mean(dim=(1,2), keepdim=True)
diff = I_pred * pred_scale - I_meas * meas_scale
return (diff * diff).mean()
# Training
torch.backends.cudnn.benchmark = True
batch_indices = torch.arange(len(scan_positions), dtype=torch.long, device='cuda')
for iteration in range(num_iterations):
# Forward
Psi = model(batch_indices)
I_pred = torch.abs(Psi) ** 2
loss = compute_intensity_loss(I_pred, I_meas, meas_scale, counts)
# Backward
optimizer_object.zero_grad(set_to_none=True)
optimizer_probe.zero_grad(set_to_none=True)
loss.backward()
# Update
optimizer_object.step()
optimizer_probe.step()
Current baseline: Already implemented optimizations
Before profiling and writing custom kernels, several standard PyTorch optimizations are already in place:
Batched gather operations: All 4,096 object patches extracted in single GPU operation (no Python loops)
Registered buffers: Precomputed coordinate grids (k-space, scan positions) stored on GPU, not recomputed
FFT internal padding: Using s parameter to delegate padding to cuFFT (more efficient than explicit torch.nn.functional.pad)
Fused Adam optimizer: torch.optim.Adam(fused=True) uses single CUDA kernel for parameter updates
JIT-compiled loss function: TorchScript compilation eliminates Python interpreter overhead
cuDNN autotuning: torch.backends.cudnn.benchmark = True selects optimal convolution algorithms for tensor shapes
Impact: These optimizations reduced iteration time from ~0.5s to 0.11s per iteration (4.5× improvement).
Challenge: To reach < 0.01s requires another 11× speedup. Standard PyTorch optimizations exhausted, custom CUDA required.
Warning
Gotcha: fused=True in Adam only works when all parameters are on the same device and contiguous in memory. If parameters are scattered across tensors, PyTorch falls back to unfused implementation without warning. Check with torch.profiler to verify single adam_kernel launch.
Profiling: Identifying the bottlenecks
Running torch.profiler for 100 iterations reveals where time is spent:
GPU time breakdown:
Name |
Self CPU % |
Self CUDA % |
|---|---|---|
|
0.11% |
24.68% |
|
0.60% |
19.07% |
|
0.04% |
5.95% |
|
0.04% |
0.00% |
|
0.02% |
0.00% |
|
0.00% |
12.98% |
|
0.02% |
0.00% |
|
0.01% |
0.00% |
|
0.00% |
11.70% |
|
0.04% |
0.00% |
|
0.02% |
0.00% |
Total times: Self CPU time: 11.011s | Self CUDA time: 11.030s
The real bottleneck: CPU time ≈ GPU time ≈ 11s for 100 iterations (0.11s per iteration).
Why CPU time matters: The user experiences wall-clock time, not just GPU time. PyTorch operations involve:
Python overhead - Each operation calls Python → C++ → CUDA
Kernel launch overhead - CPU must launch each GPU kernel (cudaLaunchKernel takes ~5-10μs)
Autograd graph - CPU builds and manages backward computation graph
Tensor bookkeeping - CPU tracks shapes, dtypes, memory, reference counts
Synchronization - CPU waits for GPU results between operations
With 5,140 mul calls, 620 abs calls, etc., these CPU overheads accumulate to match the GPU compute time.
Custom CUDA eliminates CPU overhead: A single kernel launch does everything on GPU. CPU launches once, then waits for the final result. No Python loops, no intermediate tensors, no autograd graph during forward pass.
What PyTorch does internally for every operation
CPU-side overhead (happens for EVERY PyTorch operation):
Python interpreter overhead
Function call from Python → C++ extension
Argument parsing and validation (type checking, shape checking)
GIL (Global Interpreter Lock) acquisition/release
Reference counting for Python objects
Cost: ~1-5μs per operation
PyTorch dispatcher overhead
Determine which backend to use (CPU/CUDA/MPS)
Check if operation is in-place or requires new allocation
Dispatch to correct kernel implementation
Handle device placement and data type promotion
Cost: ~0.5-2μs per operation
Memory management
Allocate output tensor (unless in-place)
Update CUDA memory allocator bookkeeping
Track tensor metadata (shape, stride, dtype, device)
Increment/decrement reference counts
Register tensor for garbage collection
Cost: ~2-10μs per allocation
Autograd graph construction
Create autograd node for backward pass
Store input tensors (or their versions) for gradient computation
Link node into computation graph
Register hook functions if any
Track requires_grad status through operations
Cost: ~5-15μs per operation with gradients
Kernel launch
Call
cudaLaunchKernel()system callTransfer kernel arguments to GPU (via driver)
Enqueue kernel in CUDA stream
Update stream synchronization state
Cost: ~5-10μs per kernel launch
GPU-side overhead (for operations PyTorch doesn’t fuse):
Memory bandwidth bottleneck
Each operation reads inputs from global memory (slow)
Each operation writes outputs to global memory (slow)
No data reuse between operations (cache misses)
Example:
A * Breads A and B, writes result. Next operation reads result again.
Kernel launch latency
GPU must wait for kernel to be scheduled
Streaming multiprocessors (SMs) must load kernel code
Thread blocks must be assigned to SMs
Cost: ~1-5μs per kernel
Synchronization overhead
Operations on same stream serialize (can’t overlap)
torch.cuda.synchronize()blocks CPU waiting for GPUImplicit syncs when copying data back to CPU
Example: Single complex multiply in PyTorch
# This one line:
result = a_complex * b_complex # complex64 tensors
# Triggers on CPU:
# - Python call overhead: ~3μs
# - Dispatcher: ~1μs
# - Memory allocation for result: ~5μs
# - Autograd node creation: ~10μs
# - 4× cudaLaunchKernel() calls: ~40μs (one per real/imag component)
# Total CPU time: ~59μs
# Triggers on GPU:
# - 4 kernel launches: ~4μs
# - Read a_complex (real + imag): 2 memory reads
# - Read b_complex (real + imag): 2 memory reads
# - Write result (real + imag): 2 memory writes
# - Actual compute: <1μs (multiplication is fast)
# Total GPU time: ~5μs
# Wall-clock time: max(CPU, GPU) ≈ 59μs for ONE multiply
With 5,140 complex multiplies:
CPU overhead: 5,140 × 59μs = 303ms
GPU overhead: 5,140 × 5μs = 26ms
Total added to iteration: ~330ms
Custom CUDA eliminates almost all of this:
// Single kernel launch from Python: ~20μs CPU overhead total
// GPU: Single kernel reads data once, writes result once
// Wall-clock time: ~5μs for ALL 5,140 multiplies fused together
Why this matters:
The profiling shows CPU time (11.01s) ≈ GPU time (11.03s). This means the code is spending equal time on:
CPU: Launching kernels, building autograd graphs, managing memory
GPU: Actually computing results
Custom CUDA collapses hundreds of operations into single kernels, eliminating most CPU overhead and allowing GPU to work continuously without waiting for CPU to launch the next kernel.
Mapping profiling to custom kernels
What can’t be optimized:
aten::_fft_c2c(24.7%): This IS cuFFT. No custom kernel beats cuFFT.regular_fft+vector_fft(24.7%): cuFFT’s internal kernels. Optimal.Total FFT: 49.4% - Keep using
torch.fft.fft2()
What custom kernels will target:
GPU operations to fuse:
aten::mul(19.1%, 5,140 calls): Complex multiply scattered everywhereaten::abs(5.95%, 620 calls): Complex → real conversionPower and backward operations: Squaring in loss + gradient computations
Total fusible GPU operations: ~25-30% of GPU time
CPU overhead to eliminate:
5,140 kernel launches for
muloperationsAutograd graph building and traversal
Tensor allocation and memory management
Python/C++/CUDA call stack overhead
Total CPU overhead: ~50% of wall-clock time (11.01s CPU ≈ 11.03s GPU)
Total speedup potential:
3× from kernel fusion (eliminate redundant memory operations on GPU)
2× from eliminating CPU overhead (single kernel launch, no Python intermediaries, no autograd graph)
Total: 6× speedup → 0.018s per iteration
What does “fuse” mean in GPU computing?
Kernel fusion means combining multiple operations into a single GPU kernel to eliminate intermediate memory reads/writes.
Unfused operations (what PyTorch does):
# Three separate operations
a = torch.exp(x) # Operation 1: read x, write a
b = a * y # Operation 2: read a, read y, write b
c = torch.abs(b) # Operation 3: read b, write c
# GPU memory traffic:
# Op 1: Read x [100MB], Write a [100MB] → 200MB
# Op 2: Read a [100MB], Read y [100MB], Write b [100MB] → 300MB
# Op 3: Read b [100MB], Write c [100MB] → 200MB
# Total: 700MB of memory traffic
# GPU execution:
# Launch kernel 1 → wait → launch kernel 2 → wait → launch kernel 3
# Each kernel launch: ~5-10μs CPU overhead + kernel execution time
Fused operations (custom CUDA):
__global__ void fused_kernel(float* x, float* y, float* c, int n) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n) {
float a = expf(x[idx]); // Compute in register
float b = a * y[idx]; // Compute in register
c[idx] = fabsf(b); // Write only final result
}
}
// GPU memory traffic:
// Read x [100MB], Read y [100MB], Write c [100MB] → 300MB
// Total: 300MB of memory traffic (2.3× less than unfused)
// GPU execution:
// Single kernel launch → compute everything → done
// One kernel launch: ~5-10μs CPU overhead (not 3×)
Why this matters:
Memory bandwidth is the bottleneck - Modern GPUs can compute much faster than they can move data. Reducing memory traffic by 2-3× directly speeds up execution by 2-3×.
Intermediate results stay in registers - GPU registers are ~1000× faster than global memory. Fused kernels keep
aandbin registers, never writing them to slow global memory.CPU launches fewer kernels - One launch instead of three eliminates 20-30μs of CPU overhead per fused group.
Real example from ptychography:
# Unfused (PyTorch):
amp = object_amplitude[indices] # Read, write → 4MB
phase = object_phase[indices] # Read, write → 4MB
obj_real = amp * torch.cos(phase) # Read amp, read phase, write → 12MB
obj_imag = amp * torch.sin(phase) # Read amp, read phase, write → 12MB
obj_complex = torch.complex(obj_real, obj_imag) # Read both, write → 12MB
exit_wave = obj_complex * probe # Read obj, read probe, write → 12MB
# Total: 56MB memory traffic, 6 kernel launches
// Fused (custom CUDA):
__global__ void fused_exit_wave_kernel(...) {
float amp = obj_amplitude[obj_idx]; // Read once
float phase = obj_phase[obj_idx]; // Read once
cuFloatComplex obj = make_cuFloatComplex(
amp * cosf(phase), // Compute in register
amp * sinf(phase) // Compute in register
);
cuFloatComplex probe_val = probe[idx]; // Read once
exit_wave[idx] = cuCmulf(obj, probe_val); // Write once
}
// Total: 12MB memory traffic (4.7× less), 1 kernel launch
Fusion is not optimization - it’s necessity:
For 4,096 scan positions, unfused operations create:
24,576 kernel launches (4,096 positions × 6 operations)
229GB of memory traffic (56MB × 4,096)
250ms of CPU overhead just launching kernels
Fused into one kernel:
1 kernel launch
49GB of memory traffic (4.7× reduction)
10μs of CPU overhead
Result: 4.7× faster GPU compute + elimination of CPU bottleneck = 5-6× total speedup.
Why PyTorch cannot fuse complex number operations
PyTorch’s fundamental limitation: The JIT compiler and torch.compile() do not fuse operations involving complex tensors. This is not a bug - it is an architectural limitation.
What happens with a simple complex multiply:
# This single line of PyTorch code:
exit_waves = object_complex * probe_complex # Both are complex64
# Becomes FOUR separate GPU kernel launches:
# Kernel 1: Extract real parts from object_complex
# Kernel 2: Extract imaginary parts from object_complex
# Kernel 3: Compute real_out = real_a * real_b - imag_a * imag_b
# Kernel 4: Compute imag_out = real_a * imag_b + imag_a * real_b
# Each kernel:
# - Reads from global memory (slow)
# - Writes to global memory (slow)
# - Requires CPU to launch it (~5-10μs overhead)
# - Allocates intermediate tensors
# - Registers operation in autograd graph
The cascade effect with 5,140 complex multiplies:
5,140 complex multiplies × 4 kernels each = 20,560 kernel launches
20,560 launches × 10μs CPU overhead = 205ms just in launch overhead
Plus: 20,560 memory allocations, 20,560 autograd registrations
Result: CPU time (11.01s) ≈ GPU time (11.03s)
Custom CUDA eliminates this completely:
// Single kernel, single launch, everything in registers:
cuFloatComplex obj_val = make_cuFloatComplex(
amp * cosf(phase), // Real part
amp * sinf(phase) // Imaginary part
);
cuFloatComplex probe_val = probe[idx];
exit_waves[idx] = cuCmulf(obj_val, probe_val); // One instruction
// CPU overhead: Single cudaLaunchKernel() call (~10μs total)
// GPU overhead: No intermediate memory reads/writes
// Result: 3× faster GPU compute + 2× faster CPU overhead = 6× total
Why this matters for ptychography:
The forward pass has three cascading complex operations:
\(A(x,y) \cdot e^{i\phi(x,y)}\) → object construction (complex exp + multiply)
\(O[indices] \cdot P(x,y)\) → exit waves (gather + complex multiply)
\(|\Psi(u,v)|^2\) → intensity (complex absolute value + square)
Each gets fragmented into multiple kernels. Backward pass makes it worse with gradient computations through each operation.
Custom CUDA is not optional - it is necessary. PyTorch fundamentally cannot optimize this workload.
The three custom kernels to write
Kernel 1: fused_exit_wave_kernel
Eliminates 19.1% overhead
This kernel computes the exit waves for all 4,096 scan positions in a single GPU call. For each position, it constructs the complex object from separate amplitude and phase arrays, extracts the relevant patch using precomputed indices, and multiplies by the probe function. PyTorch does this in three separate kernel launches with intermediate tensor allocations. The custom kernel fuses everything into one pass, eliminating memory roundtrips and CPU overhead.
Fuses: \(A(x,y) \cdot e^{i\phi(x,y)}\) at indices → multiply by \(P(x,y)\)
Parameter |
Description |
Shape/Type |
|---|---|---|
Input |
||
|
Object amplitude (real-valued, 0-1 range) |
|
|
Object phase (real-valued, radians) |
|
|
Vertical positions for all patches |
|
|
Horizontal positions for all patches |
|
|
Illumination function (complex) |
|
Output |
||
|
Transmitted wavefield for each position |
|
Replaces: 3 separate PyTorch ops (exp, gather, mul) + 3 kernel launches + 2 intermediate tensors
Kernel 2: fused_intensity_loss_kernel
Eliminates 30.5% overhead
This kernel computes the complete loss function in a single GPU kernel with streaming reduction. It converts the complex diffraction patterns to intensities, applies per-position scaling factors, computes squared differences against measured data, and reduces to a scalar loss - all without writing intermediate results to global memory. PyTorch does this as six separate operations, each allocating full-size temporary tensors. The fused kernel processes data on-the-fly using shared memory reductions.
Fuses: \(|\Psi|^2 \rightarrow \text{scale} \rightarrow \text{diff}^2 \rightarrow \text{mean}\)
Parameter |
Description |
Shape/Type |
|---|---|---|
Input |
||
|
Diffracted wavefields (complex) |
|
|
Measured intensities (precomputed) |
|
|
Scaling for measured data (precomputed) |
|
|
Total photon counts per pattern |
|
Output |
||
|
Mean squared error (scalar) |
|
|
Predicted intensities (saved for backward) |
|
Replaces: 6 separate PyTorch ops (abs, pow, mean, mul, sub, mean) + 6 kernel launches + 5 intermediate tensors
Kernel 3: fused_exit_wave_backward_kernel
Speeds up backward pass
This kernel computes gradients for the exit wave operation by reversing the forward computation. Given gradients with respect to the exit waves, it efficiently backpropagates through the complex multiplication and object construction. The challenge is handling atomic additions for overlapping patches - each object pixel is visited by multiple scan positions, so gradients must accumulate atomically. PyTorch’s autograd handles this correctly but builds a computational graph with overhead. The custom kernel implements the exact derivatives directly, using atomic operations for the scatter-add pattern.
Fuses: Gradients for obj_amplitude, obj_phase, and probe from grad_exit_waves
Parameter |
Description |
Shape/Type |
|---|---|---|
Input |
||
|
Gradients from loss w.r.t. exit waves |
|
|
Forward pass object amplitude (for chain rule) |
|
|
Forward pass object phase (for chain rule) |
|
|
Forward pass probe (for chain rule) |
|
|
Vertical positions (same as forward) |
|
|
Horizontal positions (same as forward) |
|
Output |
||
|
Gradients w.r.t. object amplitude |
|
|
Gradients w.r.t. object phase |
|
|
Gradients w.r.t. probe function |
|
Replaces: PyTorch’s autograd through complex ops (general-purpose graph traversal) with direct derivative computation using atomic scatter-add
Summary
Everything else (probe generation with aberrations, Adam, autograd) stays in PyTorch for now, but could be moved to CUDA for further speedup.
FFT itself (49.4% of GPU time) can’t be improved - cuFFT is already optimal. But the operations around it can be fused, and more importantly, moving everything to custom CUDA eliminates PyTorch’s CPU overhead (kernel launches, autograd graph, tensor bookkeeping) which roughly doubles the wall-clock time.
Expected total speedup:
2.5-3× from kernel fusion (GPU compute)
2× from eliminating CPU overhead (PyTorch → pure CUDA)
Total: 5-6× → 0.018-0.022s per iteration
Custom kernel 1: Probe generation in k-space (low priority, but needed)
Reality check: Probe aberration parameters (defocus, Cs, astigmatism) ARE trainable in ptychography - that’s the whole point! But this kernel is still low priority because it runs only 1× per iteration vs 4,096× for exit waves.
The probe generation code:
# Runs once per iteration after optimizer updates aberration parameters
chi_defocus = np.pi * wavelength * self.defocus * k**2
chi_Cs = 0.5 * np.pi * wavelength**3 * Cs_nm * k**4
chi_astig = np.pi * wavelength * self.astig_mag * k**2 * torch.cos(2*(k_angle - astig_angle))
chi = chi_defocus + chi_Cs + chi_astig
aperture = torch.sigmoid((k_max - self.k) / (k_max * aperture_smooth))
probe_k = aperture * torch.exp(1j * chi)
probe_real = torch.fft.ifft2(probe_k)
Can PyTorch fuse these operations automatically?
Short answer: No, not effectively.
Why PyTorch can’t fuse this:
JIT limitations: PyTorch’s
torch.jitcan’t optimize through:torch.exp(1j * chi)- complex exponential with imaginary unittorch.sigmoid()- elementwise nonlinearitytorch.cos()with scalar multiplicationMultiple intermediate tensor allocations
TorchScript restrictions:
@torch.jit.script def generate_probe_k(defocus, Cs, astig_mag, k, k_angle): # TorchScript can trace this, but won't fuse kernels chi = math.pi * wavelength * defocus * k**2 + ... # Separate kernel aperture = torch.sigmoid((k_max - k) / smooth) # Separate kernel probe_k = aperture * torch.exp(1j * chi) # Separate kernel return probe_k
Result: Still 3+ separate CUDA kernel launches, just with less Python overhead.
PyTorch 2.0 ``torch.compile()``:
@torch.compile(mode="max-autotune") def generate_probe_k(defocus, Cs, astig_mag, k, k_angle): chi = math.pi * wavelength * defocus * k**2 + ... aperture = torch.sigmoid((k_max - k) / smooth) probe_k = aperture * torch.exp(1j * chi) return probe_k
Might fuse some ops, but limited by TorchInductor’s complex number support
Complex exponential
exp(1j * x)often falls back to eager modeGains: ~10-20% at best (vs 1.1-1.2× with custom CUDA)
Worth trying first since it’s just one decorator!
Custom CUDA kernel gives full control:
__global__ void generate_probe_k_kernel(
const float* __restrict__ k,
const float* __restrict__ k_angle,
cuFloatComplex* __restrict__ probe_k,
float defocus, float Cs, float astig_mag, float astig_angle,
float wavelength, float k_max, float aperture_smooth,
int size
) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < size * size) {
float k_val = k[idx];
float angle = k_angle[idx];
// All ops fused in registers - no intermediate memory writes
float chi = M_PI * wavelength * defocus * k_val * k_val +
0.5f * M_PI * wavelength * wavelength * wavelength * Cs *
k_val * k_val * k_val * k_val +
M_PI * wavelength * astig_mag * k_val * k_val *
cosf(2.0f * (angle - astig_angle));
float aperture = 1.0f / (1.0f + expf((k_val - k_max) / (k_max * aperture_smooth)));
probe_k[idx] = make_cuFloatComplex(
aperture * cosf(chi),
aperture * sinf(chi)
);
}
}
Expected speedup: 1.1-1.2× (5-6 PyTorch kernel launches → 1 custom kernel)
But still low priority because:
Runs 1× per iteration (after optimizer step)
Exit wave kernel runs 4,096× per iteration (every scan position)
Even at 10× slower, it’s only ~10% of total time
Focus on Kernel 2 (exit wave FP16) first - that’s where 2-3× gains are
Recommendation: Implement exit wave FP16 kernel (Kernel 2) first. Come back to this only if you need that extra 10-20%.
Recommendation: Start with Kernel 2 (exit wave FP16) - that’s where the real gains are (2-3× speedup, happens 4,096 times per iteration).
Should I implement mixed precision (FP16) for the probe IFFT?
No - not worth it.
Why skip FP16 for probe IFFT:
Runs only 1× per iteration (same frequency as probe generation)
Minimal impact even with 2× speedup
Exit wave FFT runs 4,096× per iteration - focus there instead
cuFFT conversion overhead kills the gains:
Need FP32 → FP16 conversion before IFFT
Need FP16 → FP32 conversion after IFFT
Each conversion is a separate kernel launch
cuFFT FP32 is already memory-bandwidth optimal
FFT is memory-bound, not compute-bound:
Tensor Cores don’t help (FFT is not matrix multiply)
FP16 only saves bandwidth if data stays FP16 throughout
But probe_k comes from FP32 kernel, and probe_real goes to FP32 ops
Conversions negate the bandwidth savings
The math:
Probe IFFT: 80×80 = 6,400 complex values
Takes ~0.02-0.05ms (vs 110ms total iteration time)
Even 2× speedup saves only ~0.01-0.025ms
Impact: <0.02% of total time
Verdict: Keep probe IFFT in FP32. Focus on Kernel 2 (exit wave FP16) where you get 2-3× speedup on operations that run 4,096× per iteration.
Exception: If your exit wave pipeline keeps everything in FP16 (exit_wave_fp16 → FFT_fp16 → intensity_fp16), then the forward FFT could benefit. But the probe IFFT still runs too infrequently to matter.
Custom kernel 2: Fused exit wave computation with FP16 (CRITICAL)
This is the main optimization! Runs once per scan position (e.g., 4,096× per iteration).
Problem:
GPU time:
aten::multakes 2.104s (19.1%) for 5,140 separate kernel callsCPU time: Each kernel launch costs ~5-10μs, plus autograd bookkeeping
The current code does this:
# Three separate operations in FP32
object_complex = object_amplitude * torch.exp(1j * object_phase) # mul #1
tiles = object_complex[y_idx, x_idx] # gather
exit_waves = tiles * probe_complex.unsqueeze(0) # mul #2 (complex)
# Result: 3 kernel launches, 3 memory writes, all in FP32
Solution: Fuse all three operations into one FP16 kernel using custom complex math.
FP16 complex math primitives
First, define FP16 complex number operations using __half2:
#include <cuda_fp16.h>
typedef __half2 complex_half; // Stores (real, imag) in one __half2
// Complex multiply: (a + bi)(c + di) = (ac - bd) + (ad + bc)i
__device__ __forceinline__ complex_half cmul_fp16(complex_half a, complex_half b) {
__half a_real = __low2half(a);
__half a_imag = __high2half(a);
__half b_real = __low2half(b);
__half b_imag = __high2half(b);
__half real = __hsub(__hmul(a_real, b_real), __hmul(a_imag, b_imag));
__half imag = __hadd(__hmul(a_real, b_imag), __hmul(a_imag, b_real));
return __halves2half2(real, imag);
}
// Amplitude + phase → complex FP16
__device__ __forceinline__ complex_half amp_phase_to_complex_fp16(
__half amp, __half phase
) {
__half real = __hmul(amp, hcos(phase));
__half imag = __hmul(amp, hsin(phase));
return __halves2half2(real, imag);
}
Why this is fast: On Ampere/Hopper GPUs, __half operations use Tensor Cores and process 2× faster than FP32.
Fused gather + multiply kernel
__global__ void gather_multiply_fp16(
const __half* __restrict__ obj_amplitude, // [H, W] FP16
const __half* __restrict__ obj_phase, // [H, W] FP16
const __half2* __restrict__ probe, // [probe_size²] complex FP16
const int* __restrict__ y_indices, // [B * probe_size²]
const int* __restrict__ x_indices, // [B * probe_size²]
__half2* __restrict__ output, // [B * probe_size²] complex FP16
int batch_size, int probe_size, int object_width
) {
int batch_idx = blockIdx.x;
int pixel_idx = threadIdx.x + blockIdx.y * blockDim.x;
int probe_area = probe_size * probe_size;
if (pixel_idx < probe_area) {
int global_idx = batch_idx * probe_area + pixel_idx;
// Step 1: Gather object pixel coordinates
int y = y_indices[global_idx];
int x = x_indices[global_idx];
int obj_idx = y * object_width + x;
// Step 2: Convert amplitude + phase → complex (fused, in registers)
__half amp = obj_amplitude[obj_idx];
__half phase = obj_phase[obj_idx];
complex_half obj_val = amp_phase_to_complex_fp16(amp, phase);
// Step 3: Complex multiply with probe (fused, uses Tensor Cores!)
complex_half probe_val = probe[pixel_idx];
output[global_idx] = cmul_fp16(obj_val, probe_val);
}
}
PyTorch wrapper:
import torch
from torch.utils.cpp_extension import load_inline
# Compile CUDA kernel
cuda_source = """
// ... kernel code above ...
torch::Tensor fused_exit_wave(
torch::Tensor obj_amplitude,
torch::Tensor obj_phase,
torch::Tensor y_indices,
torch::Tensor x_indices,
// Step 3: Complex multiply with probe (fused, uses Tensor Cores!)
complex_half probe_val = probe[pixel_idx];
output[global_idx] = cmul_fp16(obj_val, probe_val);
}
}
What this achieves:
3 operations → 1 kernel: Amplitude/phase conversion + gather + complex multiply
FP32 → FP16: 2× memory bandwidth, 2× compute throughput (Tensor Cores)
No intermediate memory: Everything computed in registers
1 kernel launch instead of 3: Eliminates CPU overhead
Launch configuration and PyTorch binding
extern "C" {
void launch_gather_multiply_fp16(
const void* obj_amplitude,
const void* obj_phase,
const void* probe,
const int* y_indices,
const int* x_indices,
void* output,
int batch_size,
int probe_size,
int object_width,
cudaStream_t stream
) {
int probe_area = probe_size * probe_size;
// Grid: [batch_size, (probe_area + 255) / 256]
dim3 grid(batch_size, (probe_area + 255) / 256);
dim3 block(256);
gather_multiply_fp16<<<grid, block, 0, stream>>>(
(const __half*)obj_amplitude,
(const __half*)obj_phase,
(const __half2*)probe,
y_indices,
x_indices,
(__half2*)output,
batch_size,
probe_size,
object_width
);
}
} // extern "C"
Use in PyTorch:
import torch
import ptychography_fp16 # Your compiled extension
# Convert FP32 → FP16 for kernel input
obj_amp_fp16 = object_amplitude.to(torch.float16)
obj_phase_fp16 = object_phase.to(torch.float16)
probe_fp16 = probe.to(torch.float16)
# Single fused FP16 kernel call
exit_waves_fp16 = ptychography_fp16.gather_multiply_fp16(
obj_amp_fp16, obj_phase_fp16, probe_fp16, y_idx, x_idx
)
# Continue with FFT in FP32
Psi = torch.fft.fft2(exit_waves_fp16.to(torch.complex64))
Expected speedup: 2-3× compared to PyTorch FP32 (gather + multiply operations)
Performance breakdown: - Kernel fusion: 1.3-1.5× (eliminates 2 extra kernel launches) - FP16 compute: 1.5-2× (Tensor Cores, 2× throughput) - Combined: 2-3× for this operation
Helper kernels for FP16 conversion
Additional utility kernels for converting between FP32/FP16 and packing/unpacking complex numbers:
// Compute intensity: |ψ|² (FP16 input, FP32 output for stability)
__global__ void compute_intensity_fp16(
const __half* psi_real,
const __half* psi_imag,
float* intensity,
int size
) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < size) {
float real_f32 = __half2float(psi_real[idx]);
float imag_f32 = __half2float(psi_imag[idx]);
intensity[idx] = real_f32 * real_f32 + imag_f32 * imag_f32;
}
}
// Pack separate real/imag → __half2
__global__ void pack_complex_fp16(
const __half* real,
const __half* imag,
__half2* output,
int size
) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < size) {
output[idx] = __halves2half2(real[idx], imag[idx]);
}
}
- Why compute intensity in FP32: Squaring FP16 values can cause overflow. Convert to FP32 for the final multiplication to maintain numerical stability.
- exit_waves = fused_module.fused_exit_wave(
self.object_amplitude, self.object_phase, self.y_idx, self.x_idx, self.probe_cache
)
# FFT (can’t fuse this - cuFFT is optimal) Psi = torch.fft.fft2(exit_waves, s=(self.diffraction_size, self.diffraction_size)) return torch.abs(Psi)
What this kernel achieves:
Eliminates separate
amplitude * exp(i*phase)operationEliminates separate gather operation
Eliminates separate complex multiply operation
All done in one pass with coalesced memory access
Expected speedup: 1.3-1.5×
New time: 0.110s → 0.073-0.085s
Why this works: Instead of materializing the full complex object in memory, then gathering patches, then multiplying by probe (three separate kernel launches, three memory round-trips), we compute everything on-the-fly. Each thread reads amplitude and phase values, computes the complex object value, and immediately multiplies by the probe - all in registers.
Custom kernel 3: Fused intensity loss
Problem: Abs + power operations take 30.45% combined:
aten::abs: 656ms (5.95%)AbsBackward0: 1.524s (13.82%)PowBackward0: 1.175s (10.65%)
The current code:
# Three separate operations
I_pred = torch.abs(Psi) ** 2 # abs, then pow
pred_scale = counts / I_pred.mean(dim=(1,2), keepdim=True) # mean, div
diff = I_pred * pred_scale - I_meas * meas_scale # mul, sub
loss = (diff * diff).mean() # mul, mean
Custom CUDA kernel fuses everything:
__global__ void fused_intensity_loss_kernel(
const cuFloatComplex* __restrict__ Psi, // [B, H, W]
const float* __restrict__ I_meas, // [B, H, W]
const float* __restrict__ meas_scale, // [B, 1, 1]
float counts,
float* __restrict__ loss_output, // [1]
float* __restrict__ I_pred_output, // [B, H, W] (for backward)
int B, int H, int W
) {
__shared__ float shared_sum[256];
__shared__ float shared_mean[256]; // For per-batch means
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int tid = threadIdx.x;
int total = B * H * W;
float local_sum = 0.0f;
// Phase 1: Compute intensities and per-batch means
if (idx < total) {
int b = idx / (H * W);
cuFloatComplex psi_val = Psi[idx];
// I_pred = |Psi|^2
float I_pred = psi_val.x * psi_val.x + psi_val.y * psi_val.y;
I_pred_output[idx] = I_pred; // Store for backward pass
// Accumulate for mean (atomicAdd to shared memory per batch)
atomicAdd(&shared_mean[b % 256], I_pred);
}
__syncthreads();
// Phase 2: Compute loss with scaling
if (idx < total) {
int b = idx / (H * W);
float I_pred = I_pred_output[idx];
// pred_scale = counts / mean(I_pred)
float pred_scale = counts / (shared_mean[b % 256] / (H * W));
// diff = I_pred * pred_scale - I_meas * meas_scale
float diff = I_pred * pred_scale - I_meas[idx] * meas_scale[b];
// loss = mean(diff^2)
local_sum += diff * diff;
}
// Reduce across threads
shared_sum[tid] = local_sum;
__syncthreads();
for (int s = blockDim.x / 2; s > 0; s >>= 1) {
if (tid < s) {
shared_sum[tid] += shared_sum[tid + s];
}
__syncthreads();
}
if (tid == 0) {
atomicAdd(loss_output, shared_sum[0] / total);
}
}
PyTorch wrapper for forward pass:
def fused_intensity_loss(Psi, I_meas, meas_scale, counts):
"""
Fused forward pass - computes loss in a single kernel.
Uses PyTorch autograd for backward pass.
"""
B, H, W = Psi.shape
loss = torch.zeros(1, device=Psi.device)
I_pred = torch.empty_like(I_meas)
threads = 256
blocks = (B * H * W + threads - 1) // threads
fused_intensity_loss_kernel[blocks, threads](
Psi, I_meas, meas_scale, counts, loss, I_pred, B, H, W
)
return loss
# Use in training with PyTorch autograd
loss = fused_intensity_loss(Psi, I_meas, meas_scale, counts)
loss.backward() # PyTorch handles gradient computation
Expected speedup: 1.5-1.8× (eliminates 6 separate operations + all intermediate memory)
New time: 0.073s → 0.041-0.049s
Summary: Two critical kernels + one optional
Priority ranking by impact:
Kernel 2 (exit wave FP16) - CRITICAL: Runs 4,096× per iteration → 2-3× speedup
Kernel 3 (intensity loss) - IMPORTANT: Runs 4,096× per iteration → 1.5-1.8× speedup
Kernel 1 (probe generation) - OPTIONAL: Runs 1× per iteration, only if optimizing probe aberrations
The forward pass:
[Optional] Kernel 1: probe_k generation → [IFFT] → probe_real (1× per iter)
↓
[CRITICAL] Kernel 2: Exit wave FP16 (4,096× per iter)
tile gather + amp/phase → complex + multiply
→ [FFT] → Psi
↓
[IMPORTANT] Kernel 3: Intensity loss (4,096× per iter)
|Psi|², scaling, loss
Why Kernel 2 is most important:
Runs 4,096 times per iteration (once per scan position)
Uses FP16 with Tensor Cores → 2-3× speedup
Fuses 3 operations into 1 kernel
Biggest performance impact
Why Kernel 1 is optional:
Runs 1 time per iteration (or never if probe is fixed)
Most reconstructions use fixed probe (no optimization needed)
Even if trainable, small fraction of total time (~10%)
Low priority - skip unless you need it
Performance gains:
Kernel 1 (optional): 1.1-1.2× IF optimizing probe aberrations
Kernel 2 (critical): 2-3× speedup with FP16 + kernel fusion
Kernel 3 (important): 1.5-1.8× speedup
Realistic combined: 2-3× from Kernels 2+3 alone
Recommendation: Implement Kernel 2 first, then Kernel 3. Skip Kernel 1 unless you’re optimizing probe aberrations.
- Backward pass: Use PyTorch autograd
loss = compute_loss(Psi, I_meas)
scaler.scale(loss).backward() # Scale gradients to prevent underflow scaler.step(optimizer) scaler.update()
Performance gains:
1.3-1.5× speedup on Ampere/Hopper GPUs
2× memory reduction for activations
No accuracy loss for this application
Final combined speedup: 2.5-3.5× (kernels) × 1.3-1.5× (FP16) = 3-5× total
Result: 0.11s → 0.022-0.037s per iteration
Backward pass: Use PyTorch autograd
Strategy: Focus optimization on forward pass, use PyTorch’s built-in tools for backward.
Strategy: Focus optimization on forward pass (Kernels 2+3), use PyTorch’s built-in tools for backward.
optimizer = torch.optim.Adam([object_complex, probe], lr=0.01)
for iteration in range(max_iterations):
optimizer.zero_grad()
# Forward pass with custom FP16 kernels
exit_waves_fp16 = ptychography_fp16.gather_multiply_fp16(
obj_amp_fp16, obj_phase_fp16, probe_fp16, y_idx, x_idx
)
Psi = torch.fft.fft2(exit_waves_fp16.to(torch.complex64))
loss = compute_loss(Psi, I_meas)
loss.backward() # PyTorch autograd handles gradients
optimizer.step()
Why this is good: - PyTorch’s autograd is already highly optimized - ADAM optimizer is well-tested and numerically stable - Less code to maintain - Forward pass (Kernels 2+3) is where the biggest gains are (2-3× each)
Implementation guide
Compile custom FP16 CUDA kernels
Create setup.py for PyTorch extension:
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
setup(
name='ptychography_fp16',
ext_modules=[
CUDAExtension(
'ptychography_fp16',
['fp16_complex_ops.cu'],
extra_compile_args={
'cxx': ['-O3'],
'nvcc': ['-O3', '--use_fast_math', '-arch=sm_80'] # Ampere
}
)
],
cmdclass={'build_ext': BuildExtension}
)
Compile and install:
python setup.py install
Integrate into PyTorch training loop
import torch
import torch.nn as nn
import ptychography_fp16 # Compiled custom kernels
class OptimizedPtychography(nn.Module):
def __init__(self, object_size, probe_size, scan_positions, device='cuda'):
super().__init__()
# Master copies in FP32
self.object_amplitude = nn.Parameter(torch.randn(object_size, device=device))
self.object_phase = nn.Parameter(torch.randn(object_size, device=device))
self.probe = nn.Parameter(torch.randn(probe_size, dtype=torch.complex64, device=device))
# Precomputed indices
self.register_buffer('y_idx', ...)
self.register_buffer('x_idx', ...)
def forward(self, batch_indices):
# Convert to FP16 for custom kernel
object_fp16 = ptychography_fp16.amp_phase_to_complex_fp16(
self.object_amplitude.to(torch.float16),
self.object_phase.to(torch.float16)
)
probe_fp16 = self.probe.to(torch.float16)
# Custom FP16 gather + multiply kernel
exit_waves_fp16 = ptychography_fp16.gather_multiply_fp16(
object_fp16, probe_fp16, self.y_idx, self.x_idx
)
# FFT in FP32 for stability
Psi = torch.fft.fft2(exit_waves_fp16.to(torch.complex64))
return Psi
# Training loop
model = OptimizedPtychography(...).to('cuda')
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
for iteration in range(num_iterations):
optimizer.zero_grad()
Psi = model(batch_indices)
loss = compute_loss(Psi, I_meas)
loss.backward() # PyTorch autograd handles gradients
optimizer.step()
Validate FP16 accuracy
# Reference (PyTorch ops) with torch.no_grad():
Psi_ref = model_pytorch(batch_indices) loss_ref = compute_loss_pytorch(Psi_ref, …)
# Custom kernels with torch.no_grad():
Psi_opt = model_optimized(batch_indices) loss_opt = compute_loss_optimized(Psi_opt, …)
# Check numerical accuracy max_diff = torch.max(torch.abs(Psi_ref - Psi_opt)) rel_error = max_diff / torch.max(torch.abs(Psi_ref))
print(f”Max absolute difference: {max_diff:.6e}”) print(f”Relative error: {rel_error:.6e}”)
assert rel_error < 1e-4, “Custom kernels changed results!”
Expected errors with mixed precision: < 1e-3 (acceptable for scientific computing)
Step 4: Profile to verify speedup
import time
# Warmup
for _ in range(10):
_ = model(batch_indices)
# Benchmark
torch.cuda.synchronize()
times = []
for _ in range(100):
t0 = time.perf_counter()
loss = train_iteration(model, ...)
torch.cuda.synchronize()
times.append(time.perf_counter() - t0)
print(f"Mean: {np.mean(times):.4f}s")
print(f"Std: {np.std(times):.4f}s")
print(f"Target < 0.01s: {'✅' if np.mean(times) < 0.01 else '❌'}")
Summary
Three custom CUDA kernels + FP16 = 3-5× speedup
Why three separate kernels:
Kernel 1: Probe generation (aberrations) → [IFFT] → probe
Kernel 2: Exit wave (gather + multiply) → [FFT] → Psi
Kernel 3: Intensity loss (|Psi|², scale, loss)
Cannot fuse through FFT/IFFT - global operations, cuFFT optimal.
What each kernel fuses:
Probe generation: Defocus + Cs + astigmatism + aperture → FP32 probe_k → [IFFT]
Speedup: 1.1-1.2×
Exit wave (FP16): Tile gather + amp/phase → complex + complex multiply in FP16
Speedup: 2-3× (uses Tensor Cores)
Custom
__half2complex mathFuses 3 operations into 1 kernel
Intensity loss: |Psi|² + scaling + difference + reduction
Speedup: 1.5-1.8×
FP16 implementation details:
Uses
__half2for complex numbers:(real, imag)packedCustom
cmul_fp16()for complex multiply (Tensor Cores)Master copies in FP32, compute in FP16
FFT stays FP32 for stability
Typical error: < 1e-3 (acceptable)
Backward pass:
Use PyTorch autograd (already optimized)
ADAM optimizer
Focus on forward pass optimization
Final result: 0.11s → 0.022-0.037s (3-5× speedup)
Key takeaway: Custom FP16 kernels + kernel fusion + FFT boundaries = maximum performance with minimal code.
# PyTorch Operations Deep Dive: Ptychography Forward Pass
This document explains every PyTorch operation in the ptychography forward pass with detailed visual illustrations.
—
## Table of Contents 1. [The Big Picture: What is Ptychography?](#1-the-big-picture) 2. [Setup: Initial Data Structures](#2-setup) 3. [Operation 1: Gather (Extracting Patches)](#3-gather-operation) 4. [Operation 2: Complex Multiplication](#4-complex-multiplication) 5. Operation 3: FFT (Fourier Transform) 6. Operation 4: Intensity Calculation 7. Performance Optimizations Explained
The big picture: What is ptychography?
PTYCHOGRAPHY: Reconstructing an object by scanning a probe over it
Step 1: Physical Setup
┌─────────────────────────────────────────┐
│ │
│ Electron Beam (Probe) │
│ ↓↓↓↓↓↓↓ │
│ ┌─────────────┐ │
│ │ Probe │ ← 128×128 pixels │
│ │ (complex) │ │
│ └─────────────┘ │
│ ↓ │
│ ┌─────────────────────┐ │
│ │ │ │
│ │ Sample/Object │ ← 512×512 │
│ │ (unknown!) │ │
│ │ │ │
│ └─────────────────────┘ │
│ ↓ │
│ ┌─────────────┐ │
│ │ Diffraction │ ← Detector 256×256 │
│ │ Pattern │ │
│ └─────────────┘ │
│ │
└─────────────────────────────────────────┘
Step 2: Scanning Process
Position 1 Position 2 Position 3
┌────┐ ┌────┐ ┌────┐
│Beam│ │Beam│ │Beam│
└─┬──┘ └─┬──┘ └─┬──┘
↓ ↓ ↓
┌─────────────────────────────────────┐
│ ░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░ │
│ ░░░Sample moves, beam scans!░░░░░░ │
│ ░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░ │
└─────────────────────────────────────┘
We scan 4096 positions (typical) and record diffraction at each!
Goal: Given diffraction patterns, reconstruct the object.
Forward model (what we’re implementing):
Object + Probe → Exit Wave → FFT → Diffraction Intensity
↑ ↑ ↑ ↑ ↑
512×512 128×128 128×128 256×256 256×256
complex complex complex complex real
Setup: Initial data structures
# The object we want to reconstruct (complex-valued)
object_real = torch.randn(512, 512) # Real part
object_imag = torch.randn(512, 512) # Imaginary part
# The probe (illumination, known or optimized)
# Here assume it's already IFFTed from probe(k)
probe_real = torch.randn(128, 128) # Real part
probe_imag = torch.randn(128, 128) # Imaginary part
# Scan positions (where the probe illuminates)
batch_size = 4096 # Number of scanned positions
And we have the scanned positions:
Object:
┌─────────────────────┐
│ ● ● ● │ ← Row 1: (100,50), (100,150), (100,250)
│ │
│ ● ● ● │ ← Row 2: (200,50), (200,150), (200,250)
│ │
│ ● ● ● │ ← Row 3: (300,50), (300,150), (300,250)
└─────────────────────┘
y_centers = [100, 100, 100, 200, 200, 200, 300, 300, 300]
x_centers = [50, 150, 250, 50, 150, 250, 50, 150, 250]
y_centers = torch.tensor([100, 150, 200, ...]) # 4096 y-coordinates
x_centers = torch.tensor([50, 75, 100, ...]) # 4096 x-coordinates
Object (512×512):
┌─────────────────────────────────────┐
│ 0 1 2 3 4 5 6 ... 510 511 │ ← Row 0
│ 0 1 2 3 4 5 6 ... 510 511 │ ← Row 1
│ : : : : : : : : : │
│ 0 1 2 3 4 5 6 ... 510 511 │ ← Row 511
└─────────────────────────────────────┘
Each cell is a complex number: real + i*imag
Probe (128×128):
┌──────────────────┐
│ Gaussian-like │
│ beam shape │
│ (128×128 px) │
└──────────────────┘
Gather operation: Extracting patches (detailed explanation)
Problem: We need to extract a 128×128 patch from the 512×512 object at each scan position.
# Scan position 0: center at (100, 50)
# Scan position 1: center at (150, 75)
# Scan position 2: center at (200, 100)
# ... 4093 more positions
Object (512×512):
x_center=50
↓
┌──────────────┼──────────────┐
│ │ │
│ │ │
y=36 ├──────────────●──────────────┤ ← Top-left of patch
│ ┌────────┼────────┐ │
│ │ Patch │128×128 │ │
y_center=100──┼────────●────────┼─────┤ ← Center
│ │ │ │ │
│ └────────┼────────┘ │
y=164 ├──────────────┼──────────────┤ ← Bottom-right
│ │ │
└──────────────┼──────────────┘
x=36 x=50 x=164
Top-left corner: (y_center - 64, x_center - 64) = (36, -14)
↑
Note: 128/2 = 64 (half the probe size)
patches = []
for i in range(4096):
y_start = y_centers[i] - 64
x_start = x_centers[i] - 64
# Extract 128×128 patch
patch_real = object_real[y_start:y_start+128, x_start:x_start+128]
patch_imag = object_imag[y_start:y_start+128, x_start:x_start+128]
patches.append((patch_real, patch_imag))
# Problem: 4096 separate operations! Very slow on GPU.
Here is a visual for patch_real = object_real[36:164, 50:178]:
object_real (512×512):
Columns →
0 1 2 ... 50 51 ... 177 178 ... 511
R 0 ░ ░ ░ ... ░ ░ ... ░ ░ ... ░
o 1 ░ ░ ░ ... ░ ░ ... ░ ░ ... ░
w : : : : : : : : :
s 36 ░ ░ ░ ... █ █ ... █ ░ ... ░ ← Start row (y_start)
↓ 37 ░ ░ ░ ... █ █ ... █ ░ ... ░
38 ░ ░ ░ ... █ █ ... █ ░ ... ░
: : : : : : : : :
163 ░ ░ ░ ... █ █ ... █ ░ ... ░ ← End row (y_start+127)
164 ░ ░ ░ ... ░ ░ ... ░ ░ ... ░ ← NOT included!
: : : : : : : : :
511 ░ ░ ░ ... ░ ░ ... ░ ░ ... ░
↑ ↑
Start col End col
(x_start=50) (x_start+127=177)
178 NOT included!
█ = Extracted patch (128×128) ░ = Not extracted
# Step 1: Create a meshgrid of offsets # This is the pattern of offsets within the 128×128 patch y_offset = torch.arange(128) # [0, 1, 2, …, 127] x_offset = torch.arange(128) # [0, 1, 2, …, 127]
grid_y, grid_x = torch.meshgrid(y_offset, x_offset, indexing=’ij’) # grid_y: [[0,0,0,…,0], grid_x: [[0,1,2,…,127], # [1,1,1,…,1], [0,1,2,…,127], # … … # [127,127,…,127]] [0,1,2,…,127]]
Visual representation of grid:
grid_y (128×128): grid_x (128×128): ┌───────────────┐ ┌───────────────┐ │ 0 0 0 … 0 │ │ 0 1 2 …127│ │ 1 1 1 … 1 │ │ 0 1 2 …127│ │ 2 2 2 … 2 │ │ 0 1 2 …127│ │ : : : : │ │ : : : : │ │127 127 …127 │ │ 0 1 2 …127│ └───────────────┘ └───────────────┘
Think of these as: “relative position within patch”
# Step 2: Add center positions to get absolute coordinates # Shape: [4096, 128, 128] y_indices = y_centers[:, None, None] + grid_y[None, :, :] x_indices = x_centers[:, None, None] + grid_x[None, :, :]
# Broadcasting explained: # y_centers[:, None, None] # [4096, 1, 1] # grid_y[None, :, :] # [1, 128, 128] # ─────────────────────── # y_indices # [4096, 128, 128] (broadcast!)
Broadcasting visualization:
y_centers[:, None, None]: grid_y[None, :, :]: ┌────┐ ┌──────────────┐ │100 │ │ 0 0 0 …0 │ │150 │ Broadcast → │ 1 1 1 …1 │ │200 │ ──────────── │ : : : : │ │ : │ │127 127 …127│ │300 │ └──────────────┘ └────┘ [4096,1,1] [1, 128, 128]
Result y_indices [4096, 128, 128]: Position 0: [[100,100,…,100], ← y_center=100 + offset 0
[101,101,…,101], ← y_center=100 + offset 1 … [227,227,…,227]] ← y_center=100 + offset 127
- Position 1: [[150,150,…,150], ← y_center=150 + offset 0
[151,151,…,151], … [277,277,…,277]]
# Step 3: Gather using advanced indexing (MAGIC! ✨) patches_real = object_real[y_indices, x_indices] patches_imag = object_imag[y_indices, x_indices]
# Shape: [4096, 128, 128] # This is a SINGLE GPU operation! All 4096 patches extracted at once!
Concrete example: Advanced indexing with real numbers
Let’s use a small 8×8 object and extract 2 patches of 3×3:
# Object (8×8) with easy-to-track values object_real = torch.tensor([
[10, 11, 12, 13, 14, 15, 16, 17], # Row 0 [20, 21, 22, 23, 24, 25, 26, 27], # Row 1 [30, 31, 32, 33, 34, 35, 36, 37], # Row 2 [40, 41, 42, 43, 44, 45, 46, 47], # Row 3 [50, 51, 52, 53, 54, 55, 56, 57], # Row 4 [60, 61, 62, 63, 64, 65, 66, 67], # Row 5 [70, 71, 72, 73, 74, 75, 76, 77], # Row 6 [80, 81, 82, 83, 84, 85, 86, 87], # Row 7
])
# Extract 2 patches (3×3 each) # Patch 0: center at (2, 2) → top-left = (1, 1) # Patch 1: center at (4, 5) → top-left = (3, 4)
# Step 1: Create offset grid for 3×3 patch grid_y = torch.tensor([[0, 0, 0],
[1, 1, 1], [2, 2, 2]])
- grid_x = torch.tensor([[0, 1, 2],
[0, 1, 2], [0, 1, 2]])
# Step 2: Centers and top-left corners y_centers = torch.tensor([2, 4]) x_centers = torch.tensor([2, 5]) y_start = y_centers - 1 # [1, 3] x_start = x_centers - 1 # [1, 4]
# Step 3: Build indices [2, 3, 3] y_indices = y_start[:, None, None] + grid_y[None, :, :] x_indices = x_start[:, None, None] + grid_x[None, :, :]
# y_indices: x_indices: # [[[1, 1, 1], [[[1, 2, 3], # [2, 2, 2], [1, 2, 3], # [3, 3, 3]], [1, 2, 3]], # # [[3, 3, 3], [[4, 5, 6], # [4, 4, 4], [4, 5, 6], # [5, 5, 5]]] [4, 5, 6]]]
# Step 4: Advanced indexing! patches_real = object_real[y_indices, x_indices]
What happens internally:
Assume center is 32, position (0,0) becomes 21, (0,1) 22, (0,2) 23, etc.
object_real is the main 8 by 8.
# For patch 0, position (0,0): y_indices[0,0,0] = 1, x_indices[0,0,0] = 1 patches_real[0,0,0] = object_real[1, 1] = 21
# For patch 0, position (0,1): y_indices[0,0,1] = 1, x_indices[0,0,1] = 2 patches_real[0,0,1] = object_real[1, 2] = 22
# For patch 0, position (1,1): y_indices[0,1,1] = 2, x_indices[0,1,1] = 2 patches_real[0,1,1] = object_real[2, 2] = 32
# … and so on for all 18 positions (2 patches × 3×3)
Result:
- patches_real = [
- [[21, 22, 23], # Patch 0
[31, 32, 33], # ← Center at (2,2) = 32 [41, 42, 43]],
- [[44, 45, 46], # Patch 1
[54, 55, 56], # ← Center at (4,5) = 55 [64, 65, 66]]
]
Visual:
Object (8×8): 10 11 12 13 14 15 16 17 20 21 22 23 24 25 26 27 Patch 0 (red): Patch 1 (blue): 30 31 32 33 34 35 36 37 21 22 23 44 45 46 40 41 42 43 44 45 46 47 31 32 33 54 55 56 50 51 52 53 54 55 56 57 41 42 43 64 65 66 60 61 62 63 64 65 66 67 70 71 72 73 74 75 76 77 80 81 82 83 84 85 86 87
The key: y_indices and x_indices act as a “lookup table” telling PyTorch exactly which element to grab from object_real for each position in the output.
How advanced indexing works
object_real[y_indices, x_indices]
PyTorch internally does: for b in range(4096):
- for i in range(128):
- for j in range(128):
y = y_indices[b, i, j] # e.g., 100 x = x_indices[b, i, j] # e.g., 50 result[b, i, j] = object_real[y, x]
But this runs as PARALLEL threads on GPU! Not a loop! ```
GPU parallelization:
`
CPU Loop (sequential): GPU (parallel):
┌───┐ ┌─┬─┬─┬─┬─┬─┬─┬─┐
│ 1 │ → │ 2 │ → │ 3 │ → ... │1│2│3│4│5│6│7│8│ ← All at once!
└───┘ └─┴─┴─┴─┴─┴─┴─┴─┘
Time: 4096 × 128×128 ops Time: ~1ms (massively parallel)
`
Complex multiplication: Object × probe
Complex number: z = a + bi
where: a = real part, b = imaginary part, i = √(-1)
Multiplication rule:
(a + bi) × (c + di) = (ac - bd) + (ad + bc)i
───────── ─────────
real imag
In our code
# We have:
object_patch = object_real[y_indices, x_indices] + i * object_imag[y_indices, x_indices]
probe = probe_real + i * probe_imag
# Exit wave = object_patch × probe
# Using formula: (a+bi)×(c+di) = (ac-bd) + (ad+bc)i
exit_real = object_real[y_indices, x_indices] * probe_real \
- object_imag[y_indices, x_indices] * probe_imag
exit_imag = object_real[y_indices, x_indices] * probe_imag \
+ object_imag[y_indices, x_indices] * probe_real
Visual example (single pixel)
Object patch at position (0,0): 0.8 + 0.3i
Probe at position (0,0): 0.9 + 0.1i
Exit wave = (0.8 + 0.3i) × (0.9 + 0.1i)
= (0.8×0.9 - 0.3×0.1) + (0.8×0.1 + 0.3×0.9)i
= (0.72 - 0.03) + (0.08 + 0.27)i
= 0.69 + 0.35i
exit_real[0,0,0] = 0.69
exit_imag[0,0,0] = 0.35
Shape transformation
Input:
object_real[y_indices, x_indices] [4096, 128, 128]
object_imag[y_indices, x_indices] [4096, 128, 128]
probe_real [128, 128]
probe_imag [128, 128]
Broadcasting (probe gets repeated):
probe_real → [4096, 128, 128] (same value for all 4096 batches)
Output:
exit_real [4096, 128, 128]
exit_imag [4096, 128, 128]
Broadcasting visualization:
Batch dimension:
probe_real [128, 128]
↓ broadcast
┌─────┬─────┬─────┬─────┬─────┬─────┐
│ #0 │ #1 │ #2 │ ... │4094 │4095 │
└─────┴─────┴─────┴─────┴─────┴─────┘
↑ ↑ ↑ ↑ ↑
Same probe repeated 4096 times (implicitly)