GPU optimization

Objective: Reduce ptychography reconstruction time from 0.11s to < 0.01s per iteration on a single L40s GPU, achieving 1 second for 100 iterations. On more powerful GPUs like H100, this could be even faster due to higher memory bandwidth (3.0-3.3 TB/s vs L40S’s 864 GB/s).

Current performance: 0.11s per iteration

  • 512×512 object (amplitude + phase)

  • 80×80 probe

  • 256×256 diffraction patterns

  • 4,096 scan positions per iteration

Target performance: < 0.01s per iteration (11× speedup required)

Strategy: Profile to identify bottlenecks, write custom FP16 CUDA kernels, fuse operations across FFT boundaries.

GPU memory bandwidth reference:

  • L40S: 864 GB/s

  • A100 80GB: 2.0 TB/s

  • H100: 3.0-3.3 TB/s

This document covers the actual implementation with measured profiling data. For CUDA programming fundamentals, see CUDA tutorials.

PyTorch optimization techniques

This section covers general PyTorch optimization patterns used throughout the ptychography implementation. Understanding these techniques is essential before diving into the custom CUDA kernels.

Using batch gather to obtain object patches (foundational technique, used once)

In ptychography, we need to extract 4,096 overlapping patches (128×128 pixels each, matching the probe size) from a large object (512×512 pixels) at different scan positions.

Visualization:

Object (512×512):                    Extract 4,096 patches (128×128):
┌────────────────────┐              ┌──────┐ ┌──────┐ ┌──────┐
│                    │              │Patch │ │Patch │ │Patch │
│    ┌──────┐        │    ───→      │  1   │ │  2   │ │  3   │
│    │      │        │              └──────┘ └──────┘ └──────┘
│    └──────┘        │                    ...
│         ┌──────┐   │              ┌──────┐
│         │      │   │              │Patch │
│         └──────┘   │              │ 4096 │
└────────────────────┘              └──────┘

Each patch: 128×128 pixels at position (y_i, x_i)

Naive approach with Python loops (extremely slow)

# Bad: Sequential extraction with Python loop
patches = []
for i in range(4096):
    y_start, x_start = scan_positions[i]
    patch = object[y_start:y_start+128, x_start:x_start+128]
    patches.append(patch)

patches = torch.stack(patches)  # [4096, 128, 128]

# Problem: 4,096 separate CPU→GPU→CPU round trips
# Time: ~500ms on T4 GPU

Why this is slow:

  • Each iteration launches a separate GPU kernel

  • Python loop prevents vectorization

  • Overhead: 4,096 × ~0.12ms = 492ms wasted

Optimized approach with batch gather (fast)

import torch

# Setup: precompute indices once
object = torch.randn(512, 512).cuda()  # Large 2D array on GPU
scan_positions = torch.tensor([...]).cuda()  # [4096, 2] (y, x) positions

# Create coordinate grids for patch extraction
patch_size = 128  # Must match probe size
grid_y = torch.arange(patch_size).cuda()
grid_x = torch.arange(patch_size).cuda()
grid_y, grid_x = torch.meshgrid(grid_y, grid_x, indexing='ij')

# Broadcast scan positions + local grid
y_indices = scan_positions[:, 0:1, None] + grid_y[None, :, :]  # [4096, 128, 128]
x_indices = scan_positions[:, 1:2, None] + grid_x[None, :, :]  # [4096, 128, 128]

# Single batched gather operation
patches = object[y_indices, x_indices]  # [4096, 128, 128]

# Time: ~0.8ms on T4 GPU (625× faster!)

How it works:

Step 1: Create local coordinate grid (128×128)
┌─────────────┐
│ 0  1  2 ... │
│ 1  2  3 ... │
│ 2  3  4 ... │
│ ...         │
└─────────────┘

Step 2: Add scan position offset to each grid
Position 1: (y=50, x=100)     Position 2: (y=75, x=150)
┌─────────────┐               ┌─────────────┐
│ 50 51 52... │               │ 75 76 77... │
│ 51 52 53... │               │ 76 77 78... │
│ ...         │               │ ...         │
└─────────────┘               └─────────────┘

Step 3: Use advanced indexing to gather all 4,096 patches in parallel
object[y_indices, x_indices] → Single GPU kernel launch

Performance comparison

Approach

Time (ms)

Kernel Launches

Speedup

Python loop

~500ms

4,096

Batch gather

~0.8ms

1

625×

Key insights

  1. Avoid Python loops over GPU operations - Always vectorize when possible

  2. Advanced indexing is highly optimized - PyTorch’s tensor[indices] uses efficient GPU gather kernels

  3. Precompute coordinate grids - Create indices once, reuse across iterations

  4. Broadcasting eliminates loops - Use broadcasting to generate all indices simultaneously

Warning

Gotcha: Make sure indices are on the same device as the tensor. Mixed CPU/GPU indices force expensive device transfers:

# Bad: indices on CPU, object on GPU
patches = object_gpu[indices_cpu]  # Transfers indices CPU→GPU every call

# Good: both on GPU
patches = object_gpu[indices_gpu]  # Fast, no transfers

When to write custom CUDA kernels

PyTorch core developers write custom CUDA kernels for operations not covered by cuBLAS or cuDNN. These include element-wise operations, reductions, indexing, and specialized functions. Users can also write custom CUDA kernels using PyTorch’s C++ extension API or tools like Triton (which compiles Python-like code to GPU kernels).

Most PyTorch users never write kernels since the built-in operations cover typical neural network needs. However, custom kernels enable specialized optimizations for novel architectures or unique algorithms.

When custom CUDA justifies the complexity:

For specialized algorithms where memory access patterns matter more than ease of development, CUDA offers performance gains that justify implementation complexity. PyTorch’s general-purpose kernels cannot provide precise control over GPU resources that production systems need.

Why use custom CUDA when PyTorch is already fast?

PyTorch provides high-level operations but custom CUDA kernels give fine-grained control over memory access patterns, thread synchronization, and data layout. Optimizations like memory tiling can achieve 3-8× speedups. Production systems like FlashAttention (GPT-4), xFormers (Llama), DeepSpeed (Microsoft), and vLLM use custom CUDA C++ for maximum performance.

Mixing PyTorch operations with custom CUDA kernels

You can mix PyTorch operations with custom CUDA kernels. For example, PyTorch handles FFTs using cuFFT while your kernel implements specialized phase retrieval updates. PyTorch tensors expose raw GPU memory pointers that your CUDA kernel can access directly through the C++ extension API.

This hybrid approach lets you leverage PyTorch’s optimized FFT implementation while adding custom operations that aren’t built-in, like constraint projection or error reduction.

import torch
from torch.utils.cpp_extension import load

# Load custom CUDA kernel
custom_kernel = load(
    name='custom_ops',
    sources=['custom_kernel.cu'],
    extra_cuda_cflags=['-O3']
)

# Mix PyTorch and custom operations
x = torch.randn(512, 512, dtype=torch.complex64).cuda()

# Use PyTorch's optimized FFT
x_fft = torch.fft.fft2(x)

# Apply custom CUDA kernel
x_processed = custom_kernel.apply_constraint(x_fft)

# Back to PyTorch for inverse FFT
result = torch.fft.ifft2(x_processed)

Where are PyTorch CUDA kernels in source code?

PyTorch CUDA kernels live in aten/src/ATen/native/cuda/ for core operations and aten/src/ATen/native/cudnn/ for cuDNN wrappers. Files like BinaryOps.cu contain element-wise operations and Activation.cu has functions like ReLU. Each .cu file contains CUDA C++ code with kernel launches and device functions organized by operation type. You can browse them at pytorch/pytorch on GitHub.

PyTorch vs CuPy: When to use automatic differentiation

The problem:

CuPy handles standard GPU operations (FFT, array math, linear algebra) but lacks automatic differentiation. For optimization problems including loss function minimization, parameter fitting, and deep learning, you need autodiff. Writing gradients manually is error-prone and slow. Additionally, cross-platform support for NVIDIA, Mac, and AMD is important for many workflows.

The solution:

PyTorch provides GPU arrays with automatic differentiation (autodiff). You define your computation like loss = (model(x) - y)**2 then call loss.backward() where PyTorch computes all gradients automatically. GPU-accelerated optimizers update parameters efficiently. PyTorch works on NVIDIA (CUDA), Mac (Metal), and AMD (ROCm).

When to use PyTorch vs CuPy:

  • Use PyTorch for: Deep learning, training neural networks on diffraction data, parameter optimization, gradient-based fitting, Mac compatibility

  • Use CuPy for: Pure array operations (FFT, virtual imaging) on NVIDIA GPUs without needing gradients

Performance comparison example

256×256 4D-STEM FFT on RTX 4090:

Framework

Time (seconds)

Notes

CuPy

0.8s

cuFFT directly

PyTorch

0.85s

cuFFT through PyTorch

NumPy (CPU)

45s

56× slower

Both call cuFFT underneath with nearly identical speed, so choose by features not performance.

Background: Ptychography reconstruction

Ptychography reconstructs a high-resolution complex-valued object \(O(x,y)\) and probe \(P(x,y)\) from diffraction intensity measurements. The object and probe are both trainable parameters updated via gradient descent.

Problem setup:

  • Object: Complex-valued transmission function at each pixel (amplitude and phase)

  • Probe: Complex-valued electron beam wavefront (contains aberrations)

  • Measurements: Diffraction intensity patterns at detector (magnitude only, phase lost)

  • Goal: Recover object and probe from intensity-only measurements

Forward model: Seven computational steps per iteration

Step 1: Generate probe from aberration parameters (k-space → real space)

The probe starts in reciprocal space with aberration function:

\[\chi(k_x, k_y) = \pi \lambda \Delta f k^2 + \frac{1}{2}\pi \lambda^3 C_s k^4 + \pi \lambda A_{astig} k^2 \cos(2\theta)\]

where \(k = \sqrt{k_x^2 + k_y^2}\) is the spatial frequency, \(\Delta f\) is defocus, \(C_s\) is spherical aberration, and \(A_{astig}\) is astigmatism magnitude.

The probe in k-space is:

\[\tilde{P}(k_x, k_y) = A(k) \cdot e^{i\chi(k_x, k_y)}\]

where \(A(k)\) is the aperture function (sigmoid for soft edges).

Transform to real space via inverse FFT:

\[P(x,y) = \mathcal{F}^{-1}\{\tilde{P}(k_x, k_y)\}\]

Step 2: Construct object from amplitude and phase

\[O(x,y) = A(x,y) \cdot e^{i\phi(x,y)}\]

Step 3: Extract object patches and compute exit waves

For each scan position \(j\):

\[T_j(x,y) = O(x + x_j, y + y_j) \cdot P(x,y)\]

Step 4: Propagate to detector via FFT

\[\Psi_j(u,v) = \mathcal{F}\{T_j(x,y)\}\]

Step 5: Compute predicted intensity

\[I_j^{\text{pred}}(u,v) = |\Psi_j(u,v)|^2\]

Step 6: Loss function

\[\mathcal{L} = \frac{1}{N} \sum_{j=1}^{N} \frac{1}{HW} \sum_{u,v} \left(I_j^{\text{pred}}(u,v) \cdot s_j^{\text{pred}} - I_j^{\text{meas}}(u,v) \cdot s_j^{\text{meas}}\right)^2\]

where \(s_j = \text{counts} / \text{mean}(I_j)\) is the per-position scaling factor.

Step 7: Backward pass and optimization

PyTorch computes gradients via autograd and updates all trainable parameters using the Adam optimizer:

\[\frac{\partial \mathcal{L}}{\partial \theta} \quad \text{for all} \quad \theta \in \{\text{object amplitude, object phase, probe aberrations}\}\]

Note

Caveat: Autograd overhead is significant. For 4,096 scan positions, PyTorch builds a computation graph with thousands of nodes, costing ~50% of wall-clock time. Custom CUDA kernels bypass autograd for the forward pass, then use PyTorch’s built-in backward implementations.

Trainable parameters being optimized

The object is stored as two separate real-valued arrays (amplitude and phase) but combined into complex-valued form during forward pass:

\[O(x,y) = A(x,y) \cdot e^{i\phi(x,y)} = A(x,y) \cdot (\cos\phi + i\sin\phi)\]

Why separate storage? Gradients flow better through separate amplitude and phase parameters. PyTorch’s autograd computes \(\frac{\partial \mathcal{L}}{\partial A}\) and \(\frac{\partial \mathcal{L}}{\partial \phi}\) independently, then updates each via Adam. If we stored the complex object directly, gradients would couple real/imaginary parts in a less physically meaningful way.

Parameter (stored)

Physical meaning

Shape

Dtype

Object parameters (stored separately)

object_amplitude (A)

Transmission amplitude (0-1 range)

[512, 512]

float32

object_phase (φ)

Phase shift in radians

[512, 512]

float32

Computed during forward

object_complex (O)

\(A \cdot e^{i\phi}\) (not stored, computed on-the-fly)

[512, 512]

complex64

Probe aberration parameters

defocus (Δf)

Defocus distance in nm (C1)

[] (scalar)

float32

Cs

Spherical aberration in mm (C3)

[] (scalar)

float32

astig_mag

Astigmatism magnitude in nm

[] (scalar)

float32

astig_angle

Astigmatism orientation in radians

[] (scalar)

float32

aperture_smooth

Aperture edge smoothness parameter

[] (scalar)

float32

Total trainable parameters: 524,293

  • Object: 512×512×2 = 524,288 parameters (amplitude + phase stored separately)

  • Probe aberrations: 5 parameters (defocus, Cs, astigmatism magnitude, astigmatism angle, aperture smoothness)

Optimization strategy: Two separate Adam optimizers with fused kernels update object and probe parameters independently after each iteration.

Warning

Gotcha: Object is stored as two separate real-valued arrays (amplitude and phase) but must be combined into complex form during forward pass. This separation enables better gradient flow but requires careful handling to avoid creating unnecessary intermediate tensors.

What stays in PyTorch vs what moves to custom CUDA

Keep in PyTorch (already optimal or difficult to replace):

  • FFT operations: torch.fft.fft2() and torch.fft.ifft2()

    • Uses cuFFT internally (NVIDIA’s highly optimized library)

    • Cannot be beaten with custom implementation

    • Accounts for 49.4% of GPU time (will remain constant)

  • Adam optimizer: torch.optim.Adam(fused=True)

    • Single CUDA kernel with numerically stable updates

    • Well-tested, includes gradient clipping and weight decay

    • Replacing offers minimal gains

  • Backward pass: PyTorch’s automatic differentiation

    • Highly optimized for common operations

    • Maintains numerical stability

    • Custom backward kernels offer ~20% gains but add complexity

Move to custom CUDA kernels (high-impact optimizations):

  1. Probe generation: \(\chi(k) → \tilde{P}(k)\) preparation before IFFT

    • Fuses defocus, spherical aberration, astigmatism calculations

    • Runs 1× per iteration (low priority)

  2. Exit wave computation: \(O[indices] \cdot P\)

    • Fuses tile gather + amplitude/phase → complex + complex multiply

    • Runs 4,096× per iteration (CRITICAL)

    • FP16 implementation uses Tensor Cores for 2-3× speedup

  3. Intensity and loss: \(|\Psi|^2 → \text{scaling} → \text{loss}\)

    • Fuses magnitude squared, per-batch scaling, loss reduction

    • Runs 4,096× per iteration (IMPORTANT)

    • Eliminates 6 separate kernel launches

Note

Key insight: With 4,096 scan positions per iteration, operations that run once per position dominate the total time. Prioritize optimizing exit wave and intensity loss over probe generation.

The actual implementation

Here’s the complete code being optimized. Current performance: 0.11s per iteration

Aberration-parameterized probe model
import torch
import torch.nn as nn
import numpy as np

class AberrationProbe(nn.Module):
    """Physics-informed probe with aberration parameterization."""

    def __init__(self, size, pixel_size, wavelength, convergence_angle, device='cuda'):
        super().__init__()
        self.size = size
        self.pixel_size = pixel_size
        self.wavelength = wavelength
        self.convergence_angle = convergence_angle
        self.device = device

        # Create k-space grids - registered as buffers (no gradient retention)
        kx = torch.fft.fftfreq(size, d=pixel_size, device=device)
        ky = torch.fft.fftfreq(size, d=pixel_size, device=device)
        kx, ky = torch.meshgrid(kx, ky, indexing='ij')
        self.register_buffer('k', torch.sqrt(kx**2 + ky**2))
        self.register_buffer('k_angle', torch.atan2(ky, kx))

        # Trainable aberration coefficients
        self.defocus = nn.Parameter(torch.tensor(50.0))      # nm (C1)
        self.Cs = nn.Parameter(torch.tensor(1.0))            # mm (C3)
        self.astig_mag = nn.Parameter(torch.tensor(10.0))    # nm
        self.astig_angle = nn.Parameter(torch.tensor(0.0))   # radians
        self.aperture_smooth = nn.Parameter(torch.tensor(0.1))

    def compute_aberration_function(self):
        """χ(k) = π λ Δf k² + (1/2) π λ³ Cs k⁴ + astigmatism"""
        k = self.k
        wavelength = self.wavelength
        Cs_nm = self.Cs * 1e6  # mm → nm

        # Defocus term
        chi_defocus = np.pi * wavelength * self.defocus * k**2

        # Spherical aberration term
        chi_Cs = 0.5 * np.pi * wavelength**3 * Cs_nm * k**4

        # Astigmatism
        chi_astig = (np.pi * wavelength * self.astig_mag * k**2 *
                    torch.cos(2 * (self.k_angle - self.astig_angle)))

        return chi_defocus + chi_Cs + chi_astig

    def compute_aperture(self):
        """Soft aperture function with sigmoid rolloff."""
        k_max = self.convergence_angle / self.wavelength
        smoothness = torch.abs(self.aperture_smooth) + 0.01
        return torch.sigmoid((k_max - self.k) / (k_max * smoothness))

    def forward(self):
        """Generate probe: P = IFFT[A(k) × exp(iχ(k))]"""
        chi = self.compute_aberration_function()
        aperture = self.compute_aperture()

        # Probe in k-space
        probe_k = aperture * torch.exp(1j * chi)

        # Transform to real space
        probe_real = torch.fft.ifftshift(torch.fft.ifft2(probe_k))

        # Normalize
        probe_real = probe_real / torch.sqrt(torch.sum(torch.abs(probe_real)**2))
        return probe_real

    def get_probe_crop(self, crop_size):
        """Get cropped probe WITH gradients for optimization."""
        probe_full = self.forward()
        center = self.size // 2
        half = crop_size // 2
        return probe_full[center-half:center+half, center-half:center+half]
Physics-informed ptychography model
class PhysicsInformedPtychography(nn.Module):
    """Ptychography model with batched operations."""

    def __init__(self, object_size, probe_size, diffraction_size, scan_positions,
                 pixel_size, wavelength, convergence_angle, device='cuda'):
        super().__init__()
        self.object_size = object_size
        self.probe_size = probe_size
        self.diffraction_size = diffraction_size
        self.device = device

        # Precompute coordinate grids for batched gather
        dy = torch.arange(probe_size, device=device, dtype=torch.long)
        dx = torch.arange(probe_size, device=device, dtype=torch.long)
        grid_y, grid_x = torch.meshgrid(dy, dx, indexing='ij')
        self.register_buffer('grid_y', grid_y)
        self.register_buffer('grid_x', grid_x)

        # Convert scan positions to tensor
        self.register_buffer('scan_positions',
                           torch.tensor(scan_positions, dtype=torch.long, device=device))

        # Trainable object parameters
        self.object_amplitude = nn.Parameter(
            0.9 * torch.ones(object_size, object_size, device=device) +
            0.1 * torch.randn(object_size, object_size, device=device)
        )
        self.object_phase = nn.Parameter(
            0.1 * torch.randn(object_size, object_size, device=device)
        )

        # Probe model with trainable aberrations
        self.probe_model = AberrationProbe(
            diffraction_size, pixel_size, wavelength, convergence_angle, device=device
        )

    def get_object_complex(self):
        """Construct complex object from amplitude and phase."""
        return self.object_amplitude * torch.exp(1j * self.object_phase)

    def forward(self, batch_indices):
        """Forward pass with batched gather."""
        if not isinstance(batch_indices, torch.Tensor):
            batch_indices = torch.tensor(batch_indices, dtype=torch.long, device=self.device)

        batch_indices = batch_indices.detach()
        positions = self.scan_positions[batch_indices]

        # Get complex object
        object_complex = self.get_object_complex()

        # Get probe (regenerated every call - BOTTLENECK!)
        probe_complex = self.probe_model.get_probe_crop(self.probe_size)

        # Batched gather - extract all patches at once
        y_idx = positions[:, 0, None, None] + self.grid_y[None, :, :]
        x_idx = positions[:, 1, None, None] + self.grid_x[None, :, :]
        tiles = object_complex[y_idx, x_idx]  # [B, probe_size, probe_size]

        # Exit waves
        exit_waves = tiles * probe_complex.unsqueeze(0)

        # FFT with internal padding (cuFFT handles this optimally)
        Psi = torch.fft.fft2(exit_waves, s=(self.diffraction_size, self.diffraction_size))

        return Psi
Training loop
# Model setup
model = PhysicsInformedPtychography(
    object_size=512,
    probe_size=80,
    diffraction_size=256,
    scan_positions=scan_positions,  # 4,096 positions
    pixel_size=0.5,
    wavelength=0.0197,  # 300 keV electrons
    convergence_angle=0.02,
    device='cuda'
).to('cuda')

# Fused Adam optimizers (single CUDA kernel each)
optimizer_object = torch.optim.Adam(
    [model.object_amplitude, model.object_phase],
    lr=0.01,
    fused=True
)
optimizer_probe = torch.optim.Adam(
    model.probe_model.parameters(),
    lr=0.001,
    fused=True
)

# Precompute constants
I_meas = (diffraction_patterns ** 2).contiguous()
meas_scale = (counts / I_meas.mean(dim=(1,2), keepdim=True)).contiguous()

# JIT-compiled loss function
@torch.jit.script
def compute_intensity_loss(I_pred, I_meas, meas_scale, counts):
    pred_scale = counts / I_pred.mean(dim=(1,2), keepdim=True)
    diff = I_pred * pred_scale - I_meas * meas_scale
    return (diff * diff).mean()

# Training
torch.backends.cudnn.benchmark = True
batch_indices = torch.arange(len(scan_positions), dtype=torch.long, device='cuda')

for iteration in range(num_iterations):
    # Forward
    Psi = model(batch_indices)
    I_pred = torch.abs(Psi) ** 2
    loss = compute_intensity_loss(I_pred, I_meas, meas_scale, counts)

    # Backward
    optimizer_object.zero_grad(set_to_none=True)
    optimizer_probe.zero_grad(set_to_none=True)
    loss.backward()

    # Update
    optimizer_object.step()
    optimizer_probe.step()

Current baseline: Already implemented optimizations

Before profiling and writing custom kernels, several standard PyTorch optimizations are already in place:

Batched gather operations: All 4,096 object patches extracted in single GPU operation (no Python loops)

Registered buffers: Precomputed coordinate grids (k-space, scan positions) stored on GPU, not recomputed

FFT internal padding: Using s parameter to delegate padding to cuFFT (more efficient than explicit torch.nn.functional.pad)

Fused Adam optimizer: torch.optim.Adam(fused=True) uses single CUDA kernel for parameter updates

JIT-compiled loss function: TorchScript compilation eliminates Python interpreter overhead

cuDNN autotuning: torch.backends.cudnn.benchmark = True selects optimal convolution algorithms for tensor shapes

Impact: These optimizations reduced iteration time from ~0.5s to 0.11s per iteration (4.5× improvement).

Challenge: To reach < 0.01s requires another 11× speedup. Standard PyTorch optimizations exhausted, custom CUDA required.

Warning

Gotcha: fused=True in Adam only works when all parameters are on the same device and contiguous in memory. If parameters are scattered across tensors, PyTorch falls back to unfused implementation without warning. Check with torch.profiler to verify single adam_kernel launch.

Profiling: Identifying the bottlenecks

Running torch.profiler for 100 iterations reveals where time is spent:

GPU time breakdown:

Name

Self CPU %

Self CUDA %

aten::_fft_c2c

0.11%

24.68%

aten::mul

0.60%

19.07%

aten::abs

0.04%

5.95%

autograd::engine::evaluate_function: AbsBackward0

0.04%

0.00%

AbsBackward0

0.02%

0.00%

void regular_fft<256u, ...>

0.00%

12.98%

autograd::engine::evaluate_function: FftC2CBackward0

0.02%

0.00%

FftC2CBackward0

0.01%

0.00%

void vector_fft<256u, ...>

0.00%

11.70%

autograd::engine::evaluate_function: PowBackward0

0.04%

0.00%

PowBackward0

0.02%

0.00%

Total times: Self CPU time: 11.011s | Self CUDA time: 11.030s

The real bottleneck: CPU time ≈ GPU time ≈ 11s for 100 iterations (0.11s per iteration).

Why CPU time matters: The user experiences wall-clock time, not just GPU time. PyTorch operations involve:

  1. Python overhead - Each operation calls Python → C++ → CUDA

  2. Kernel launch overhead - CPU must launch each GPU kernel (cudaLaunchKernel takes ~5-10μs)

  3. Autograd graph - CPU builds and manages backward computation graph

  4. Tensor bookkeeping - CPU tracks shapes, dtypes, memory, reference counts

  5. Synchronization - CPU waits for GPU results between operations

With 5,140 mul calls, 620 abs calls, etc., these CPU overheads accumulate to match the GPU compute time.

Custom CUDA eliminates CPU overhead: A single kernel launch does everything on GPU. CPU launches once, then waits for the final result. No Python loops, no intermediate tensors, no autograd graph during forward pass.

What PyTorch does internally for every operation

CPU-side overhead (happens for EVERY PyTorch operation):

  1. Python interpreter overhead

    • Function call from Python → C++ extension

    • Argument parsing and validation (type checking, shape checking)

    • GIL (Global Interpreter Lock) acquisition/release

    • Reference counting for Python objects

    • Cost: ~1-5μs per operation

  2. PyTorch dispatcher overhead

    • Determine which backend to use (CPU/CUDA/MPS)

    • Check if operation is in-place or requires new allocation

    • Dispatch to correct kernel implementation

    • Handle device placement and data type promotion

    • Cost: ~0.5-2μs per operation

  3. Memory management

    • Allocate output tensor (unless in-place)

    • Update CUDA memory allocator bookkeeping

    • Track tensor metadata (shape, stride, dtype, device)

    • Increment/decrement reference counts

    • Register tensor for garbage collection

    • Cost: ~2-10μs per allocation

  4. Autograd graph construction

    • Create autograd node for backward pass

    • Store input tensors (or their versions) for gradient computation

    • Link node into computation graph

    • Register hook functions if any

    • Track requires_grad status through operations

    • Cost: ~5-15μs per operation with gradients

  5. Kernel launch

    • Call cudaLaunchKernel() system call

    • Transfer kernel arguments to GPU (via driver)

    • Enqueue kernel in CUDA stream

    • Update stream synchronization state

    • Cost: ~5-10μs per kernel launch

GPU-side overhead (for operations PyTorch doesn’t fuse):

  1. Memory bandwidth bottleneck

    • Each operation reads inputs from global memory (slow)

    • Each operation writes outputs to global memory (slow)

    • No data reuse between operations (cache misses)

    • Example: A * B reads A and B, writes result. Next operation reads result again.

  2. Kernel launch latency

    • GPU must wait for kernel to be scheduled

    • Streaming multiprocessors (SMs) must load kernel code

    • Thread blocks must be assigned to SMs

    • Cost: ~1-5μs per kernel

  3. Synchronization overhead

    • Operations on same stream serialize (can’t overlap)

    • torch.cuda.synchronize() blocks CPU waiting for GPU

    • Implicit syncs when copying data back to CPU

Example: Single complex multiply in PyTorch

# This one line:
result = a_complex * b_complex  # complex64 tensors

# Triggers on CPU:
# - Python call overhead: ~3μs
# - Dispatcher: ~1μs
# - Memory allocation for result: ~5μs
# - Autograd node creation: ~10μs
# - 4× cudaLaunchKernel() calls: ~40μs (one per real/imag component)
# Total CPU time: ~59μs

# Triggers on GPU:
# - 4 kernel launches: ~4μs
# - Read a_complex (real + imag): 2 memory reads
# - Read b_complex (real + imag): 2 memory reads
# - Write result (real + imag): 2 memory writes
# - Actual compute: <1μs (multiplication is fast)
# Total GPU time: ~5μs

# Wall-clock time: max(CPU, GPU) ≈ 59μs for ONE multiply

With 5,140 complex multiplies:

CPU overhead: 5,140 × 59μs = 303ms
GPU overhead: 5,140 × 5μs = 26ms
Total added to iteration: ~330ms

Custom CUDA eliminates almost all of this:

// Single kernel launch from Python: ~20μs CPU overhead total
// GPU: Single kernel reads data once, writes result once
// Wall-clock time: ~5μs for ALL 5,140 multiplies fused together

Why this matters:

The profiling shows CPU time (11.01s) ≈ GPU time (11.03s). This means the code is spending equal time on:

  • CPU: Launching kernels, building autograd graphs, managing memory

  • GPU: Actually computing results

Custom CUDA collapses hundreds of operations into single kernels, eliminating most CPU overhead and allowing GPU to work continuously without waiting for CPU to launch the next kernel.

Mapping profiling to custom kernels

What can’t be optimized:

  • aten::_fft_c2c (24.7%): This IS cuFFT. No custom kernel beats cuFFT.

  • regular_fft + vector_fft (24.7%): cuFFT’s internal kernels. Optimal.

  • Total FFT: 49.4% - Keep using torch.fft.fft2()

What custom kernels will target:

GPU operations to fuse:

  • aten::mul (19.1%, 5,140 calls): Complex multiply scattered everywhere

  • aten::abs (5.95%, 620 calls): Complex → real conversion

  • Power and backward operations: Squaring in loss + gradient computations

  • Total fusible GPU operations: ~25-30% of GPU time

CPU overhead to eliminate:

  • 5,140 kernel launches for mul operations

  • Autograd graph building and traversal

  • Tensor allocation and memory management

  • Python/C++/CUDA call stack overhead

  • Total CPU overhead: ~50% of wall-clock time (11.01s CPU ≈ 11.03s GPU)

Total speedup potential:

  • 3× from kernel fusion (eliminate redundant memory operations on GPU)

  • 2× from eliminating CPU overhead (single kernel launch, no Python intermediaries, no autograd graph)

  • Total: 6× speedup → 0.018s per iteration

What does “fuse” mean in GPU computing?

Kernel fusion means combining multiple operations into a single GPU kernel to eliminate intermediate memory reads/writes.

Unfused operations (what PyTorch does):

# Three separate operations
a = torch.exp(x)        # Operation 1: read x, write a
b = a * y               # Operation 2: read a, read y, write b
c = torch.abs(b)        # Operation 3: read b, write c

# GPU memory traffic:
# Op 1: Read x [100MB], Write a [100MB] → 200MB
# Op 2: Read a [100MB], Read y [100MB], Write b [100MB] → 300MB
# Op 3: Read b [100MB], Write c [100MB] → 200MB
# Total: 700MB of memory traffic

# GPU execution:
# Launch kernel 1 → wait → launch kernel 2 → wait → launch kernel 3
# Each kernel launch: ~5-10μs CPU overhead + kernel execution time

Fused operations (custom CUDA):

__global__ void fused_kernel(float* x, float* y, float* c, int n) {
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    if (idx < n) {
        float a = expf(x[idx]);      // Compute in register
        float b = a * y[idx];        // Compute in register
        c[idx] = fabsf(b);           // Write only final result
    }
}

// GPU memory traffic:
// Read x [100MB], Read y [100MB], Write c [100MB] → 300MB
// Total: 300MB of memory traffic (2.3× less than unfused)

// GPU execution:
// Single kernel launch → compute everything → done
// One kernel launch: ~5-10μs CPU overhead (not 3×)

Why this matters:

  1. Memory bandwidth is the bottleneck - Modern GPUs can compute much faster than they can move data. Reducing memory traffic by 2-3× directly speeds up execution by 2-3×.

  2. Intermediate results stay in registers - GPU registers are ~1000× faster than global memory. Fused kernels keep a and b in registers, never writing them to slow global memory.

  3. CPU launches fewer kernels - One launch instead of three eliminates 20-30μs of CPU overhead per fused group.

Real example from ptychography:

# Unfused (PyTorch):
amp = object_amplitude[indices]           # Read, write → 4MB
phase = object_phase[indices]             # Read, write → 4MB
obj_real = amp * torch.cos(phase)         # Read amp, read phase, write → 12MB
obj_imag = amp * torch.sin(phase)         # Read amp, read phase, write → 12MB
obj_complex = torch.complex(obj_real, obj_imag)  # Read both, write → 12MB
exit_wave = obj_complex * probe           # Read obj, read probe, write → 12MB
# Total: 56MB memory traffic, 6 kernel launches
// Fused (custom CUDA):
__global__ void fused_exit_wave_kernel(...) {
    float amp = obj_amplitude[obj_idx];      // Read once
    float phase = obj_phase[obj_idx];        // Read once
    cuFloatComplex obj = make_cuFloatComplex(
        amp * cosf(phase),  // Compute in register
        amp * sinf(phase)   // Compute in register
    );
    cuFloatComplex probe_val = probe[idx];   // Read once
    exit_wave[idx] = cuCmulf(obj, probe_val);  // Write once
}
// Total: 12MB memory traffic (4.7× less), 1 kernel launch

Fusion is not optimization - it’s necessity:

For 4,096 scan positions, unfused operations create:

  • 24,576 kernel launches (4,096 positions × 6 operations)

  • 229GB of memory traffic (56MB × 4,096)

  • 250ms of CPU overhead just launching kernels

Fused into one kernel:

  • 1 kernel launch

  • 49GB of memory traffic (4.7× reduction)

  • 10μs of CPU overhead

Result: 4.7× faster GPU compute + elimination of CPU bottleneck = 5-6× total speedup.

Why PyTorch cannot fuse complex number operations

PyTorch’s fundamental limitation: The JIT compiler and torch.compile() do not fuse operations involving complex tensors. This is not a bug - it is an architectural limitation.

What happens with a simple complex multiply:

# This single line of PyTorch code:
exit_waves = object_complex * probe_complex  # Both are complex64

# Becomes FOUR separate GPU kernel launches:
# Kernel 1: Extract real parts from object_complex
# Kernel 2: Extract imaginary parts from object_complex
# Kernel 3: Compute real_out = real_a * real_b - imag_a * imag_b
# Kernel 4: Compute imag_out = real_a * imag_b + imag_a * real_b

# Each kernel:
# - Reads from global memory (slow)
# - Writes to global memory (slow)
# - Requires CPU to launch it (~5-10μs overhead)
# - Allocates intermediate tensors
# - Registers operation in autograd graph

The cascade effect with 5,140 complex multiplies:

5,140 complex multiplies × 4 kernels each = 20,560 kernel launches
20,560 launches × 10μs CPU overhead = 205ms just in launch overhead
Plus: 20,560 memory allocations, 20,560 autograd registrations
Result: CPU time (11.01s) ≈ GPU time (11.03s)

Custom CUDA eliminates this completely:

// Single kernel, single launch, everything in registers:
cuFloatComplex obj_val = make_cuFloatComplex(
    amp * cosf(phase),  // Real part
    amp * sinf(phase)   // Imaginary part
);
cuFloatComplex probe_val = probe[idx];
exit_waves[idx] = cuCmulf(obj_val, probe_val);  // One instruction

// CPU overhead: Single cudaLaunchKernel() call (~10μs total)
// GPU overhead: No intermediate memory reads/writes
// Result: 3× faster GPU compute + 2× faster CPU overhead = 6× total

Why this matters for ptychography:

The forward pass has three cascading complex operations:

  1. \(A(x,y) \cdot e^{i\phi(x,y)}\) → object construction (complex exp + multiply)

  2. \(O[indices] \cdot P(x,y)\) → exit waves (gather + complex multiply)

  3. \(|\Psi(u,v)|^2\) → intensity (complex absolute value + square)

Each gets fragmented into multiple kernels. Backward pass makes it worse with gradient computations through each operation.

Custom CUDA is not optional - it is necessary. PyTorch fundamentally cannot optimize this workload.

The three custom kernels to write

Kernel 1: fused_exit_wave_kernel

Eliminates 19.1% overhead

This kernel computes the exit waves for all 4,096 scan positions in a single GPU call. For each position, it constructs the complex object from separate amplitude and phase arrays, extracts the relevant patch using precomputed indices, and multiplies by the probe function. PyTorch does this in three separate kernel launches with intermediate tensor allocations. The custom kernel fuses everything into one pass, eliminating memory roundtrips and CPU overhead.

Fuses: \(A(x,y) \cdot e^{i\phi(x,y)}\) at indices → multiply by \(P(x,y)\)

Parameter

Description

Shape/Type

Input

obj_amplitude

Object amplitude (real-valued, 0-1 range)

[512, 512] float

obj_phase

Object phase (real-valued, radians)

[512, 512] float

y_indices

Vertical positions for all patches

[4096, 80, 80] int

x_indices

Horizontal positions for all patches

[4096, 80, 80] int

probe

Illumination function (complex)

[80, 80] complex64

Output

exit_waves

Transmitted wavefield for each position

[4096, 80, 80] complex64

Replaces: 3 separate PyTorch ops (exp, gather, mul) + 3 kernel launches + 2 intermediate tensors

Kernel 2: fused_intensity_loss_kernel

Eliminates 30.5% overhead

This kernel computes the complete loss function in a single GPU kernel with streaming reduction. It converts the complex diffraction patterns to intensities, applies per-position scaling factors, computes squared differences against measured data, and reduces to a scalar loss - all without writing intermediate results to global memory. PyTorch does this as six separate operations, each allocating full-size temporary tensors. The fused kernel processes data on-the-fly using shared memory reductions.

Fuses: \(|\Psi|^2 \rightarrow \text{scale} \rightarrow \text{diff}^2 \rightarrow \text{mean}\)

Parameter

Description

Shape/Type

Input

Psi

Diffracted wavefields (complex)

[4096, 256, 256] complex64

I_meas

Measured intensities (precomputed)

[4096, 256, 256] float

meas_scale

Scaling for measured data (precomputed)

[4096, 1, 1] float

counts

Total photon counts per pattern

[4096] float

Output

loss

Mean squared error (scalar)

[] float

I_pred

Predicted intensities (saved for backward)

[4096, 256, 256] float

Replaces: 6 separate PyTorch ops (abs, pow, mean, mul, sub, mean) + 6 kernel launches + 5 intermediate tensors

Kernel 3: fused_exit_wave_backward_kernel

Speeds up backward pass

This kernel computes gradients for the exit wave operation by reversing the forward computation. Given gradients with respect to the exit waves, it efficiently backpropagates through the complex multiplication and object construction. The challenge is handling atomic additions for overlapping patches - each object pixel is visited by multiple scan positions, so gradients must accumulate atomically. PyTorch’s autograd handles this correctly but builds a computational graph with overhead. The custom kernel implements the exact derivatives directly, using atomic operations for the scatter-add pattern.

Fuses: Gradients for obj_amplitude, obj_phase, and probe from grad_exit_waves

Parameter

Description

Shape/Type

Input

grad_exit_waves

Gradients from loss w.r.t. exit waves

[4096, 80, 80] complex64

obj_amplitude

Forward pass object amplitude (for chain rule)

[512, 512] float

obj_phase

Forward pass object phase (for chain rule)

[512, 512] float

probe

Forward pass probe (for chain rule)

[80, 80] complex64

y_indices

Vertical positions (same as forward)

[4096, 80, 80] int

x_indices

Horizontal positions (same as forward)

[4096, 80, 80] int

Output

grad_obj_amplitude

Gradients w.r.t. object amplitude

[512, 512] float

grad_obj_phase

Gradients w.r.t. object phase

[512, 512] float

grad_probe

Gradients w.r.t. probe function

[80, 80] complex64

Replaces: PyTorch’s autograd through complex ops (general-purpose graph traversal) with direct derivative computation using atomic scatter-add

Summary

Everything else (probe generation with aberrations, Adam, autograd) stays in PyTorch for now, but could be moved to CUDA for further speedup.

FFT itself (49.4% of GPU time) can’t be improved - cuFFT is already optimal. But the operations around it can be fused, and more importantly, moving everything to custom CUDA eliminates PyTorch’s CPU overhead (kernel launches, autograd graph, tensor bookkeeping) which roughly doubles the wall-clock time.

Expected total speedup:

  • 2.5-3× from kernel fusion (GPU compute)

  • 2× from eliminating CPU overhead (PyTorch → pure CUDA)

  • Total: 5-6× → 0.018-0.022s per iteration

Custom kernel 1: Probe generation in k-space (low priority, but needed)

Reality check: Probe aberration parameters (defocus, Cs, astigmatism) ARE trainable in ptychography - that’s the whole point! But this kernel is still low priority because it runs only 1× per iteration vs 4,096× for exit waves.

The probe generation code:

# Runs once per iteration after optimizer updates aberration parameters
chi_defocus = np.pi * wavelength * self.defocus * k**2
chi_Cs = 0.5 * np.pi * wavelength**3 * Cs_nm * k**4
chi_astig = np.pi * wavelength * self.astig_mag * k**2 * torch.cos(2*(k_angle - astig_angle))
chi = chi_defocus + chi_Cs + chi_astig
aperture = torch.sigmoid((k_max - self.k) / (k_max * aperture_smooth))
probe_k = aperture * torch.exp(1j * chi)
probe_real = torch.fft.ifft2(probe_k)

Can PyTorch fuse these operations automatically?

Short answer: No, not effectively.

Why PyTorch can’t fuse this:

  1. JIT limitations: PyTorch’s torch.jit can’t optimize through:

    • torch.exp(1j * chi) - complex exponential with imaginary unit

    • torch.sigmoid() - elementwise nonlinearity

    • torch.cos() with scalar multiplication

    • Multiple intermediate tensor allocations

  2. TorchScript restrictions:

    @torch.jit.script
    def generate_probe_k(defocus, Cs, astig_mag, k, k_angle):
        # TorchScript can trace this, but won't fuse kernels
        chi = math.pi * wavelength * defocus * k**2 + ...  # Separate kernel
        aperture = torch.sigmoid((k_max - k) / smooth)     # Separate kernel
        probe_k = aperture * torch.exp(1j * chi)           # Separate kernel
        return probe_k
    

    Result: Still 3+ separate CUDA kernel launches, just with less Python overhead.

  3. PyTorch 2.0 ``torch.compile()``:

    @torch.compile(mode="max-autotune")
    def generate_probe_k(defocus, Cs, astig_mag, k, k_angle):
        chi = math.pi * wavelength * defocus * k**2 + ...
        aperture = torch.sigmoid((k_max - k) / smooth)
        probe_k = aperture * torch.exp(1j * chi)
        return probe_k
    
    • Might fuse some ops, but limited by TorchInductor’s complex number support

    • Complex exponential exp(1j * x) often falls back to eager mode

    • Gains: ~10-20% at best (vs 1.1-1.2× with custom CUDA)

    • Worth trying first since it’s just one decorator!

Custom CUDA kernel gives full control:

__global__ void generate_probe_k_kernel(
    const float* __restrict__ k,
    const float* __restrict__ k_angle,
    cuFloatComplex* __restrict__ probe_k,
    float defocus, float Cs, float astig_mag, float astig_angle,
    float wavelength, float k_max, float aperture_smooth,
    int size
) {
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    if (idx < size * size) {
        float k_val = k[idx];
        float angle = k_angle[idx];

        // All ops fused in registers - no intermediate memory writes
        float chi = M_PI * wavelength * defocus * k_val * k_val +
                   0.5f * M_PI * wavelength * wavelength * wavelength * Cs *
                   k_val * k_val * k_val * k_val +
                   M_PI * wavelength * astig_mag * k_val * k_val *
                   cosf(2.0f * (angle - astig_angle));

        float aperture = 1.0f / (1.0f + expf((k_val - k_max) / (k_max * aperture_smooth)));

        probe_k[idx] = make_cuFloatComplex(
            aperture * cosf(chi),
            aperture * sinf(chi)
        );
    }
}

Expected speedup: 1.1-1.2× (5-6 PyTorch kernel launches → 1 custom kernel)

But still low priority because:

  • Runs 1× per iteration (after optimizer step)

  • Exit wave kernel runs 4,096× per iteration (every scan position)

  • Even at 10× slower, it’s only ~10% of total time

  • Focus on Kernel 2 (exit wave FP16) first - that’s where 2-3× gains are

Recommendation: Implement exit wave FP16 kernel (Kernel 2) first. Come back to this only if you need that extra 10-20%.

Recommendation: Start with Kernel 2 (exit wave FP16) - that’s where the real gains are (2-3× speedup, happens 4,096 times per iteration).

Should I implement mixed precision (FP16) for the probe IFFT?

No - not worth it.

Why skip FP16 for probe IFFT:

  1. Runs only 1× per iteration (same frequency as probe generation)

    • Minimal impact even with 2× speedup

    • Exit wave FFT runs 4,096× per iteration - focus there instead

  2. cuFFT conversion overhead kills the gains:

    • Need FP32 → FP16 conversion before IFFT

    • Need FP16 → FP32 conversion after IFFT

    • Each conversion is a separate kernel launch

    • cuFFT FP32 is already memory-bandwidth optimal

  3. FFT is memory-bound, not compute-bound:

    • Tensor Cores don’t help (FFT is not matrix multiply)

    • FP16 only saves bandwidth if data stays FP16 throughout

    • But probe_k comes from FP32 kernel, and probe_real goes to FP32 ops

    • Conversions negate the bandwidth savings

  4. The math:

    • Probe IFFT: 80×80 = 6,400 complex values

    • Takes ~0.02-0.05ms (vs 110ms total iteration time)

    • Even 2× speedup saves only ~0.01-0.025ms

    • Impact: <0.02% of total time

Verdict: Keep probe IFFT in FP32. Focus on Kernel 2 (exit wave FP16) where you get 2-3× speedup on operations that run 4,096× per iteration.

Exception: If your exit wave pipeline keeps everything in FP16 (exit_wave_fp16 → FFT_fp16 → intensity_fp16), then the forward FFT could benefit. But the probe IFFT still runs too infrequently to matter.

Custom kernel 2: Fused exit wave computation with FP16 (CRITICAL)

This is the main optimization! Runs once per scan position (e.g., 4,096× per iteration).

Problem:

  • GPU time: aten::mul takes 2.104s (19.1%) for 5,140 separate kernel calls

  • CPU time: Each kernel launch costs ~5-10μs, plus autograd bookkeeping

The current code does this:

# Three separate operations in FP32
object_complex = object_amplitude * torch.exp(1j * object_phase)  # mul #1
tiles = object_complex[y_idx, x_idx]  # gather
exit_waves = tiles * probe_complex.unsqueeze(0)  # mul #2 (complex)

# Result: 3 kernel launches, 3 memory writes, all in FP32

Solution: Fuse all three operations into one FP16 kernel using custom complex math.

FP16 complex math primitives

First, define FP16 complex number operations using __half2:

#include <cuda_fp16.h>

typedef __half2 complex_half;  // Stores (real, imag) in one __half2

// Complex multiply: (a + bi)(c + di) = (ac - bd) + (ad + bc)i
__device__ __forceinline__ complex_half cmul_fp16(complex_half a, complex_half b) {
    __half a_real = __low2half(a);
    __half a_imag = __high2half(a);
    __half b_real = __low2half(b);
    __half b_imag = __high2half(b);

    __half real = __hsub(__hmul(a_real, b_real), __hmul(a_imag, b_imag));
    __half imag = __hadd(__hmul(a_real, b_imag), __hmul(a_imag, b_real));

    return __halves2half2(real, imag);
}

// Amplitude + phase → complex FP16
__device__ __forceinline__ complex_half amp_phase_to_complex_fp16(
    __half amp, __half phase
) {
    __half real = __hmul(amp, hcos(phase));
    __half imag = __hmul(amp, hsin(phase));
    return __halves2half2(real, imag);
}

Why this is fast: On Ampere/Hopper GPUs, __half operations use Tensor Cores and process 2× faster than FP32.

Fused gather + multiply kernel

__global__ void gather_multiply_fp16(
    const __half* __restrict__ obj_amplitude,  // [H, W] FP16
    const __half* __restrict__ obj_phase,      // [H, W] FP16
    const __half2* __restrict__ probe,         // [probe_size²] complex FP16
    const int* __restrict__ y_indices,         // [B * probe_size²]
    const int* __restrict__ x_indices,         // [B * probe_size²]
    __half2* __restrict__ output,              // [B * probe_size²] complex FP16
    int batch_size, int probe_size, int object_width
) {
    int batch_idx = blockIdx.x;
    int pixel_idx = threadIdx.x + blockIdx.y * blockDim.x;
    int probe_area = probe_size * probe_size;

    if (pixel_idx < probe_area) {
        int global_idx = batch_idx * probe_area + pixel_idx;

        // Step 1: Gather object pixel coordinates
        int y = y_indices[global_idx];
        int x = x_indices[global_idx];
        int obj_idx = y * object_width + x;

        // Step 2: Convert amplitude + phase → complex (fused, in registers)
        __half amp = obj_amplitude[obj_idx];
        __half phase = obj_phase[obj_idx];
        complex_half obj_val = amp_phase_to_complex_fp16(amp, phase);

        // Step 3: Complex multiply with probe (fused, uses Tensor Cores!)
        complex_half probe_val = probe[pixel_idx];
        output[global_idx] = cmul_fp16(obj_val, probe_val);
    }

}

PyTorch wrapper:

import torch
from torch.utils.cpp_extension import load_inline

# Compile CUDA kernel
cuda_source = """
// ... kernel code above ...

torch::Tensor fused_exit_wave(
    torch::Tensor obj_amplitude,
    torch::Tensor obj_phase,
    torch::Tensor y_indices,
    torch::Tensor x_indices,
        // Step 3: Complex multiply with probe (fused, uses Tensor Cores!)
        complex_half probe_val = probe[pixel_idx];
        output[global_idx] = cmul_fp16(obj_val, probe_val);
    }
}

What this achieves:

  • 3 operations → 1 kernel: Amplitude/phase conversion + gather + complex multiply

  • FP32 → FP16: 2× memory bandwidth, 2× compute throughput (Tensor Cores)

  • No intermediate memory: Everything computed in registers

  • 1 kernel launch instead of 3: Eliminates CPU overhead

Launch configuration and PyTorch binding

extern "C" {

void launch_gather_multiply_fp16(
    const void* obj_amplitude,
    const void* obj_phase,
    const void* probe,
    const int* y_indices,
    const int* x_indices,
    void* output,
    int batch_size,
    int probe_size,
    int object_width,
    cudaStream_t stream
) {
    int probe_area = probe_size * probe_size;

    // Grid: [batch_size, (probe_area + 255) / 256]
    dim3 grid(batch_size, (probe_area + 255) / 256);
    dim3 block(256);

    gather_multiply_fp16<<<grid, block, 0, stream>>>(
        (const __half*)obj_amplitude,
        (const __half*)obj_phase,
        (const __half2*)probe,
        y_indices,
        x_indices,
        (__half2*)output,
        batch_size,
        probe_size,
        object_width
    );
}

}  // extern "C"

Use in PyTorch:

import torch
import ptychography_fp16  # Your compiled extension

# Convert FP32 → FP16 for kernel input
obj_amp_fp16 = object_amplitude.to(torch.float16)
obj_phase_fp16 = object_phase.to(torch.float16)
probe_fp16 = probe.to(torch.float16)

# Single fused FP16 kernel call
exit_waves_fp16 = ptychography_fp16.gather_multiply_fp16(
    obj_amp_fp16, obj_phase_fp16, probe_fp16, y_idx, x_idx
)

# Continue with FFT in FP32
Psi = torch.fft.fft2(exit_waves_fp16.to(torch.complex64))

Expected speedup: 2-3× compared to PyTorch FP32 (gather + multiply operations)

Performance breakdown: - Kernel fusion: 1.3-1.5× (eliminates 2 extra kernel launches) - FP16 compute: 1.5-2× (Tensor Cores, 2× throughput) - Combined: 2-3× for this operation

Helper kernels for FP16 conversion

Additional utility kernels for converting between FP32/FP16 and packing/unpacking complex numbers:

// Compute intensity: |ψ|² (FP16 input, FP32 output for stability)
__global__ void compute_intensity_fp16(
    const __half* psi_real,
    const __half* psi_imag,
    float* intensity,
    int size
) {
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    if (idx < size) {
        float real_f32 = __half2float(psi_real[idx]);
        float imag_f32 = __half2float(psi_imag[idx]);
        intensity[idx] = real_f32 * real_f32 + imag_f32 * imag_f32;
    }
}

// Pack separate real/imag → __half2
__global__ void pack_complex_fp16(
    const __half* real,
    const __half* imag,
    __half2* output,
    int size
) {
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    if (idx < size) {
        output[idx] = __halves2half2(real[idx], imag[idx]);
    }
}
Why compute intensity in FP32: Squaring FP16 values can cause overflow. Convert to FP32 for the final multiplication to maintain numerical stability.
exit_waves = fused_module.fused_exit_wave(

self.object_amplitude, self.object_phase, self.y_idx, self.x_idx, self.probe_cache

)

# FFT (can’t fuse this - cuFFT is optimal) Psi = torch.fft.fft2(exit_waves, s=(self.diffraction_size, self.diffraction_size)) return torch.abs(Psi)

What this kernel achieves:

  • Eliminates separate amplitude * exp(i*phase) operation

  • Eliminates separate gather operation

  • Eliminates separate complex multiply operation

  • All done in one pass with coalesced memory access

Expected speedup: 1.3-1.5×

New time: 0.110s → 0.073-0.085s

Why this works: Instead of materializing the full complex object in memory, then gathering patches, then multiplying by probe (three separate kernel launches, three memory round-trips), we compute everything on-the-fly. Each thread reads amplitude and phase values, computes the complex object value, and immediately multiplies by the probe - all in registers.

Custom kernel 3: Fused intensity loss

Problem: Abs + power operations take 30.45% combined:

  • aten::abs: 656ms (5.95%)

  • AbsBackward0: 1.524s (13.82%)

  • PowBackward0: 1.175s (10.65%)

The current code:

# Three separate operations
I_pred = torch.abs(Psi) ** 2  # abs, then pow
pred_scale = counts / I_pred.mean(dim=(1,2), keepdim=True)  # mean, div
diff = I_pred * pred_scale - I_meas * meas_scale  # mul, sub
loss = (diff * diff).mean()  # mul, mean

Custom CUDA kernel fuses everything:

__global__ void fused_intensity_loss_kernel(
    const cuFloatComplex* __restrict__ Psi,  // [B, H, W]
    const float* __restrict__ I_meas,        // [B, H, W]
    const float* __restrict__ meas_scale,    // [B, 1, 1]
    float counts,
    float* __restrict__ loss_output,         // [1]
    float* __restrict__ I_pred_output,       // [B, H, W] (for backward)
    int B, int H, int W
) {
    __shared__ float shared_sum[256];
    __shared__ float shared_mean[256];  // For per-batch means

    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    int tid = threadIdx.x;
    int total = B * H * W;

    float local_sum = 0.0f;

    // Phase 1: Compute intensities and per-batch means
    if (idx < total) {
        int b = idx / (H * W);
        cuFloatComplex psi_val = Psi[idx];

        // I_pred = |Psi|^2
        float I_pred = psi_val.x * psi_val.x + psi_val.y * psi_val.y;
        I_pred_output[idx] = I_pred;  // Store for backward pass

        // Accumulate for mean (atomicAdd to shared memory per batch)
        atomicAdd(&shared_mean[b % 256], I_pred);
    }

    __syncthreads();

    // Phase 2: Compute loss with scaling
    if (idx < total) {
        int b = idx / (H * W);
        float I_pred = I_pred_output[idx];

        // pred_scale = counts / mean(I_pred)
        float pred_scale = counts / (shared_mean[b % 256] / (H * W));

        // diff = I_pred * pred_scale - I_meas * meas_scale
        float diff = I_pred * pred_scale - I_meas[idx] * meas_scale[b];

        // loss = mean(diff^2)
        local_sum += diff * diff;
    }

    // Reduce across threads
    shared_sum[tid] = local_sum;
    __syncthreads();

    for (int s = blockDim.x / 2; s > 0; s >>= 1) {
        if (tid < s) {
            shared_sum[tid] += shared_sum[tid + s];
        }
        __syncthreads();
    }

    if (tid == 0) {
        atomicAdd(loss_output, shared_sum[0] / total);
    }
}

PyTorch wrapper for forward pass:

def fused_intensity_loss(Psi, I_meas, meas_scale, counts):
    """
    Fused forward pass - computes loss in a single kernel.
    Uses PyTorch autograd for backward pass.
    """
    B, H, W = Psi.shape
    loss = torch.zeros(1, device=Psi.device)
    I_pred = torch.empty_like(I_meas)

    threads = 256
    blocks = (B * H * W + threads - 1) // threads

    fused_intensity_loss_kernel[blocks, threads](
        Psi, I_meas, meas_scale, counts, loss, I_pred, B, H, W
    )

    return loss

# Use in training with PyTorch autograd
loss = fused_intensity_loss(Psi, I_meas, meas_scale, counts)
loss.backward()  # PyTorch handles gradient computation

Expected speedup: 1.5-1.8× (eliminates 6 separate operations + all intermediate memory)

New time: 0.073s → 0.041-0.049s

Summary: Two critical kernels + one optional

Priority ranking by impact:

  1. Kernel 2 (exit wave FP16) - CRITICAL: Runs 4,096× per iteration → 2-3× speedup

  2. Kernel 3 (intensity loss) - IMPORTANT: Runs 4,096× per iteration → 1.5-1.8× speedup

  3. Kernel 1 (probe generation) - OPTIONAL: Runs 1× per iteration, only if optimizing probe aberrations

The forward pass:

[Optional] Kernel 1: probe_k generation  →  [IFFT]  →  probe_real (1× per iter)
                                          ↓
[CRITICAL] Kernel 2: Exit wave FP16      (4,096× per iter)
                     tile gather + amp/phase → complex + multiply
                                          →  [FFT]  →  Psi
                                          ↓
[IMPORTANT] Kernel 3: Intensity loss     (4,096× per iter)
                      |Psi|², scaling, loss

Why Kernel 2 is most important:

  • Runs 4,096 times per iteration (once per scan position)

  • Uses FP16 with Tensor Cores → 2-3× speedup

  • Fuses 3 operations into 1 kernel

  • Biggest performance impact

Why Kernel 1 is optional:

  • Runs 1 time per iteration (or never if probe is fixed)

  • Most reconstructions use fixed probe (no optimization needed)

  • Even if trainable, small fraction of total time (~10%)

  • Low priority - skip unless you need it

Performance gains:

  • Kernel 1 (optional): 1.1-1.2× IF optimizing probe aberrations

  • Kernel 2 (critical): 2-3× speedup with FP16 + kernel fusion

  • Kernel 3 (important): 1.5-1.8× speedup

  • Realistic combined: 2-3× from Kernels 2+3 alone

Recommendation: Implement Kernel 2 first, then Kernel 3. Skip Kernel 1 unless you’re optimizing probe aberrations.

Backward pass: Use PyTorch autograd

loss = compute_loss(Psi, I_meas)

scaler.scale(loss).backward() # Scale gradients to prevent underflow scaler.step(optimizer) scaler.update()

Performance gains:

  • 1.3-1.5× speedup on Ampere/Hopper GPUs

  • 2× memory reduction for activations

  • No accuracy loss for this application

Final combined speedup: 2.5-3.5× (kernels) × 1.3-1.5× (FP16) = 3-5× total

Result: 0.11s → 0.022-0.037s per iteration

Backward pass: Use PyTorch autograd

Strategy: Focus optimization on forward pass, use PyTorch’s built-in tools for backward.

Strategy: Focus optimization on forward pass (Kernels 2+3), use PyTorch’s built-in tools for backward.

optimizer = torch.optim.Adam([object_complex, probe], lr=0.01)

for iteration in range(max_iterations):
    optimizer.zero_grad()

    # Forward pass with custom FP16 kernels
    exit_waves_fp16 = ptychography_fp16.gather_multiply_fp16(
        obj_amp_fp16, obj_phase_fp16, probe_fp16, y_idx, x_idx
    )

    Psi = torch.fft.fft2(exit_waves_fp16.to(torch.complex64))
    loss = compute_loss(Psi, I_meas)

    loss.backward()  # PyTorch autograd handles gradients
    optimizer.step()

Why this is good: - PyTorch’s autograd is already highly optimized - ADAM optimizer is well-tested and numerically stable - Less code to maintain - Forward pass (Kernels 2+3) is where the biggest gains are (2-3× each)

Implementation guide

Compile custom FP16 CUDA kernels

Create setup.py for PyTorch extension:

from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension

setup(
    name='ptychography_fp16',
    ext_modules=[
        CUDAExtension(
            'ptychography_fp16',
            ['fp16_complex_ops.cu'],
            extra_compile_args={
                'cxx': ['-O3'],
                'nvcc': ['-O3', '--use_fast_math', '-arch=sm_80']  # Ampere
            }
        )
    ],
    cmdclass={'build_ext': BuildExtension}
)

Compile and install:

python setup.py install

Integrate into PyTorch training loop

import torch
import torch.nn as nn
import ptychography_fp16  # Compiled custom kernels

class OptimizedPtychography(nn.Module):
    def __init__(self, object_size, probe_size, scan_positions, device='cuda'):
        super().__init__()

        # Master copies in FP32
        self.object_amplitude = nn.Parameter(torch.randn(object_size, device=device))
        self.object_phase = nn.Parameter(torch.randn(object_size, device=device))
        self.probe = nn.Parameter(torch.randn(probe_size, dtype=torch.complex64, device=device))

        # Precomputed indices
        self.register_buffer('y_idx', ...)
        self.register_buffer('x_idx', ...)

    def forward(self, batch_indices):
        # Convert to FP16 for custom kernel
        object_fp16 = ptychography_fp16.amp_phase_to_complex_fp16(
            self.object_amplitude.to(torch.float16),
            self.object_phase.to(torch.float16)
        )
        probe_fp16 = self.probe.to(torch.float16)

        # Custom FP16 gather + multiply kernel
        exit_waves_fp16 = ptychography_fp16.gather_multiply_fp16(
            object_fp16, probe_fp16, self.y_idx, self.x_idx
        )

        # FFT in FP32 for stability
        Psi = torch.fft.fft2(exit_waves_fp16.to(torch.complex64))

        return Psi

# Training loop
model = OptimizedPtychography(...).to('cuda')
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

for iteration in range(num_iterations):
    optimizer.zero_grad()

    Psi = model(batch_indices)
    loss = compute_loss(Psi, I_meas)

    loss.backward()  # PyTorch autograd handles gradients
    optimizer.step()

Validate FP16 accuracy

# Reference (PyTorch ops) with torch.no_grad():

Psi_ref = model_pytorch(batch_indices) loss_ref = compute_loss_pytorch(Psi_ref, …)

# Custom kernels with torch.no_grad():

Psi_opt = model_optimized(batch_indices) loss_opt = compute_loss_optimized(Psi_opt, …)

# Check numerical accuracy max_diff = torch.max(torch.abs(Psi_ref - Psi_opt)) rel_error = max_diff / torch.max(torch.abs(Psi_ref))

print(f”Max absolute difference: {max_diff:.6e}”) print(f”Relative error: {rel_error:.6e}”)

assert rel_error < 1e-4, “Custom kernels changed results!”

Expected errors with mixed precision: < 1e-3 (acceptable for scientific computing)

Step 4: Profile to verify speedup

import time

# Warmup
for _ in range(10):
    _ = model(batch_indices)

# Benchmark
torch.cuda.synchronize()
times = []
for _ in range(100):
    t0 = time.perf_counter()
    loss = train_iteration(model, ...)
    torch.cuda.synchronize()
    times.append(time.perf_counter() - t0)

print(f"Mean: {np.mean(times):.4f}s")
print(f"Std: {np.std(times):.4f}s")
print(f"Target < 0.01s: {'✅' if np.mean(times) < 0.01 else '❌'}")

Summary

Three custom CUDA kernels + FP16 = 3-5× speedup

Why three separate kernels:

Kernel 1: Probe generation (aberrations) → [IFFT] → probe
Kernel 2: Exit wave (gather + multiply)  → [FFT]  → Psi
Kernel 3: Intensity loss (|Psi|², scale, loss)

Cannot fuse through FFT/IFFT - global operations, cuFFT optimal.

What each kernel fuses:

  1. Probe generation: Defocus + Cs + astigmatism + aperture → FP32 probe_k → [IFFT]

    • Speedup: 1.1-1.2×

  2. Exit wave (FP16): Tile gather + amp/phase → complex + complex multiply in FP16

    • Speedup: 2-3× (uses Tensor Cores)

    • Custom __half2 complex math

    • Fuses 3 operations into 1 kernel

  3. Intensity loss: |Psi|² + scaling + difference + reduction

    • Speedup: 1.5-1.8×

FP16 implementation details:

  • Uses __half2 for complex numbers: (real, imag) packed

  • Custom cmul_fp16() for complex multiply (Tensor Cores)

  • Master copies in FP32, compute in FP16

  • FFT stays FP32 for stability

  • Typical error: < 1e-3 (acceptable)

Backward pass:

  • Use PyTorch autograd (already optimized)

  • ADAM optimizer

  • Focus on forward pass optimization

Final result: 0.11s → 0.022-0.037s (3-5× speedup)

Key takeaway: Custom FP16 kernels + kernel fusion + FFT boundaries = maximum performance with minimal code.

# PyTorch Operations Deep Dive: Ptychography Forward Pass

This document explains every PyTorch operation in the ptychography forward pass with detailed visual illustrations.

## Table of Contents 1. [The Big Picture: What is Ptychography?](#1-the-big-picture) 2. [Setup: Initial Data Structures](#2-setup) 3. [Operation 1: Gather (Extracting Patches)](#3-gather-operation) 4. [Operation 2: Complex Multiplication](#4-complex-multiplication) 5. Operation 3: FFT (Fourier Transform) 6. Operation 4: Intensity Calculation 7. Performance Optimizations Explained

The big picture: What is ptychography?

PTYCHOGRAPHY: Reconstructing an object by scanning a probe over it

Step 1: Physical Setup
┌─────────────────────────────────────────┐
│                                         │
│   Electron Beam (Probe)                │
│         ↓↓↓↓↓↓↓                        │
│   ┌─────────────┐                      │
│   │   Probe     │  ← 128×128 pixels    │
│   │  (complex)  │                      │
│   └─────────────┘                      │
│         ↓                               │
│   ┌─────────────────────┐              │
│   │                     │              │
│   │    Sample/Object    │ ← 512×512    │
│   │     (unknown!)      │              │
│   │                     │              │
│   └─────────────────────┘              │
│         ↓                               │
│   ┌─────────────┐                      │
│   │ Diffraction │  ← Detector 256×256  │
│   │  Pattern    │                      │
│   └─────────────┘                      │
│                                         │
└─────────────────────────────────────────┘

Step 2: Scanning Process
        Position 1     Position 2     Position 3
        ┌────┐        ┌────┐         ┌────┐
        │Beam│        │Beam│         │Beam│
        └─┬──┘        └─┬──┘         └─┬──┘
          ↓             ↓              ↓
    ┌─────────────────────────────────────┐
    │ ░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░ │
    │ ░░░Sample moves, beam scans!░░░░░░ │
    │ ░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░ │
    └─────────────────────────────────────┘

We scan 4096 positions (typical) and record diffraction at each!

Goal: Given diffraction patterns, reconstruct the object.

Forward model (what we’re implementing):

Object + Probe → Exit Wave → FFT → Diffraction Intensity
  ↑       ↑          ↑         ↑          ↑
 512×512 128×128  128×128   256×256   256×256
complex complex  complex   complex     real

Setup: Initial data structures

# The object we want to reconstruct (complex-valued)
object_real = torch.randn(512, 512)  # Real part
object_imag = torch.randn(512, 512)  # Imaginary part

# The probe (illumination, known or optimized)
# Here assume it's already IFFTed from probe(k)
probe_real = torch.randn(128, 128)   # Real part
probe_imag = torch.randn(128, 128)   # Imaginary part

# Scan positions (where the probe illuminates)
batch_size = 4096  # Number of scanned positions

And we have the scanned positions:

Object:
┌─────────────────────┐
│  ●    ●    ●        │  ← Row 1: (100,50), (100,150), (100,250)
│                     │
│  ●    ●    ●        │  ← Row 2: (200,50), (200,150), (200,250)
│                     │
│  ●    ●    ●        │  ← Row 3: (300,50), (300,150), (300,250)
└─────────────────────┘

y_centers = [100, 100, 100, 200, 200, 200, 300, 300, 300]
x_centers = [50, 150, 250, 50, 150, 250, 50, 150, 250]
y_centers = torch.tensor([100, 150, 200, ...])  # 4096 y-coordinates
x_centers = torch.tensor([50, 75, 100, ...])    # 4096 x-coordinates
Object (512×512):
┌─────────────────────────────────────┐
│ 0  1  2  3  4  5  6  ... 510  511  │ ← Row 0
│ 0  1  2  3  4  5  6  ... 510  511  │ ← Row 1
│ :  :  :  :  :  :  :       :    :   │
│ 0  1  2  3  4  5  6  ... 510  511  │ ← Row 511
└─────────────────────────────────────┘
Each cell is a complex number: real + i*imag

Probe (128×128):
┌──────────────────┐
│ Gaussian-like    │
│   beam shape     │
│  (128×128 px)    │
└──────────────────┘
Gather operation: Extracting patches (detailed explanation)

Problem: We need to extract a 128×128 patch from the 512×512 object at each scan position.

# Scan position 0: center at (100, 50)
# Scan position 1: center at (150, 75)
# Scan position 2: center at (200, 100)
# ... 4093 more positions
Object (512×512):
                    x_center=50
                        ↓
         ┌──────────────┼──────────────┐
         │              │              │
         │              │              │
    y=36 ├──────────────●──────────────┤ ← Top-left of patch
         │     ┌────────┼────────┐     │
         │     │ Patch  │128×128 │     │
 y_center=100──┼────────●────────┼─────┤ ← Center
         │     │        │        │     │
         │     └────────┼────────┘     │
   y=164 ├──────────────┼──────────────┤ ← Bottom-right
         │              │              │
         └──────────────┼──────────────┘
                   x=36  x=50  x=164

Top-left corner: (y_center - 64, x_center - 64) = (36, -14)
                  ↑
         Note: 128/2 = 64 (half the probe size)
patches = []
for i in range(4096):
    y_start = y_centers[i] - 64
    x_start = x_centers[i] - 64

    # Extract 128×128 patch
    patch_real = object_real[y_start:y_start+128, x_start:x_start+128]
    patch_imag = object_imag[y_start:y_start+128, x_start:x_start+128]

    patches.append((patch_real, patch_imag))

# Problem: 4096 separate operations! Very slow on GPU.

Here is a visual for patch_real = object_real[36:164, 50:178]:

object_real (512×512):
     Columns →
     0  1  2  ...  50  51  ...  177 178 ... 511
R  0 ░  ░  ░  ...  ░   ░   ...  ░   ░  ...  ░
o  1 ░  ░  ░  ...  ░   ░   ...  ░   ░  ...  ░
w  :  :  :  :       :   :        :   :       :
s 36 ░  ░  ░  ...  █   █   ...  █   ░  ...  ░  ← Start row (y_start)
↓ 37 ░  ░  ░  ...  █   █   ...  █   ░  ...  ░
  38 ░  ░  ░  ...  █   █   ...  █   ░  ...  ░
   :  :  :  :       :   :        :   :       :
 163 ░  ░  ░  ...  █   █   ...  █   ░  ...  ░  ← End row (y_start+127)
 164 ░  ░  ░  ...  ░   ░   ...  ░   ░  ...  ░  ← NOT included!
   :  :  :  :       :   :        :   :       :
 511 ░  ░  ░  ...  ░   ░   ...  ░   ░  ...  ░
              ↑                   ↑
         Start col            End col
        (x_start=50)      (x_start+127=177)
                            178 NOT included!

█ = Extracted patch (128×128) ░ = Not extracted


# Step 1: Create a meshgrid of offsets # This is the pattern of offsets within the 128×128 patch y_offset = torch.arange(128) # [0, 1, 2, …, 127] x_offset = torch.arange(128) # [0, 1, 2, …, 127]

grid_y, grid_x = torch.meshgrid(y_offset, x_offset, indexing=’ij’) # grid_y: [[0,0,0,…,0], grid_x: [[0,1,2,…,127], # [1,1,1,…,1], [0,1,2,…,127], # … … # [127,127,…,127]] [0,1,2,…,127]]

Visual representation of grid:


grid_y (128×128): grid_x (128×128): ┌───────────────┐ ┌───────────────┐ │ 0 0 0 … 0 │ │ 0 1 2 …127│ │ 1 1 1 … 1 │ │ 0 1 2 …127│ │ 2 2 2 … 2 │ │ 0 1 2 …127│ │ : : : : │ │ : : : : │ │127 127 …127 │ │ 0 1 2 …127│ └───────────────┘ └───────────────┘

Think of these as: “relative position within patch”


# Step 2: Add center positions to get absolute coordinates # Shape: [4096, 128, 128] y_indices = y_centers[:, None, None] + grid_y[None, :, :] x_indices = x_centers[:, None, None] + grid_x[None, :, :]

# Broadcasting explained: # y_centers[:, None, None] # [4096, 1, 1] # grid_y[None, :, :] # [1, 128, 128] # ─────────────────────── # y_indices # [4096, 128, 128] (broadcast!)

Broadcasting visualization:


y_centers[:, None, None]: grid_y[None, :, :]: ┌────┐ ┌──────────────┐ │100 │ │ 0 0 0 …0 │ │150 │ Broadcast → │ 1 1 1 …1 │ │200 │ ──────────── │ : : : : │ │ : │ │127 127 …127│ │300 │ └──────────────┘ └────┘ [4096,1,1] [1, 128, 128]

Result y_indices [4096, 128, 128]: Position 0: [[100,100,…,100], ← y_center=100 + offset 0

[101,101,…,101], ← y_center=100 + offset 1 … [227,227,…,227]] ← y_center=100 + offset 127

Position 1: [[150,150,…,150], ← y_center=150 + offset 0

[151,151,…,151], … [277,277,…,277]]


# Step 3: Gather using advanced indexing (MAGIC! ✨) patches_real = object_real[y_indices, x_indices] patches_imag = object_imag[y_indices, x_indices]

# Shape: [4096, 128, 128] # This is a SINGLE GPU operation! All 4096 patches extracted at once!

Concrete example: Advanced indexing with real numbers

Let’s use a small 8×8 object and extract 2 patches of 3×3:


# Object (8×8) with easy-to-track values object_real = torch.tensor([

[10, 11, 12, 13, 14, 15, 16, 17], # Row 0 [20, 21, 22, 23, 24, 25, 26, 27], # Row 1 [30, 31, 32, 33, 34, 35, 36, 37], # Row 2 [40, 41, 42, 43, 44, 45, 46, 47], # Row 3 [50, 51, 52, 53, 54, 55, 56, 57], # Row 4 [60, 61, 62, 63, 64, 65, 66, 67], # Row 5 [70, 71, 72, 73, 74, 75, 76, 77], # Row 6 [80, 81, 82, 83, 84, 85, 86, 87], # Row 7

])

# Extract 2 patches (3×3 each) # Patch 0: center at (2, 2) → top-left = (1, 1) # Patch 1: center at (4, 5) → top-left = (3, 4)

# Step 1: Create offset grid for 3×3 patch grid_y = torch.tensor([[0, 0, 0],

[1, 1, 1], [2, 2, 2]])

grid_x = torch.tensor([[0, 1, 2],

[0, 1, 2], [0, 1, 2]])

# Step 2: Centers and top-left corners y_centers = torch.tensor([2, 4]) x_centers = torch.tensor([2, 5]) y_start = y_centers - 1 # [1, 3] x_start = x_centers - 1 # [1, 4]

# Step 3: Build indices [2, 3, 3] y_indices = y_start[:, None, None] + grid_y[None, :, :] x_indices = x_start[:, None, None] + grid_x[None, :, :]

# y_indices: x_indices: # [[[1, 1, 1], [[[1, 2, 3], # [2, 2, 2], [1, 2, 3], # [3, 3, 3]], [1, 2, 3]], # # [[3, 3, 3], [[4, 5, 6], # [4, 4, 4], [4, 5, 6], # [5, 5, 5]]] [4, 5, 6]]]

# Step 4: Advanced indexing! patches_real = object_real[y_indices, x_indices]

What happens internally:

Assume center is 32, position (0,0) becomes 21, (0,1) 22, (0,2) 23, etc.

object_real is the main 8 by 8.


# For patch 0, position (0,0): y_indices[0,0,0] = 1, x_indices[0,0,0] = 1 patches_real[0,0,0] = object_real[1, 1] = 21

# For patch 0, position (0,1): y_indices[0,0,1] = 1, x_indices[0,0,1] = 2 patches_real[0,0,1] = object_real[1, 2] = 22

# For patch 0, position (1,1): y_indices[0,1,1] = 2, x_indices[0,1,1] = 2 patches_real[0,1,1] = object_real[2, 2] = 32

# … and so on for all 18 positions (2 patches × 3×3)

Result:


patches_real = [
[[21, 22, 23], # Patch 0

[31, 32, 33], # ← Center at (2,2) = 32 [41, 42, 43]],

[[44, 45, 46], # Patch 1

[54, 55, 56], # ← Center at (4,5) = 55 [64, 65, 66]]

]

Visual:


Object (8×8): 10 11 12 13 14 15 16 17 20 21 22 23 24 25 26 27 Patch 0 (red): Patch 1 (blue): 30 31 32 33 34 35 36 37 21 22 23 44 45 46 40 41 42 43 44 45 46 47 31 32 33 54 55 56 50 51 52 53 54 55 56 57 41 42 43 64 65 66 60 61 62 63 64 65 66 67 70 71 72 73 74 75 76 77 80 81 82 83 84 85 86 87

The key: y_indices and x_indices act as a “lookup table” telling PyTorch exactly which element to grab from object_real for each position in the output.

How advanced indexing works


object_real[y_indices, x_indices]

PyTorch internally does: for b in range(4096):

for i in range(128):
for j in range(128):

y = y_indices[b, i, j] # e.g., 100 x = x_indices[b, i, j] # e.g., 50 result[b, i, j] = object_real[y, x]

But this runs as PARALLEL threads on GPU! Not a loop! ```

GPU parallelization: ` CPU Loop (sequential):          GPU (parallel): ┌───┐                           ┌─┬─┬─┬─┬─┬─┬─┬─┐ 1 2 3 ...   │1│2│3│4│5│6│7│8│ All at once! └───┘                           └─┴─┴─┴─┴─┴─┴─┴─┘ Time: 4096 × 128×128 ops        Time: ~1ms (massively parallel) `

Complex multiplication: Object × probe

Complex number: z = a + bi
where: a = real part, b = imaginary part, i = √(-1)

Multiplication rule:
(a + bi) × (c + di) = (ac - bd) + (ad + bc)i
                       ─────────   ─────────
                         real        imag

In our code

# We have:
object_patch = object_real[y_indices, x_indices] + i * object_imag[y_indices, x_indices]
probe = probe_real + i * probe_imag

# Exit wave = object_patch × probe
# Using formula: (a+bi)×(c+di) = (ac-bd) + (ad+bc)i

exit_real = object_real[y_indices, x_indices] * probe_real \
          - object_imag[y_indices, x_indices] * probe_imag

exit_imag = object_real[y_indices, x_indices] * probe_imag \
          + object_imag[y_indices, x_indices] * probe_real

Visual example (single pixel)

Object patch at position (0,0): 0.8 + 0.3i
Probe at position (0,0): 0.9 + 0.1i

Exit wave = (0.8 + 0.3i) × (0.9 + 0.1i)
          = (0.8×0.9 - 0.3×0.1) + (0.8×0.1 + 0.3×0.9)i
          = (0.72 - 0.03) + (0.08 + 0.27)i
          = 0.69 + 0.35i

exit_real[0,0,0] = 0.69
exit_imag[0,0,0] = 0.35

Shape transformation

Input:
  object_real[y_indices, x_indices]  [4096, 128, 128]
  object_imag[y_indices, x_indices]  [4096, 128, 128]
  probe_real                         [128, 128]
  probe_imag                         [128, 128]

Broadcasting (probe gets repeated):
  probe_real → [4096, 128, 128] (same value for all 4096 batches)

Output:
  exit_real  [4096, 128, 128]
  exit_imag  [4096, 128, 128]

Broadcasting visualization:

Batch dimension:
                probe_real [128, 128]
                     ↓ broadcast
┌─────┬─────┬─────┬─────┬─────┬─────┐
│ #0  │ #1  │ #2  │ ... │4094 │4095 │
└─────┴─────┴─────┴─────┴─────┴─────┘
  ↑     ↑     ↑           ↑     ↑

Same probe repeated 4096 times (implicitly)

Each batch: object_patch[b] × probe → exit_wave[b] ```