.. _gpu-optimization: GPU optimization ================ **Objective:** Reduce ptychography reconstruction time from 0.11s to < 0.01s per iteration on a single L40s GPU, achieving 1 second for 100 iterations. On more powerful GPUs like H100, this could be even faster due to higher memory bandwidth (3.0-3.3 TB/s vs L40S's 864 GB/s). **Current performance:** 0.11s per iteration - 512×512 object (amplitude + phase) - 80×80 probe - 256×256 diffraction patterns - 4,096 scan positions per iteration **Target performance:** < 0.01s per iteration (11× speedup required) **Strategy:** Profile to identify bottlenecks, write custom FP16 CUDA kernels, fuse operations across FFT boundaries. **GPU memory bandwidth reference:** - L40S: 864 GB/s - A100 80GB: 2.0 TB/s - H100: 3.0-3.3 TB/s This document covers the actual implementation with measured profiling data. For CUDA programming fundamentals, see :ref:`cuda-tutorials`. PyTorch optimization techniques -------------------------------- This section covers general PyTorch optimization patterns used throughout the ptychography implementation. Understanding these techniques is essential before diving into the custom CUDA kernels. .. dropdown:: Using batch gather to obtain object patches (foundational technique, used once) :open: Problem: Extracting multiple patches from a large array """""""""""""""""""""""""""""""""""""""""""""""""""""""" In ptychography, we need to extract 4,096 overlapping patches (128×128 pixels each, matching the probe size) from a large object (512×512 pixels) at different scan positions. Visualization: .. code-block:: text Object (512×512): Extract 4,096 patches (128×128): ┌────────────────────┐ ┌──────┐ ┌──────┐ ┌──────┐ │ │ │Patch │ │Patch │ │Patch │ │ ┌──────┐ │ ───→ │ 1 │ │ 2 │ │ 3 │ │ │ │ │ └──────┘ └──────┘ └──────┘ │ └──────┘ │ ... │ ┌──────┐ │ ┌──────┐ │ │ │ │ │Patch │ │ └──────┘ │ │ 4096 │ └────────────────────┘ └──────┘ Each patch: 128×128 pixels at position (y_i, x_i) **Naive approach with Python loops (extremely slow)** .. code-block:: python # Bad: Sequential extraction with Python loop patches = [] for i in range(4096): y_start, x_start = scan_positions[i] patch = object[y_start:y_start+128, x_start:x_start+128] patches.append(patch) patches = torch.stack(patches) # [4096, 128, 128] # Problem: 4,096 separate CPU→GPU→CPU round trips # Time: ~500ms on T4 GPU **Why this is slow:** - Each iteration launches a separate GPU kernel - Python loop prevents vectorization - Overhead: 4,096 × ~0.12ms = 492ms wasted **Optimized approach with batch gather (fast)** .. code-block:: python import torch # Setup: precompute indices once object = torch.randn(512, 512).cuda() # Large 2D array on GPU scan_positions = torch.tensor([...]).cuda() # [4096, 2] (y, x) positions # Create coordinate grids for patch extraction patch_size = 128 # Must match probe size grid_y = torch.arange(patch_size).cuda() grid_x = torch.arange(patch_size).cuda() grid_y, grid_x = torch.meshgrid(grid_y, grid_x, indexing='ij') # Broadcast scan positions + local grid y_indices = scan_positions[:, 0:1, None] + grid_y[None, :, :] # [4096, 128, 128] x_indices = scan_positions[:, 1:2, None] + grid_x[None, :, :] # [4096, 128, 128] # Single batched gather operation patches = object[y_indices, x_indices] # [4096, 128, 128] # Time: ~0.8ms on T4 GPU (625× faster!) **How it works:** .. code-block:: text Step 1: Create local coordinate grid (128×128) ┌─────────────┐ │ 0 1 2 ... │ │ 1 2 3 ... │ │ 2 3 4 ... │ │ ... │ └─────────────┘ Step 2: Add scan position offset to each grid Position 1: (y=50, x=100) Position 2: (y=75, x=150) ┌─────────────┐ ┌─────────────┐ │ 50 51 52... │ │ 75 76 77... │ │ 51 52 53... │ │ 76 77 78... │ │ ... │ │ ... │ └─────────────┘ └─────────────┘ Step 3: Use advanced indexing to gather all 4,096 patches in parallel object[y_indices, x_indices] → Single GPU kernel launch **Performance comparison** .. list-table:: :widths: 40 20 20 20 :header-rows: 1 * - Approach - Time (ms) - Kernel Launches - Speedup * - Python loop - ~500ms - 4,096 - 1× * - Batch gather - ~0.8ms - 1 - **625×** **Key insights** 1. **Avoid Python loops over GPU operations** - Always vectorize when possible 2. **Advanced indexing is highly optimized** - PyTorch's ``tensor[indices]`` uses efficient GPU gather kernels 3. **Precompute coordinate grids** - Create indices once, reuse across iterations 4. **Broadcasting eliminates loops** - Use broadcasting to generate all indices simultaneously .. warning:: **Gotcha:** Make sure indices are on the same device as the tensor. Mixed CPU/GPU indices force expensive device transfers: .. code-block:: python # Bad: indices on CPU, object on GPU patches = object_gpu[indices_cpu] # Transfers indices CPU→GPU every call # Good: both on GPU patches = object_gpu[indices_gpu] # Fast, no transfers When to write custom CUDA kernels ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ PyTorch core developers write custom CUDA kernels for operations not covered by cuBLAS or cuDNN. These include element-wise operations, reductions, indexing, and specialized functions. Users can also write custom CUDA kernels using PyTorch's C++ extension API or tools like Triton (which compiles Python-like code to GPU kernels). Most PyTorch users never write kernels since the built-in operations cover typical neural network needs. However, custom kernels enable specialized optimizations for novel architectures or unique algorithms. **When custom CUDA justifies the complexity:** For specialized algorithms where memory access patterns matter more than ease of development, CUDA offers performance gains that justify implementation complexity. PyTorch's general-purpose kernels cannot provide precise control over GPU resources that production systems need. **Why use custom CUDA when PyTorch is already fast?** PyTorch provides high-level operations but custom CUDA kernels give fine-grained control over memory access patterns, thread synchronization, and data layout. Optimizations like memory tiling can achieve 3-8× speedups. Production systems like FlashAttention (GPT-4), xFormers (Llama), DeepSpeed (Microsoft), and vLLM use custom CUDA C++ for maximum performance. Mixing PyTorch operations with custom CUDA kernels ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ You can mix PyTorch operations with custom CUDA kernels. For example, PyTorch handles FFTs using cuFFT while your kernel implements specialized phase retrieval updates. PyTorch tensors expose raw GPU memory pointers that your CUDA kernel can access directly through the C++ extension API. This hybrid approach lets you leverage PyTorch's optimized FFT implementation while adding custom operations that aren't built-in, like constraint projection or error reduction. .. code-block:: python import torch from torch.utils.cpp_extension import load # Load custom CUDA kernel custom_kernel = load( name='custom_ops', sources=['custom_kernel.cu'], extra_cuda_cflags=['-O3'] ) # Mix PyTorch and custom operations x = torch.randn(512, 512, dtype=torch.complex64).cuda() # Use PyTorch's optimized FFT x_fft = torch.fft.fft2(x) # Apply custom CUDA kernel x_processed = custom_kernel.apply_constraint(x_fft) # Back to PyTorch for inverse FFT result = torch.fft.ifft2(x_processed) Where are PyTorch CUDA kernels in source code? ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ PyTorch CUDA kernels live in ``aten/src/ATen/native/cuda/`` for core operations and ``aten/src/ATen/native/cudnn/`` for cuDNN wrappers. Files like ``BinaryOps.cu`` contain element-wise operations and ``Activation.cu`` has functions like ReLU. Each ``.cu`` file contains CUDA C++ code with kernel launches and device functions organized by operation type. You can browse them at `pytorch/pytorch on GitHub `_. PyTorch vs CuPy: When to use automatic differentiation ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ **The problem:** CuPy handles standard GPU operations (FFT, array math, linear algebra) but lacks automatic differentiation. For optimization problems including loss function minimization, parameter fitting, and deep learning, you need autodiff. Writing gradients manually is error-prone and slow. Additionally, cross-platform support for NVIDIA, Mac, and AMD is important for many workflows. **The solution:** PyTorch provides GPU arrays with automatic differentiation (autodiff). You define your computation like ``loss = (model(x) - y)**2`` then call ``loss.backward()`` where PyTorch computes all gradients automatically. GPU-accelerated optimizers update parameters efficiently. PyTorch works on NVIDIA (CUDA), Mac (Metal), and AMD (ROCm). **When to use PyTorch vs CuPy:** - **Use PyTorch** for: Deep learning, training neural networks on diffraction data, parameter optimization, gradient-based fitting, Mac compatibility - **Use CuPy** for: Pure array operations (FFT, virtual imaging) on NVIDIA GPUs without needing gradients **Performance comparison example** 256×256 4D-STEM FFT on RTX 4090: .. list-table:: :widths: 40 30 30 :header-rows: 1 * - Framework - Time (seconds) - Notes * - CuPy - 0.8s - cuFFT directly * - PyTorch - 0.85s - cuFFT through PyTorch * - NumPy (CPU) - 45s - 56× slower Both call cuFFT underneath with nearly identical speed, so choose by features not performance. Background: Ptychography reconstruction ---------------------------------------- Ptychography reconstructs a high-resolution complex-valued object :math:`O(x,y)` and probe :math:`P(x,y)` from diffraction intensity measurements. The object and probe are both trainable parameters updated via gradient descent. **Problem setup:** - **Object:** Complex-valued transmission function at each pixel (amplitude and phase) - **Probe:** Complex-valued electron beam wavefront (contains aberrations) - **Measurements:** Diffraction intensity patterns at detector (magnitude only, phase lost) - **Goal:** Recover object and probe from intensity-only measurements Forward model: Seven computational steps per iteration ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ **Step 1: Generate probe from aberration parameters** (k-space → real space) The probe starts in reciprocal space with aberration function: .. math:: \chi(k_x, k_y) = \pi \lambda \Delta f k^2 + \frac{1}{2}\pi \lambda^3 C_s k^4 + \pi \lambda A_{astig} k^2 \cos(2\theta) where :math:`k = \sqrt{k_x^2 + k_y^2}` is the spatial frequency, :math:`\Delta f` is defocus, :math:`C_s` is spherical aberration, and :math:`A_{astig}` is astigmatism magnitude. The probe in k-space is: .. math:: \tilde{P}(k_x, k_y) = A(k) \cdot e^{i\chi(k_x, k_y)} where :math:`A(k)` is the aperture function (sigmoid for soft edges). Transform to real space via inverse FFT: .. math:: P(x,y) = \mathcal{F}^{-1}\{\tilde{P}(k_x, k_y)\} **Step 2: Construct object from amplitude and phase** .. math:: O(x,y) = A(x,y) \cdot e^{i\phi(x,y)} **Step 3: Extract object patches and compute exit waves** For each scan position :math:`j`: .. math:: T_j(x,y) = O(x + x_j, y + y_j) \cdot P(x,y) **Step 4: Propagate to detector via FFT** .. math:: \Psi_j(u,v) = \mathcal{F}\{T_j(x,y)\} **Step 5: Compute predicted intensity** .. math:: I_j^{\text{pred}}(u,v) = |\Psi_j(u,v)|^2 **Step 6: Loss function** .. math:: \mathcal{L} = \frac{1}{N} \sum_{j=1}^{N} \frac{1}{HW} \sum_{u,v} \left(I_j^{\text{pred}}(u,v) \cdot s_j^{\text{pred}} - I_j^{\text{meas}}(u,v) \cdot s_j^{\text{meas}}\right)^2 where :math:`s_j = \text{counts} / \text{mean}(I_j)` is the per-position scaling factor. **Step 7: Backward pass and optimization** PyTorch computes gradients via autograd and updates all trainable parameters using the Adam optimizer: .. math:: \frac{\partial \mathcal{L}}{\partial \theta} \quad \text{for all} \quad \theta \in \{\text{object amplitude, object phase, probe aberrations}\} .. note:: **Caveat:** Autograd overhead is significant. For 4,096 scan positions, PyTorch builds a computation graph with thousands of nodes, costing ~50% of wall-clock time. Custom CUDA kernels bypass autograd for the forward pass, then use PyTorch's built-in backward implementations. Trainable parameters being optimized ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ The object is stored as **two separate real-valued arrays** (amplitude and phase) but combined into **complex-valued form** during forward pass: .. math:: O(x,y) = A(x,y) \cdot e^{i\phi(x,y)} = A(x,y) \cdot (\cos\phi + i\sin\phi) **Why separate storage?** Gradients flow better through separate amplitude and phase parameters. PyTorch's autograd computes :math:`\frac{\partial \mathcal{L}}{\partial A}` and :math:`\frac{\partial \mathcal{L}}{\partial \phi}` independently, then updates each via Adam. If we stored the complex object directly, gradients would couple real/imaginary parts in a less physically meaningful way. .. list-table:: :widths: 30 25 25 20 :header-rows: 1 * - Parameter (stored) - Physical meaning - Shape - Dtype * - **Object parameters (stored separately)** - - - * - ``object_amplitude`` (A) - Transmission amplitude (0-1 range) - ``[512, 512]`` - float32 * - ``object_phase`` (φ) - Phase shift in radians - ``[512, 512]`` - float32 * - **Computed during forward** - - - * - ``object_complex`` (O) - :math:`A \cdot e^{i\phi}` (not stored, computed on-the-fly) - ``[512, 512]`` - complex64 * - **Probe aberration parameters** - - - * - ``defocus`` (Δf) - Defocus distance in nm (C1) - ``[]`` (scalar) - float32 * - ``Cs`` - Spherical aberration in mm (C3) - ``[]`` (scalar) - float32 * - ``astig_mag`` - Astigmatism magnitude in nm - ``[]`` (scalar) - float32 * - ``astig_angle`` - Astigmatism orientation in radians - ``[]`` (scalar) - float32 * - ``aperture_smooth`` - Aperture edge smoothness parameter - ``[]`` (scalar) - float32 **Total trainable parameters:** 524,293 - Object: 512×512×2 = 524,288 parameters (amplitude + phase stored separately) - Probe aberrations: 5 parameters (defocus, Cs, astigmatism magnitude, astigmatism angle, aperture smoothness) **Optimization strategy:** Two separate Adam optimizers with fused kernels update object and probe parameters independently after each iteration. .. warning:: **Gotcha:** Object is stored as two separate real-valued arrays (amplitude and phase) but must be combined into complex form during forward pass. This separation enables better gradient flow but requires careful handling to avoid creating unnecessary intermediate tensors. What stays in PyTorch vs what moves to custom CUDA ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ **Keep in PyTorch (already optimal or difficult to replace):** - FFT operations: ``torch.fft.fft2()`` and ``torch.fft.ifft2()`` - Uses cuFFT internally (NVIDIA's highly optimized library) - Cannot be beaten with custom implementation - Accounts for 49.4% of GPU time (will remain constant) - Adam optimizer: ``torch.optim.Adam(fused=True)`` - Single CUDA kernel with numerically stable updates - Well-tested, includes gradient clipping and weight decay - Replacing offers minimal gains - Backward pass: PyTorch's automatic differentiation - Highly optimized for common operations - Maintains numerical stability - Custom backward kernels offer ~20% gains but add complexity **Move to custom CUDA kernels (high-impact optimizations):** 1. **Probe generation:** :math:`\chi(k) → \tilde{P}(k)` preparation before IFFT - Fuses defocus, spherical aberration, astigmatism calculations - Runs 1× per iteration (low priority) 2. **Exit wave computation:** :math:`O[indices] \cdot P` - Fuses tile gather + amplitude/phase → complex + complex multiply - Runs 4,096× per iteration (CRITICAL) - FP16 implementation uses Tensor Cores for 2-3× speedup 3. **Intensity and loss:** :math:`|\Psi|^2 → \text{scaling} → \text{loss}` - Fuses magnitude squared, per-batch scaling, loss reduction - Runs 4,096× per iteration (IMPORTANT) - Eliminates 6 separate kernel launches .. note:: **Key insight:** With 4,096 scan positions per iteration, operations that run once per position dominate the total time. Prioritize optimizing exit wave and intensity loss over probe generation. The actual implementation -------------------------- Here's the complete code being optimized. **Current performance:** 0.11s per iteration .. dropdown:: Aberration-parameterized probe model :color: info :icon: code .. code-block:: python import torch import torch.nn as nn import numpy as np class AberrationProbe(nn.Module): """Physics-informed probe with aberration parameterization.""" def __init__(self, size, pixel_size, wavelength, convergence_angle, device='cuda'): super().__init__() self.size = size self.pixel_size = pixel_size self.wavelength = wavelength self.convergence_angle = convergence_angle self.device = device # Create k-space grids - registered as buffers (no gradient retention) kx = torch.fft.fftfreq(size, d=pixel_size, device=device) ky = torch.fft.fftfreq(size, d=pixel_size, device=device) kx, ky = torch.meshgrid(kx, ky, indexing='ij') self.register_buffer('k', torch.sqrt(kx**2 + ky**2)) self.register_buffer('k_angle', torch.atan2(ky, kx)) # Trainable aberration coefficients self.defocus = nn.Parameter(torch.tensor(50.0)) # nm (C1) self.Cs = nn.Parameter(torch.tensor(1.0)) # mm (C3) self.astig_mag = nn.Parameter(torch.tensor(10.0)) # nm self.astig_angle = nn.Parameter(torch.tensor(0.0)) # radians self.aperture_smooth = nn.Parameter(torch.tensor(0.1)) def compute_aberration_function(self): """χ(k) = π λ Δf k² + (1/2) π λ³ Cs k⁴ + astigmatism""" k = self.k wavelength = self.wavelength Cs_nm = self.Cs * 1e6 # mm → nm # Defocus term chi_defocus = np.pi * wavelength * self.defocus * k**2 # Spherical aberration term chi_Cs = 0.5 * np.pi * wavelength**3 * Cs_nm * k**4 # Astigmatism chi_astig = (np.pi * wavelength * self.astig_mag * k**2 * torch.cos(2 * (self.k_angle - self.astig_angle))) return chi_defocus + chi_Cs + chi_astig def compute_aperture(self): """Soft aperture function with sigmoid rolloff.""" k_max = self.convergence_angle / self.wavelength smoothness = torch.abs(self.aperture_smooth) + 0.01 return torch.sigmoid((k_max - self.k) / (k_max * smoothness)) def forward(self): """Generate probe: P = IFFT[A(k) × exp(iχ(k))]""" chi = self.compute_aberration_function() aperture = self.compute_aperture() # Probe in k-space probe_k = aperture * torch.exp(1j * chi) # Transform to real space probe_real = torch.fft.ifftshift(torch.fft.ifft2(probe_k)) # Normalize probe_real = probe_real / torch.sqrt(torch.sum(torch.abs(probe_real)**2)) return probe_real def get_probe_crop(self, crop_size): """Get cropped probe WITH gradients for optimization.""" probe_full = self.forward() center = self.size // 2 half = crop_size // 2 return probe_full[center-half:center+half, center-half:center+half] .. dropdown:: Physics-informed ptychography model :color: info :icon: code .. code-block:: python class PhysicsInformedPtychography(nn.Module): """Ptychography model with batched operations.""" def __init__(self, object_size, probe_size, diffraction_size, scan_positions, pixel_size, wavelength, convergence_angle, device='cuda'): super().__init__() self.object_size = object_size self.probe_size = probe_size self.diffraction_size = diffraction_size self.device = device # Precompute coordinate grids for batched gather dy = torch.arange(probe_size, device=device, dtype=torch.long) dx = torch.arange(probe_size, device=device, dtype=torch.long) grid_y, grid_x = torch.meshgrid(dy, dx, indexing='ij') self.register_buffer('grid_y', grid_y) self.register_buffer('grid_x', grid_x) # Convert scan positions to tensor self.register_buffer('scan_positions', torch.tensor(scan_positions, dtype=torch.long, device=device)) # Trainable object parameters self.object_amplitude = nn.Parameter( 0.9 * torch.ones(object_size, object_size, device=device) + 0.1 * torch.randn(object_size, object_size, device=device) ) self.object_phase = nn.Parameter( 0.1 * torch.randn(object_size, object_size, device=device) ) # Probe model with trainable aberrations self.probe_model = AberrationProbe( diffraction_size, pixel_size, wavelength, convergence_angle, device=device ) def get_object_complex(self): """Construct complex object from amplitude and phase.""" return self.object_amplitude * torch.exp(1j * self.object_phase) def forward(self, batch_indices): """Forward pass with batched gather.""" if not isinstance(batch_indices, torch.Tensor): batch_indices = torch.tensor(batch_indices, dtype=torch.long, device=self.device) batch_indices = batch_indices.detach() positions = self.scan_positions[batch_indices] # Get complex object object_complex = self.get_object_complex() # Get probe (regenerated every call - BOTTLENECK!) probe_complex = self.probe_model.get_probe_crop(self.probe_size) # Batched gather - extract all patches at once y_idx = positions[:, 0, None, None] + self.grid_y[None, :, :] x_idx = positions[:, 1, None, None] + self.grid_x[None, :, :] tiles = object_complex[y_idx, x_idx] # [B, probe_size, probe_size] # Exit waves exit_waves = tiles * probe_complex.unsqueeze(0) # FFT with internal padding (cuFFT handles this optimally) Psi = torch.fft.fft2(exit_waves, s=(self.diffraction_size, self.diffraction_size)) return Psi .. dropdown:: Training loop :color: info :icon: code .. code-block:: python # Model setup model = PhysicsInformedPtychography( object_size=512, probe_size=80, diffraction_size=256, scan_positions=scan_positions, # 4,096 positions pixel_size=0.5, wavelength=0.0197, # 300 keV electrons convergence_angle=0.02, device='cuda' ).to('cuda') # Fused Adam optimizers (single CUDA kernel each) optimizer_object = torch.optim.Adam( [model.object_amplitude, model.object_phase], lr=0.01, fused=True ) optimizer_probe = torch.optim.Adam( model.probe_model.parameters(), lr=0.001, fused=True ) # Precompute constants I_meas = (diffraction_patterns ** 2).contiguous() meas_scale = (counts / I_meas.mean(dim=(1,2), keepdim=True)).contiguous() # JIT-compiled loss function @torch.jit.script def compute_intensity_loss(I_pred, I_meas, meas_scale, counts): pred_scale = counts / I_pred.mean(dim=(1,2), keepdim=True) diff = I_pred * pred_scale - I_meas * meas_scale return (diff * diff).mean() # Training torch.backends.cudnn.benchmark = True batch_indices = torch.arange(len(scan_positions), dtype=torch.long, device='cuda') for iteration in range(num_iterations): # Forward Psi = model(batch_indices) I_pred = torch.abs(Psi) ** 2 loss = compute_intensity_loss(I_pred, I_meas, meas_scale, counts) # Backward optimizer_object.zero_grad(set_to_none=True) optimizer_probe.zero_grad(set_to_none=True) loss.backward() # Update optimizer_object.step() optimizer_probe.step() Current baseline: Already implemented optimizations --------------------------------------------------- Before profiling and writing custom kernels, several standard PyTorch optimizations are already in place: **Batched gather operations:** All 4,096 object patches extracted in single GPU operation (no Python loops) **Registered buffers:** Precomputed coordinate grids (k-space, scan positions) stored on GPU, not recomputed **FFT internal padding:** Using ``s`` parameter to delegate padding to cuFFT (more efficient than explicit ``torch.nn.functional.pad``) **Fused Adam optimizer:** ``torch.optim.Adam(fused=True)`` uses single CUDA kernel for parameter updates **JIT-compiled loss function:** TorchScript compilation eliminates Python interpreter overhead **cuDNN autotuning:** ``torch.backends.cudnn.benchmark = True`` selects optimal convolution algorithms for tensor shapes **Impact:** These optimizations reduced iteration time from ~0.5s to **0.11s per iteration** (4.5× improvement). **Challenge:** To reach < 0.01s requires another 11× speedup. Standard PyTorch optimizations exhausted, custom CUDA required. .. warning:: **Gotcha:** ``fused=True`` in Adam only works when all parameters are on the same device and contiguous in memory. If parameters are scattered across tensors, PyTorch falls back to unfused implementation without warning. Check with ``torch.profiler`` to verify single ``adam_kernel`` launch. Profiling: Identifying the bottlenecks --------------------------------------- Running ``torch.profiler`` for 100 iterations reveals where time is spent: **GPU time breakdown:** .. list-table:: :widths: 20 50 30 :header-rows: 1 * - Name - Self CPU % - Self CUDA % * - ``aten::_fft_c2c`` - 0.11% - 24.68% * - ``aten::mul`` - 0.60% - 19.07% * - ``aten::abs`` - 0.04% - 5.95% * - ``autograd::engine::evaluate_function: AbsBackward0`` - 0.04% - 0.00% * - ``AbsBackward0`` - 0.02% - 0.00% * - ``void regular_fft<256u, ...>`` - 0.00% - 12.98% * - ``autograd::engine::evaluate_function: FftC2CBackward0`` - 0.02% - 0.00% * - ``FftC2CBackward0`` - 0.01% - 0.00% * - ``void vector_fft<256u, ...>`` - 0.00% - 11.70% * - ``autograd::engine::evaluate_function: PowBackward0`` - 0.04% - 0.00% * - ``PowBackward0`` - 0.02% - 0.00% **Total times:** Self CPU time: 11.011s | Self CUDA time: 11.030s **The real bottleneck:** CPU time ≈ GPU time ≈ 11s for 100 iterations (**0.11s per iteration**). **Why CPU time matters:** The user experiences wall-clock time, not just GPU time. PyTorch operations involve: 1. **Python overhead** - Each operation calls Python → C++ → CUDA 2. **Kernel launch overhead** - CPU must launch each GPU kernel (cudaLaunchKernel takes ~5-10μs) 3. **Autograd graph** - CPU builds and manages backward computation graph 4. **Tensor bookkeeping** - CPU tracks shapes, dtypes, memory, reference counts 5. **Synchronization** - CPU waits for GPU results between operations With 5,140 ``mul`` calls, 620 ``abs`` calls, etc., these CPU overheads accumulate to match the GPU compute time. **Custom CUDA eliminates CPU overhead:** A single kernel launch does everything on GPU. CPU launches once, then waits for the final result. No Python loops, no intermediate tensors, no autograd graph during forward pass. .. dropdown:: What PyTorch does internally for every operation :color: info :icon: tools **CPU-side overhead (happens for EVERY PyTorch operation):** 1. **Python interpreter overhead** - Function call from Python → C++ extension - Argument parsing and validation (type checking, shape checking) - GIL (Global Interpreter Lock) acquisition/release - Reference counting for Python objects - **Cost:** ~1-5μs per operation 2. **PyTorch dispatcher overhead** - Determine which backend to use (CPU/CUDA/MPS) - Check if operation is in-place or requires new allocation - Dispatch to correct kernel implementation - Handle device placement and data type promotion - **Cost:** ~0.5-2μs per operation 3. **Memory management** - Allocate output tensor (unless in-place) - Update CUDA memory allocator bookkeeping - Track tensor metadata (shape, stride, dtype, device) - Increment/decrement reference counts - Register tensor for garbage collection - **Cost:** ~2-10μs per allocation 4. **Autograd graph construction** - Create autograd node for backward pass - Store input tensors (or their versions) for gradient computation - Link node into computation graph - Register hook functions if any - Track requires_grad status through operations - **Cost:** ~5-15μs per operation with gradients 5. **Kernel launch** - Call ``cudaLaunchKernel()`` system call - Transfer kernel arguments to GPU (via driver) - Enqueue kernel in CUDA stream - Update stream synchronization state - **Cost:** ~5-10μs per kernel launch **GPU-side overhead (for operations PyTorch doesn't fuse):** 1. **Memory bandwidth bottleneck** - Each operation reads inputs from global memory (slow) - Each operation writes outputs to global memory (slow) - No data reuse between operations (cache misses) - **Example:** ``A * B`` reads A and B, writes result. Next operation reads result again. 2. **Kernel launch latency** - GPU must wait for kernel to be scheduled - Streaming multiprocessors (SMs) must load kernel code - Thread blocks must be assigned to SMs - **Cost:** ~1-5μs per kernel 3. **Synchronization overhead** - Operations on same stream serialize (can't overlap) - ``torch.cuda.synchronize()`` blocks CPU waiting for GPU - Implicit syncs when copying data back to CPU **Example: Single complex multiply in PyTorch** .. code-block:: python # This one line: result = a_complex * b_complex # complex64 tensors # Triggers on CPU: # - Python call overhead: ~3μs # - Dispatcher: ~1μs # - Memory allocation for result: ~5μs # - Autograd node creation: ~10μs # - 4× cudaLaunchKernel() calls: ~40μs (one per real/imag component) # Total CPU time: ~59μs # Triggers on GPU: # - 4 kernel launches: ~4μs # - Read a_complex (real + imag): 2 memory reads # - Read b_complex (real + imag): 2 memory reads # - Write result (real + imag): 2 memory writes # - Actual compute: <1μs (multiplication is fast) # Total GPU time: ~5μs # Wall-clock time: max(CPU, GPU) ≈ 59μs for ONE multiply **With 5,140 complex multiplies:** .. code-block:: text CPU overhead: 5,140 × 59μs = 303ms GPU overhead: 5,140 × 5μs = 26ms Total added to iteration: ~330ms **Custom CUDA eliminates almost all of this:** .. code-block:: cpp // Single kernel launch from Python: ~20μs CPU overhead total // GPU: Single kernel reads data once, writes result once // Wall-clock time: ~5μs for ALL 5,140 multiplies fused together **Why this matters:** The profiling shows CPU time (11.01s) ≈ GPU time (11.03s). This means the code is spending equal time on: - **CPU:** Launching kernels, building autograd graphs, managing memory - **GPU:** Actually computing results Custom CUDA collapses hundreds of operations into single kernels, eliminating most CPU overhead and allowing GPU to work continuously without waiting for CPU to launch the next kernel. Mapping profiling to custom kernels ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ **What can't be optimized:** - ``aten::_fft_c2c`` (24.7%): This IS cuFFT. No custom kernel beats cuFFT. - ``regular_fft`` + ``vector_fft`` (24.7%): cuFFT's internal kernels. Optimal. - **Total FFT: 49.4%** - Keep using ``torch.fft.fft2()`` **What custom kernels will target:** GPU operations to fuse: - ``aten::mul`` (19.1%, 5,140 calls): Complex multiply scattered everywhere - ``aten::abs`` (5.95%, 620 calls): Complex → real conversion - Power and backward operations: Squaring in loss + gradient computations - **Total fusible GPU operations: ~25-30% of GPU time** CPU overhead to eliminate: - 5,140 kernel launches for ``mul`` operations - Autograd graph building and traversal - Tensor allocation and memory management - Python/C++/CUDA call stack overhead - **Total CPU overhead: ~50% of wall-clock time (11.01s CPU ≈ 11.03s GPU)** **Total speedup potential:** - **3× from kernel fusion** (eliminate redundant memory operations on GPU) - **2× from eliminating CPU overhead** (single kernel launch, no Python intermediaries, no autograd graph) - **Total: 6× speedup → 0.018s per iteration** .. dropdown:: What does "fuse" mean in GPU computing? :color: success :icon: zap **Kernel fusion** means combining multiple operations into a single GPU kernel to eliminate intermediate memory reads/writes. **Unfused operations (what PyTorch does):** .. code-block:: python # Three separate operations a = torch.exp(x) # Operation 1: read x, write a b = a * y # Operation 2: read a, read y, write b c = torch.abs(b) # Operation 3: read b, write c # GPU memory traffic: # Op 1: Read x [100MB], Write a [100MB] → 200MB # Op 2: Read a [100MB], Read y [100MB], Write b [100MB] → 300MB # Op 3: Read b [100MB], Write c [100MB] → 200MB # Total: 700MB of memory traffic # GPU execution: # Launch kernel 1 → wait → launch kernel 2 → wait → launch kernel 3 # Each kernel launch: ~5-10μs CPU overhead + kernel execution time **Fused operations (custom CUDA):** .. code-block:: cuda __global__ void fused_kernel(float* x, float* y, float* c, int n) { int idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx < n) { float a = expf(x[idx]); // Compute in register float b = a * y[idx]; // Compute in register c[idx] = fabsf(b); // Write only final result } } // GPU memory traffic: // Read x [100MB], Read y [100MB], Write c [100MB] → 300MB // Total: 300MB of memory traffic (2.3× less than unfused) // GPU execution: // Single kernel launch → compute everything → done // One kernel launch: ~5-10μs CPU overhead (not 3×) **Why this matters:** 1. **Memory bandwidth is the bottleneck** - Modern GPUs can compute much faster than they can move data. Reducing memory traffic by 2-3× directly speeds up execution by 2-3×. 2. **Intermediate results stay in registers** - GPU registers are ~1000× faster than global memory. Fused kernels keep ``a`` and ``b`` in registers, never writing them to slow global memory. 3. **CPU launches fewer kernels** - One launch instead of three eliminates 20-30μs of CPU overhead per fused group. **Real example from ptychography:** .. code-block:: python # Unfused (PyTorch): amp = object_amplitude[indices] # Read, write → 4MB phase = object_phase[indices] # Read, write → 4MB obj_real = amp * torch.cos(phase) # Read amp, read phase, write → 12MB obj_imag = amp * torch.sin(phase) # Read amp, read phase, write → 12MB obj_complex = torch.complex(obj_real, obj_imag) # Read both, write → 12MB exit_wave = obj_complex * probe # Read obj, read probe, write → 12MB # Total: 56MB memory traffic, 6 kernel launches .. code-block:: cuda // Fused (custom CUDA): __global__ void fused_exit_wave_kernel(...) { float amp = obj_amplitude[obj_idx]; // Read once float phase = obj_phase[obj_idx]; // Read once cuFloatComplex obj = make_cuFloatComplex( amp * cosf(phase), // Compute in register amp * sinf(phase) // Compute in register ); cuFloatComplex probe_val = probe[idx]; // Read once exit_wave[idx] = cuCmulf(obj, probe_val); // Write once } // Total: 12MB memory traffic (4.7× less), 1 kernel launch **Fusion is not optimization - it's necessity:** For 4,096 scan positions, unfused operations create: - 24,576 kernel launches (4,096 positions × 6 operations) - 229GB of memory traffic (56MB × 4,096) - 250ms of CPU overhead just launching kernels Fused into one kernel: - 1 kernel launch - 49GB of memory traffic (4.7× reduction) - 10μs of CPU overhead **Result:** 4.7× faster GPU compute + elimination of CPU bottleneck = 5-6× total speedup. .. dropdown:: Why PyTorch cannot fuse complex number operations :color: warning :icon: alert **PyTorch's fundamental limitation:** The JIT compiler and ``torch.compile()`` do not fuse operations involving complex tensors. This is not a bug - it is an architectural limitation. **What happens with a simple complex multiply:** .. code-block:: python # This single line of PyTorch code: exit_waves = object_complex * probe_complex # Both are complex64 # Becomes FOUR separate GPU kernel launches: # Kernel 1: Extract real parts from object_complex # Kernel 2: Extract imaginary parts from object_complex # Kernel 3: Compute real_out = real_a * real_b - imag_a * imag_b # Kernel 4: Compute imag_out = real_a * imag_b + imag_a * real_b # Each kernel: # - Reads from global memory (slow) # - Writes to global memory (slow) # - Requires CPU to launch it (~5-10μs overhead) # - Allocates intermediate tensors # - Registers operation in autograd graph **The cascade effect with 5,140 complex multiplies:** .. code-block:: text 5,140 complex multiplies × 4 kernels each = 20,560 kernel launches 20,560 launches × 10μs CPU overhead = 205ms just in launch overhead Plus: 20,560 memory allocations, 20,560 autograd registrations Result: CPU time (11.01s) ≈ GPU time (11.03s) **Custom CUDA eliminates this completely:** .. code-block:: cuda // Single kernel, single launch, everything in registers: cuFloatComplex obj_val = make_cuFloatComplex( amp * cosf(phase), // Real part amp * sinf(phase) // Imaginary part ); cuFloatComplex probe_val = probe[idx]; exit_waves[idx] = cuCmulf(obj_val, probe_val); // One instruction // CPU overhead: Single cudaLaunchKernel() call (~10μs total) // GPU overhead: No intermediate memory reads/writes // Result: 3× faster GPU compute + 2× faster CPU overhead = 6× total **Why this matters for ptychography:** The forward pass has three cascading complex operations: 1. :math:`A(x,y) \cdot e^{i\phi(x,y)}` → object construction (complex exp + multiply) 2. :math:`O[indices] \cdot P(x,y)` → exit waves (gather + complex multiply) 3. :math:`|\Psi(u,v)|^2` → intensity (complex absolute value + square) Each gets fragmented into multiple kernels. Backward pass makes it worse with gradient computations through each operation. **Custom CUDA is not optional - it is necessary.** PyTorch fundamentally cannot optimize this workload. The three custom kernels to write ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Kernel 1: ``fused_exit_wave_kernel`` ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ **Eliminates 19.1% overhead** This kernel computes the exit waves for all 4,096 scan positions in a single GPU call. For each position, it constructs the complex object from separate amplitude and phase arrays, extracts the relevant patch using precomputed indices, and multiplies by the probe function. PyTorch does this in three separate kernel launches with intermediate tensor allocations. The custom kernel fuses everything into one pass, eliminating memory roundtrips and CPU overhead. **Fuses:** :math:`A(x,y) \cdot e^{i\phi(x,y)}` at indices → multiply by :math:`P(x,y)` .. list-table:: :widths: 20 50 30 :header-rows: 1 * - Parameter - Description - Shape/Type * - **Input** - - * - ``obj_amplitude`` - Object amplitude (real-valued, 0-1 range) - ``[512, 512]`` float * - ``obj_phase`` - Object phase (real-valued, radians) - ``[512, 512]`` float * - ``y_indices`` - Vertical positions for all patches - ``[4096, 80, 80]`` int * - ``x_indices`` - Horizontal positions for all patches - ``[4096, 80, 80]`` int * - ``probe`` - Illumination function (complex) - ``[80, 80]`` complex64 * - **Output** - - * - ``exit_waves`` - Transmitted wavefield for each position - ``[4096, 80, 80]`` complex64 **Replaces:** 3 separate PyTorch ops (``exp``, ``gather``, ``mul``) + 3 kernel launches + 2 intermediate tensors Kernel 2: ``fused_intensity_loss_kernel`` ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ **Eliminates 30.5% overhead** This kernel computes the complete loss function in a single GPU kernel with streaming reduction. It converts the complex diffraction patterns to intensities, applies per-position scaling factors, computes squared differences against measured data, and reduces to a scalar loss - all without writing intermediate results to global memory. PyTorch does this as six separate operations, each allocating full-size temporary tensors. The fused kernel processes data on-the-fly using shared memory reductions. **Fuses:** :math:`|\Psi|^2 \rightarrow \text{scale} \rightarrow \text{diff}^2 \rightarrow \text{mean}` .. list-table:: :widths: 20 50 30 :header-rows: 1 * - Parameter - Description - Shape/Type * - **Input** - - * - ``Psi`` - Diffracted wavefields (complex) - ``[4096, 256, 256]`` complex64 * - ``I_meas`` - Measured intensities (precomputed) - ``[4096, 256, 256]`` float * - ``meas_scale`` - Scaling for measured data (precomputed) - ``[4096, 1, 1]`` float * - ``counts`` - Total photon counts per pattern - ``[4096]`` float * - **Output** - - * - ``loss`` - Mean squared error (scalar) - ``[]`` float * - ``I_pred`` - Predicted intensities (saved for backward) - ``[4096, 256, 256]`` float **Replaces:** 6 separate PyTorch ops (``abs``, ``pow``, ``mean``, ``mul``, ``sub``, ``mean``) + 6 kernel launches + 5 intermediate tensors Kernel 3: ``fused_exit_wave_backward_kernel`` ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ **Speeds up backward pass** This kernel computes gradients for the exit wave operation by reversing the forward computation. Given gradients with respect to the exit waves, it efficiently backpropagates through the complex multiplication and object construction. The challenge is handling atomic additions for overlapping patches - each object pixel is visited by multiple scan positions, so gradients must accumulate atomically. PyTorch's autograd handles this correctly but builds a computational graph with overhead. The custom kernel implements the exact derivatives directly, using atomic operations for the scatter-add pattern. **Fuses:** Gradients for ``obj_amplitude``, ``obj_phase``, and ``probe`` from ``grad_exit_waves`` .. list-table:: :widths: 20 40 20 :header-rows: 1 * - Parameter - Description - Shape/Type * - **Input** - - * - ``grad_exit_waves`` - Gradients from loss w.r.t. exit waves - ``[4096, 80, 80]`` complex64 * - ``obj_amplitude`` - Forward pass object amplitude (for chain rule) - ``[512, 512]`` float * - ``obj_phase`` - Forward pass object phase (for chain rule) - ``[512, 512]`` float * - ``probe`` - Forward pass probe (for chain rule) - ``[80, 80]`` complex64 * - ``y_indices`` - Vertical positions (same as forward) - ``[4096, 80, 80]`` int * - ``x_indices`` - Horizontal positions (same as forward) - ``[4096, 80, 80]`` int * - **Output** - - * - ``grad_obj_amplitude`` - Gradients w.r.t. object amplitude - ``[512, 512]`` float * - ``grad_obj_phase`` - Gradients w.r.t. object phase - ``[512, 512]`` float * - ``grad_probe`` - Gradients w.r.t. probe function - ``[80, 80]`` complex64 **Replaces:** PyTorch's autograd through complex ops (general-purpose graph traversal) with direct derivative computation using atomic scatter-add Summary ^^^^^^^ Everything else (probe generation with aberrations, Adam, autograd) stays in PyTorch for now, but could be moved to CUDA for further speedup. FFT itself (49.4% of GPU time) can't be improved - cuFFT is already optimal. But the operations around it can be fused, and more importantly, **moving everything to custom CUDA eliminates PyTorch's CPU overhead** (kernel launches, autograd graph, tensor bookkeeping) which roughly doubles the wall-clock time. **Expected total speedup:** - 2.5-3× from kernel fusion (GPU compute) - 2× from eliminating CPU overhead (PyTorch → pure CUDA) - **Total: 5-6× → 0.018-0.022s per iteration** Custom kernel 1: Probe generation in k-space (low priority, but needed) ------------------------------------------------------------------------ **Reality check:** Probe aberration parameters (defocus, Cs, astigmatism) ARE trainable in ptychography - that's the whole point! But this kernel is still **low priority** because it runs only **1× per iteration** vs 4,096× for exit waves. **The probe generation code:** .. code-block:: python # Runs once per iteration after optimizer updates aberration parameters chi_defocus = np.pi * wavelength * self.defocus * k**2 chi_Cs = 0.5 * np.pi * wavelength**3 * Cs_nm * k**4 chi_astig = np.pi * wavelength * self.astig_mag * k**2 * torch.cos(2*(k_angle - astig_angle)) chi = chi_defocus + chi_Cs + chi_astig aperture = torch.sigmoid((k_max - self.k) / (k_max * aperture_smooth)) probe_k = aperture * torch.exp(1j * chi) probe_real = torch.fft.ifft2(probe_k) **Can PyTorch fuse these operations automatically?** **Short answer:** No, not effectively. **Why PyTorch can't fuse this:** 1. **JIT limitations:** PyTorch's ``torch.jit`` can't optimize through: - ``torch.exp(1j * chi)`` - complex exponential with imaginary unit - ``torch.sigmoid()`` - elementwise nonlinearity - ``torch.cos()`` with scalar multiplication - Multiple intermediate tensor allocations 2. **TorchScript restrictions:** .. code-block:: python @torch.jit.script def generate_probe_k(defocus, Cs, astig_mag, k, k_angle): # TorchScript can trace this, but won't fuse kernels chi = math.pi * wavelength * defocus * k**2 + ... # Separate kernel aperture = torch.sigmoid((k_max - k) / smooth) # Separate kernel probe_k = aperture * torch.exp(1j * chi) # Separate kernel return probe_k Result: Still 3+ separate CUDA kernel launches, just with less Python overhead. 3. **PyTorch 2.0 ``torch.compile()``:** .. code-block:: python @torch.compile(mode="max-autotune") def generate_probe_k(defocus, Cs, astig_mag, k, k_angle): chi = math.pi * wavelength * defocus * k**2 + ... aperture = torch.sigmoid((k_max - k) / smooth) probe_k = aperture * torch.exp(1j * chi) return probe_k - Might fuse some ops, but limited by TorchInductor's complex number support - Complex exponential ``exp(1j * x)`` often falls back to eager mode - **Gains: ~10-20% at best** (vs 1.1-1.2× with custom CUDA) - Worth trying first since it's just one decorator! **Custom CUDA kernel gives full control:** .. code-block:: cuda __global__ void generate_probe_k_kernel( const float* __restrict__ k, const float* __restrict__ k_angle, cuFloatComplex* __restrict__ probe_k, float defocus, float Cs, float astig_mag, float astig_angle, float wavelength, float k_max, float aperture_smooth, int size ) { int idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx < size * size) { float k_val = k[idx]; float angle = k_angle[idx]; // All ops fused in registers - no intermediate memory writes float chi = M_PI * wavelength * defocus * k_val * k_val + 0.5f * M_PI * wavelength * wavelength * wavelength * Cs * k_val * k_val * k_val * k_val + M_PI * wavelength * astig_mag * k_val * k_val * cosf(2.0f * (angle - astig_angle)); float aperture = 1.0f / (1.0f + expf((k_val - k_max) / (k_max * aperture_smooth))); probe_k[idx] = make_cuFloatComplex( aperture * cosf(chi), aperture * sinf(chi) ); } } **Expected speedup:** 1.1-1.2× (5-6 PyTorch kernel launches → 1 custom kernel) **But still low priority because:** - Runs **1× per iteration** (after optimizer step) - Exit wave kernel runs **4,096× per iteration** (every scan position) - Even at 10× slower, it's only ~10% of total time - **Focus on Kernel 2 (exit wave FP16) first - that's where 2-3× gains are** **Recommendation:** Implement exit wave FP16 kernel (Kernel 2) first. Come back to this only if you need that extra 10-20%. **Recommendation:** Start with **Kernel 2 (exit wave FP16)** - that's where the real gains are (2-3× speedup, happens 4,096 times per iteration). .. dropdown:: Should I implement mixed precision (FP16) for the probe IFFT? :color: secondary :icon: question **No - not worth it.** **Why skip FP16 for probe IFFT:** 1. **Runs only 1× per iteration** (same frequency as probe generation) - Minimal impact even with 2× speedup - Exit wave FFT runs 4,096× per iteration - focus there instead 2. **cuFFT conversion overhead kills the gains:** - Need FP32 → FP16 conversion before IFFT - Need FP16 → FP32 conversion after IFFT - Each conversion is a separate kernel launch - cuFFT FP32 is already memory-bandwidth optimal 3. **FFT is memory-bound, not compute-bound:** - Tensor Cores don't help (FFT is not matrix multiply) - FP16 only saves bandwidth if data stays FP16 throughout - But probe_k comes from FP32 kernel, and probe_real goes to FP32 ops - Conversions negate the bandwidth savings 4. **The math:** - Probe IFFT: 80×80 = 6,400 complex values - Takes ~0.02-0.05ms (vs 110ms total iteration time) - Even 2× speedup saves only ~0.01-0.025ms - **Impact: <0.02% of total time** **Verdict:** Keep probe IFFT in FP32. Focus on Kernel 2 (exit wave FP16) where you get 2-3× speedup on operations that run 4,096× per iteration. **Exception:** If your exit wave pipeline keeps everything in FP16 (exit_wave_fp16 → FFT_fp16 → intensity_fp16), then the forward FFT could benefit. But the probe IFFT still runs too infrequently to matter. Custom kernel 2: Fused exit wave computation with FP16 (CRITICAL) ------------------------------------------------------------------ **This is the main optimization!** Runs **once per scan position** (e.g., 4,096× per iteration). -------------------------------------------------------- **Problem:** - **GPU time:** ``aten::mul`` takes 2.104s (19.1%) for 5,140 separate kernel calls - **CPU time:** Each kernel launch costs ~5-10μs, plus autograd bookkeeping The current code does this: .. code-block:: python # Three separate operations in FP32 object_complex = object_amplitude * torch.exp(1j * object_phase) # mul #1 tiles = object_complex[y_idx, x_idx] # gather exit_waves = tiles * probe_complex.unsqueeze(0) # mul #2 (complex) # Result: 3 kernel launches, 3 memory writes, all in FP32 **Solution:** Fuse all three operations into one FP16 kernel using custom complex math. FP16 complex math primitives ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ First, define FP16 complex number operations using ``__half2``: .. code-block:: cuda #include typedef __half2 complex_half; // Stores (real, imag) in one __half2 // Complex multiply: (a + bi)(c + di) = (ac - bd) + (ad + bc)i __device__ __forceinline__ complex_half cmul_fp16(complex_half a, complex_half b) { __half a_real = __low2half(a); __half a_imag = __high2half(a); __half b_real = __low2half(b); __half b_imag = __high2half(b); __half real = __hsub(__hmul(a_real, b_real), __hmul(a_imag, b_imag)); __half imag = __hadd(__hmul(a_real, b_imag), __hmul(a_imag, b_real)); return __halves2half2(real, imag); } // Amplitude + phase → complex FP16 __device__ __forceinline__ complex_half amp_phase_to_complex_fp16( __half amp, __half phase ) { __half real = __hmul(amp, hcos(phase)); __half imag = __hmul(amp, hsin(phase)); return __halves2half2(real, imag); } **Why this is fast:** On Ampere/Hopper GPUs, ``__half`` operations use Tensor Cores and process 2× faster than FP32. Fused gather + multiply kernel ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. code-block:: cuda __global__ void gather_multiply_fp16( const __half* __restrict__ obj_amplitude, // [H, W] FP16 const __half* __restrict__ obj_phase, // [H, W] FP16 const __half2* __restrict__ probe, // [probe_size²] complex FP16 const int* __restrict__ y_indices, // [B * probe_size²] const int* __restrict__ x_indices, // [B * probe_size²] __half2* __restrict__ output, // [B * probe_size²] complex FP16 int batch_size, int probe_size, int object_width ) { int batch_idx = blockIdx.x; int pixel_idx = threadIdx.x + blockIdx.y * blockDim.x; int probe_area = probe_size * probe_size; if (pixel_idx < probe_area) { int global_idx = batch_idx * probe_area + pixel_idx; // Step 1: Gather object pixel coordinates int y = y_indices[global_idx]; int x = x_indices[global_idx]; int obj_idx = y * object_width + x; // Step 2: Convert amplitude + phase → complex (fused, in registers) __half amp = obj_amplitude[obj_idx]; __half phase = obj_phase[obj_idx]; complex_half obj_val = amp_phase_to_complex_fp16(amp, phase); // Step 3: Complex multiply with probe (fused, uses Tensor Cores!) complex_half probe_val = probe[pixel_idx]; output[global_idx] = cmul_fp16(obj_val, probe_val); } } **PyTorch wrapper:** .. code-block:: python import torch from torch.utils.cpp_extension import load_inline # Compile CUDA kernel cuda_source = """ // ... kernel code above ... torch::Tensor fused_exit_wave( torch::Tensor obj_amplitude, torch::Tensor obj_phase, torch::Tensor y_indices, torch::Tensor x_indices, // Step 3: Complex multiply with probe (fused, uses Tensor Cores!) complex_half probe_val = probe[pixel_idx]; output[global_idx] = cmul_fp16(obj_val, probe_val); } } **What this achieves:** - **3 operations → 1 kernel:** Amplitude/phase conversion + gather + complex multiply - **FP32 → FP16:** 2× memory bandwidth, 2× compute throughput (Tensor Cores) - **No intermediate memory:** Everything computed in registers - **1 kernel launch instead of 3:** Eliminates CPU overhead Launch configuration and PyTorch binding ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. code-block:: cuda extern "C" { void launch_gather_multiply_fp16( const void* obj_amplitude, const void* obj_phase, const void* probe, const int* y_indices, const int* x_indices, void* output, int batch_size, int probe_size, int object_width, cudaStream_t stream ) { int probe_area = probe_size * probe_size; // Grid: [batch_size, (probe_area + 255) / 256] dim3 grid(batch_size, (probe_area + 255) / 256); dim3 block(256); gather_multiply_fp16<<>>( (const __half*)obj_amplitude, (const __half*)obj_phase, (const __half2*)probe, y_indices, x_indices, (__half2*)output, batch_size, probe_size, object_width ); } } // extern "C" Use in PyTorch: .. code-block:: python import torch import ptychography_fp16 # Your compiled extension # Convert FP32 → FP16 for kernel input obj_amp_fp16 = object_amplitude.to(torch.float16) obj_phase_fp16 = object_phase.to(torch.float16) probe_fp16 = probe.to(torch.float16) # Single fused FP16 kernel call exit_waves_fp16 = ptychography_fp16.gather_multiply_fp16( obj_amp_fp16, obj_phase_fp16, probe_fp16, y_idx, x_idx ) # Continue with FFT in FP32 Psi = torch.fft.fft2(exit_waves_fp16.to(torch.complex64)) **Expected speedup:** 2-3× compared to PyTorch FP32 (gather + multiply operations) **Performance breakdown:** - Kernel fusion: 1.3-1.5× (eliminates 2 extra kernel launches) - FP16 compute: 1.5-2× (Tensor Cores, 2× throughput) - **Combined: 2-3× for this operation** Helper kernels for FP16 conversion ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Additional utility kernels for converting between FP32/FP16 and packing/unpacking complex numbers: .. code-block:: cuda // Compute intensity: |ψ|² (FP16 input, FP32 output for stability) __global__ void compute_intensity_fp16( const __half* psi_real, const __half* psi_imag, float* intensity, int size ) { int idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx < size) { float real_f32 = __half2float(psi_real[idx]); float imag_f32 = __half2float(psi_imag[idx]); intensity[idx] = real_f32 * real_f32 + imag_f32 * imag_f32; } } // Pack separate real/imag → __half2 __global__ void pack_complex_fp16( const __half* real, const __half* imag, __half2* output, int size ) { int idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx < size) { output[idx] = __halves2half2(real[idx], imag[idx]); } } **Why compute intensity in FP32:** Squaring FP16 values can cause overflow. Convert to FP32 for the final multiplication to maintain numerical stability. exit_waves = fused_module.fused_exit_wave( self.object_amplitude, self.object_phase, self.y_idx, self.x_idx, self.probe_cache ) # FFT (can't fuse this - cuFFT is optimal) Psi = torch.fft.fft2(exit_waves, s=(self.diffraction_size, self.diffraction_size)) return torch.abs(Psi) **What this kernel achieves:** - Eliminates separate ``amplitude * exp(i*phase)`` operation - Eliminates separate gather operation - Eliminates separate complex multiply operation - All done in one pass with coalesced memory access **Expected speedup:** 1.3-1.5× **New time:** 0.110s → 0.073-0.085s Why this works: Instead of materializing the full complex object in memory, then gathering patches, then multiplying by probe (three separate kernel launches, three memory round-trips), we compute everything on-the-fly. Each thread reads amplitude and phase values, computes the complex object value, and immediately multiplies by the probe - all in registers. Custom kernel 3: Fused intensity loss -------------------------------------- **Problem:** Abs + power operations take 30.45% combined: - ``aten::abs``: 656ms (5.95%) - ``AbsBackward0``: 1.524s (13.82%) - ``PowBackward0``: 1.175s (10.65%) The current code: .. code-block:: python # Three separate operations I_pred = torch.abs(Psi) ** 2 # abs, then pow pred_scale = counts / I_pred.mean(dim=(1,2), keepdim=True) # mean, div diff = I_pred * pred_scale - I_meas * meas_scale # mul, sub loss = (diff * diff).mean() # mul, mean **Custom CUDA kernel** fuses everything: .. code-block:: cuda __global__ void fused_intensity_loss_kernel( const cuFloatComplex* __restrict__ Psi, // [B, H, W] const float* __restrict__ I_meas, // [B, H, W] const float* __restrict__ meas_scale, // [B, 1, 1] float counts, float* __restrict__ loss_output, // [1] float* __restrict__ I_pred_output, // [B, H, W] (for backward) int B, int H, int W ) { __shared__ float shared_sum[256]; __shared__ float shared_mean[256]; // For per-batch means int idx = blockIdx.x * blockDim.x + threadIdx.x; int tid = threadIdx.x; int total = B * H * W; float local_sum = 0.0f; // Phase 1: Compute intensities and per-batch means if (idx < total) { int b = idx / (H * W); cuFloatComplex psi_val = Psi[idx]; // I_pred = |Psi|^2 float I_pred = psi_val.x * psi_val.x + psi_val.y * psi_val.y; I_pred_output[idx] = I_pred; // Store for backward pass // Accumulate for mean (atomicAdd to shared memory per batch) atomicAdd(&shared_mean[b % 256], I_pred); } __syncthreads(); // Phase 2: Compute loss with scaling if (idx < total) { int b = idx / (H * W); float I_pred = I_pred_output[idx]; // pred_scale = counts / mean(I_pred) float pred_scale = counts / (shared_mean[b % 256] / (H * W)); // diff = I_pred * pred_scale - I_meas * meas_scale float diff = I_pred * pred_scale - I_meas[idx] * meas_scale[b]; // loss = mean(diff^2) local_sum += diff * diff; } // Reduce across threads shared_sum[tid] = local_sum; __syncthreads(); for (int s = blockDim.x / 2; s > 0; s >>= 1) { if (tid < s) { shared_sum[tid] += shared_sum[tid + s]; } __syncthreads(); } if (tid == 0) { atomicAdd(loss_output, shared_sum[0] / total); } } **PyTorch wrapper for forward pass:** .. code-block:: python def fused_intensity_loss(Psi, I_meas, meas_scale, counts): """ Fused forward pass - computes loss in a single kernel. Uses PyTorch autograd for backward pass. """ B, H, W = Psi.shape loss = torch.zeros(1, device=Psi.device) I_pred = torch.empty_like(I_meas) threads = 256 blocks = (B * H * W + threads - 1) // threads fused_intensity_loss_kernel[blocks, threads]( Psi, I_meas, meas_scale, counts, loss, I_pred, B, H, W ) return loss # Use in training with PyTorch autograd loss = fused_intensity_loss(Psi, I_meas, meas_scale, counts) loss.backward() # PyTorch handles gradient computation **Expected speedup:** 1.5-1.8× (eliminates 6 separate operations + all intermediate memory) **New time:** 0.073s → 0.041-0.049s Summary: Two critical kernels + one optional --------------------------------------------- **Priority ranking by impact:** 1. **Kernel 2 (exit wave FP16) - CRITICAL:** Runs 4,096× per iteration → 2-3× speedup 2. **Kernel 3 (intensity loss) - IMPORTANT:** Runs 4,096× per iteration → 1.5-1.8× speedup 3. **Kernel 1 (probe generation) - OPTIONAL:** Runs 1× per iteration, only if optimizing probe aberrations **The forward pass:** .. code-block:: text [Optional] Kernel 1: probe_k generation → [IFFT] → probe_real (1× per iter) ↓ [CRITICAL] Kernel 2: Exit wave FP16 (4,096× per iter) tile gather + amp/phase → complex + multiply → [FFT] → Psi ↓ [IMPORTANT] Kernel 3: Intensity loss (4,096× per iter) |Psi|², scaling, loss **Why Kernel 2 is most important:** - Runs **4,096 times per iteration** (once per scan position) - Uses FP16 with Tensor Cores → 2-3× speedup - Fuses 3 operations into 1 kernel - **Biggest performance impact** **Why Kernel 1 is optional:** - Runs **1 time per iteration** (or never if probe is fixed) - Most reconstructions use fixed probe (no optimization needed) - Even if trainable, small fraction of total time (~10%) - **Low priority - skip unless you need it** **Performance gains:** - Kernel 1 (optional): 1.1-1.2× IF optimizing probe aberrations - Kernel 2 (critical): **2-3× speedup** with FP16 + kernel fusion - Kernel 3 (important): 1.5-1.8× speedup - **Realistic combined: 2-3× from Kernels 2+3 alone** **Recommendation:** Implement Kernel 2 first, then Kernel 3. Skip Kernel 1 unless you're optimizing probe aberrations. Backward pass: Use PyTorch autograd loss = compute_loss(Psi, I_meas) scaler.scale(loss).backward() # Scale gradients to prevent underflow scaler.step(optimizer) scaler.update() **Performance gains:** - 1.3-1.5× speedup on Ampere/Hopper GPUs - 2× memory reduction for activations - No accuracy loss for this application **Final combined speedup:** 2.5-3.5× (kernels) × 1.3-1.5× (FP16) = **3-5× total** **Result:** 0.11s → 0.022-0.037s per iteration Backward pass: Use PyTorch autograd ------------------------------------ **Strategy:** Focus optimization on forward pass, use PyTorch's built-in tools for backward. **Strategy:** Focus optimization on forward pass (Kernels 2+3), use PyTorch's built-in tools for backward. .. code-block:: python optimizer = torch.optim.Adam([object_complex, probe], lr=0.01) for iteration in range(max_iterations): optimizer.zero_grad() # Forward pass with custom FP16 kernels exit_waves_fp16 = ptychography_fp16.gather_multiply_fp16( obj_amp_fp16, obj_phase_fp16, probe_fp16, y_idx, x_idx ) Psi = torch.fft.fft2(exit_waves_fp16.to(torch.complex64)) loss = compute_loss(Psi, I_meas) loss.backward() # PyTorch autograd handles gradients optimizer.step() **Why this is good:** - PyTorch's autograd is already highly optimized - ADAM optimizer is well-tested and numerically stable - Less code to maintain - Forward pass (Kernels 2+3) is where the biggest gains are (2-3× each) Implementation guide -------------------- Compile custom FP16 CUDA kernels ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Create ``setup.py`` for PyTorch extension: .. code-block:: python from setuptools import setup from torch.utils.cpp_extension import BuildExtension, CUDAExtension setup( name='ptychography_fp16', ext_modules=[ CUDAExtension( 'ptychography_fp16', ['fp16_complex_ops.cu'], extra_compile_args={ 'cxx': ['-O3'], 'nvcc': ['-O3', '--use_fast_math', '-arch=sm_80'] # Ampere } ) ], cmdclass={'build_ext': BuildExtension} ) Compile and install: .. code-block:: bash python setup.py install Integrate into PyTorch training loop ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. code-block:: python import torch import torch.nn as nn import ptychography_fp16 # Compiled custom kernels class OptimizedPtychography(nn.Module): def __init__(self, object_size, probe_size, scan_positions, device='cuda'): super().__init__() # Master copies in FP32 self.object_amplitude = nn.Parameter(torch.randn(object_size, device=device)) self.object_phase = nn.Parameter(torch.randn(object_size, device=device)) self.probe = nn.Parameter(torch.randn(probe_size, dtype=torch.complex64, device=device)) # Precomputed indices self.register_buffer('y_idx', ...) self.register_buffer('x_idx', ...) def forward(self, batch_indices): # Convert to FP16 for custom kernel object_fp16 = ptychography_fp16.amp_phase_to_complex_fp16( self.object_amplitude.to(torch.float16), self.object_phase.to(torch.float16) ) probe_fp16 = self.probe.to(torch.float16) # Custom FP16 gather + multiply kernel exit_waves_fp16 = ptychography_fp16.gather_multiply_fp16( object_fp16, probe_fp16, self.y_idx, self.x_idx ) # FFT in FP32 for stability Psi = torch.fft.fft2(exit_waves_fp16.to(torch.complex64)) return Psi # Training loop model = OptimizedPtychography(...).to('cuda') optimizer = torch.optim.Adam(model.parameters(), lr=0.01) for iteration in range(num_iterations): optimizer.zero_grad() Psi = model(batch_indices) loss = compute_loss(Psi, I_meas) loss.backward() # PyTorch autograd handles gradients optimizer.step() Validate FP16 accuracy ~~~~~~~~~~~~~~~~~~~~~~~ # Reference (PyTorch ops) with torch.no_grad(): Psi_ref = model_pytorch(batch_indices) loss_ref = compute_loss_pytorch(Psi_ref, ...) # Custom kernels with torch.no_grad(): Psi_opt = model_optimized(batch_indices) loss_opt = compute_loss_optimized(Psi_opt, ...) # Check numerical accuracy max_diff = torch.max(torch.abs(Psi_ref - Psi_opt)) rel_error = max_diff / torch.max(torch.abs(Psi_ref)) print(f"Max absolute difference: {max_diff:.6e}") print(f"Relative error: {rel_error:.6e}") assert rel_error < 1e-4, "Custom kernels changed results!" Expected errors with mixed precision: < 1e-3 (acceptable for scientific computing) Step 4: Profile to verify speedup ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. code-block:: python import time # Warmup for _ in range(10): _ = model(batch_indices) # Benchmark torch.cuda.synchronize() times = [] for _ in range(100): t0 = time.perf_counter() loss = train_iteration(model, ...) torch.cuda.synchronize() times.append(time.perf_counter() - t0) print(f"Mean: {np.mean(times):.4f}s") print(f"Std: {np.std(times):.4f}s") print(f"Target < 0.01s: {'✅' if np.mean(times) < 0.01 else '❌'}") Summary ------- **Three custom CUDA kernels + FP16 = 3-5× speedup** **Why three separate kernels:** .. code-block:: text Kernel 1: Probe generation (aberrations) → [IFFT] → probe Kernel 2: Exit wave (gather + multiply) → [FFT] → Psi Kernel 3: Intensity loss (|Psi|², scale, loss) Cannot fuse through FFT/IFFT - global operations, cuFFT optimal. **What each kernel fuses:** 1. **Probe generation:** Defocus + Cs + astigmatism + aperture → FP32 probe_k → [IFFT] - Speedup: 1.1-1.2× 2. **Exit wave (FP16):** Tile gather + amp/phase → complex + complex multiply in FP16 - Speedup: 2-3× (uses Tensor Cores) - Custom ``__half2`` complex math - Fuses 3 operations into 1 kernel 3. **Intensity loss:** |Psi|² + scaling + difference + reduction - Speedup: 1.5-1.8× **FP16 implementation details:** - Uses ``__half2`` for complex numbers: ``(real, imag)`` packed - Custom ``cmul_fp16()`` for complex multiply (Tensor Cores) - Master copies in FP32, compute in FP16 - FFT stays FP32 for stability - Typical error: < 1e-3 (acceptable) **Backward pass:** - Use PyTorch autograd (already optimized) - ADAM optimizer - Focus on forward pass optimization **Final result:** 0.11s → **0.022-0.037s** (3-5× speedup) **Key takeaway:** Custom FP16 kernels + kernel fusion + FFT boundaries = maximum performance with minimal code. # PyTorch Operations Deep Dive: Ptychography Forward Pass This document explains **every PyTorch operation** in the ptychography forward pass with detailed visual illustrations. --- ## Table of Contents 1. [The Big Picture: What is Ptychography?](#1-the-big-picture) 2. [Setup: Initial Data Structures](#2-setup) 3. [Operation 1: Gather (Extracting Patches)](#3-gather-operation) 4. [Operation 2: Complex Multiplication](#4-complex-multiplication) 5. Operation 3: FFT (Fourier Transform) 6. Operation 4: Intensity Calculation 7. Performance Optimizations Explained The big picture: What is ptychography? --------------------------------------- .. code-block:: text PTYCHOGRAPHY: Reconstructing an object by scanning a probe over it Step 1: Physical Setup ┌─────────────────────────────────────────┐ │ │ │ Electron Beam (Probe) │ │ ↓↓↓↓↓↓↓ │ │ ┌─────────────┐ │ │ │ Probe │ ← 128×128 pixels │ │ │ (complex) │ │ │ └─────────────┘ │ │ ↓ │ │ ┌─────────────────────┐ │ │ │ │ │ │ │ Sample/Object │ ← 512×512 │ │ │ (unknown!) │ │ │ │ │ │ │ └─────────────────────┘ │ │ ↓ │ │ ┌─────────────┐ │ │ │ Diffraction │ ← Detector 256×256 │ │ │ Pattern │ │ │ └─────────────┘ │ │ │ └─────────────────────────────────────────┘ Step 2: Scanning Process Position 1 Position 2 Position 3 ┌────┐ ┌────┐ ┌────┐ │Beam│ │Beam│ │Beam│ └─┬──┘ └─┬──┘ └─┬──┘ ↓ ↓ ↓ ┌─────────────────────────────────────┐ │ ░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░ │ │ ░░░Sample moves, beam scans!░░░░░░ │ │ ░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░ │ └─────────────────────────────────────┘ We scan 4096 positions (typical) and record diffraction at each! Goal: Given diffraction patterns, reconstruct the object. Forward model (what we're implementing): .. code-block:: text Object + Probe → Exit Wave → FFT → Diffraction Intensity ↑ ↑ ↑ ↑ ↑ 512×512 128×128 128×128 256×256 256×256 complex complex complex complex real Setup: Initial data structures ------------------------------- Input data ^^^^^^^^^^ .. code-block:: python # The object we want to reconstruct (complex-valued) object_real = torch.randn(512, 512) # Real part object_imag = torch.randn(512, 512) # Imaginary part # The probe (illumination, known or optimized) # Here assume it's already IFFTed from probe(k) probe_real = torch.randn(128, 128) # Real part probe_imag = torch.randn(128, 128) # Imaginary part # Scan positions (where the probe illuminates) batch_size = 4096 # Number of scanned positions And we have the scanned positions: .. code-block:: text Object: ┌─────────────────────┐ │ ● ● ● │ ← Row 1: (100,50), (100,150), (100,250) │ │ │ ● ● ● │ ← Row 2: (200,50), (200,150), (200,250) │ │ │ ● ● ● │ ← Row 3: (300,50), (300,150), (300,250) └─────────────────────┘ y_centers = [100, 100, 100, 200, 200, 200, 300, 300, 300] x_centers = [50, 150, 250, 50, 150, 250, 50, 150, 250] .. code-block:: python y_centers = torch.tensor([100, 150, 200, ...]) # 4096 y-coordinates x_centers = torch.tensor([50, 75, 100, ...]) # 4096 x-coordinates Memory layout ^^^^^^^^^^^^^ .. code-block:: text Object (512×512): ┌─────────────────────────────────────┐ │ 0 1 2 3 4 5 6 ... 510 511 │ ← Row 0 │ 0 1 2 3 4 5 6 ... 510 511 │ ← Row 1 │ : : : : : : : : : │ │ 0 1 2 3 4 5 6 ... 510 511 │ ← Row 511 └─────────────────────────────────────┘ Each cell is a complex number: real + i*imag Probe (128×128): ┌──────────────────┐ │ Gaussian-like │ │ beam shape │ │ (128×128 px) │ └──────────────────┘ .. dropdown:: Gather operation: Extracting patches (detailed explanation) :color: info What is "gather"? ^^^^^^^^^^^^^^^^^ Problem: We need to extract a 128×128 patch from the 512×512 object at each scan position. .. code-block:: python # Scan position 0: center at (100, 50) # Scan position 1: center at (150, 75) # Scan position 2: center at (200, 100) # ... 4093 more positions Visual example (position 0) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ .. code-block:: text Object (512×512): x_center=50 ↓ ┌──────────────┼──────────────┐ │ │ │ │ │ │ y=36 ├──────────────●──────────────┤ ← Top-left of patch │ ┌────────┼────────┐ │ │ │ Patch │128×128 │ │ y_center=100──┼────────●────────┼─────┤ ← Center │ │ │ │ │ │ └────────┼────────┘ │ y=164 ├──────────────┼──────────────┤ ← Bottom-right │ │ │ └──────────────┼──────────────┘ x=36 x=50 x=164 Top-left corner: (y_center - 64, x_center - 64) = (36, -14) ↑ Note: 128/2 = 64 (half the probe size) How gather works in PyTorch ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Method 1: Naive loop (SLOW) """""""""""""""""""""""""""" .. code-block:: python patches = [] for i in range(4096): y_start = y_centers[i] - 64 x_start = x_centers[i] - 64 # Extract 128×128 patch patch_real = object_real[y_start:y_start+128, x_start:x_start+128] patch_imag = object_imag[y_start:y_start+128, x_start:x_start+128] patches.append((patch_real, patch_imag)) # Problem: 4096 separate operations! Very slow on GPU. Here is a visual for ``patch_real = object_real[36:164, 50:178]``: .. code-block:: text object_real (512×512): Columns → 0 1 2 ... 50 51 ... 177 178 ... 511 R 0 ░ ░ ░ ... ░ ░ ... ░ ░ ... ░ o 1 ░ ░ ░ ... ░ ░ ... ░ ░ ... ░ w : : : : : : : : : s 36 ░ ░ ░ ... █ █ ... █ ░ ... ░ ← Start row (y_start) ↓ 37 ░ ░ ░ ... █ █ ... █ ░ ... ░ 38 ░ ░ ░ ... █ █ ... █ ░ ... ░ : : : : : : : : : 163 ░ ░ ░ ... █ █ ... █ ░ ... ░ ← End row (y_start+127) 164 ░ ░ ░ ... ░ ░ ... ░ ░ ... ░ ← NOT included! : : : : : : : : : 511 ░ ░ ░ ... ░ ░ ... ░ ░ ... ░ ↑ ↑ Start col End col (x_start=50) (x_start+127=177) 178 NOT included! █ = Extracted patch (128×128) ░ = Not extracted Method 2: Advanced indexing (FAST) """"""""""""""""""""""""""""""""""" .. code-block:: python # Step 1: Create a meshgrid of offsets # This is the pattern of offsets within the 128×128 patch y_offset = torch.arange(128) # [0, 1, 2, ..., 127] x_offset = torch.arange(128) # [0, 1, 2, ..., 127] grid_y, grid_x = torch.meshgrid(y_offset, x_offset, indexing='ij') # grid_y: [[0,0,0,...,0], grid_x: [[0,1,2,...,127], # [1,1,1,...,1], [0,1,2,...,127], # ... ... # [127,127,...,127]] [0,1,2,...,127]] Visual representation of grid: .. code-block:: text grid_y (128×128): grid_x (128×128): ┌───────────────┐ ┌───────────────┐ │ 0 0 0 ... 0 │ │ 0 1 2 ...127│ │ 1 1 1 ... 1 │ │ 0 1 2 ...127│ │ 2 2 2 ... 2 │ │ 0 1 2 ...127│ │ : : : : │ │ : : : : │ │127 127 ...127 │ │ 0 1 2 ...127│ └───────────────┘ └───────────────┘ Think of these as: "relative position within patch" .. code-block:: python # Step 2: Add center positions to get absolute coordinates # Shape: [4096, 128, 128] y_indices = y_centers[:, None, None] + grid_y[None, :, :] x_indices = x_centers[:, None, None] + grid_x[None, :, :] # Broadcasting explained: # y_centers[:, None, None] # [4096, 1, 1] # grid_y[None, :, :] # [1, 128, 128] # ─────────────────────── # y_indices # [4096, 128, 128] (broadcast!) Broadcasting visualization: .. code-block:: text y_centers[:, None, None]: grid_y[None, :, :]: ┌────┐ ┌──────────────┐ │100 │ │ 0 0 0 ...0 │ │150 │ Broadcast → │ 1 1 1 ...1 │ │200 │ ──────────── │ : : : : │ │ : │ │127 127 ...127│ │300 │ └──────────────┘ └────┘ [4096,1,1] [1, 128, 128] Result y_indices [4096, 128, 128]: Position 0: [[100,100,...,100], ← y_center=100 + offset 0 [101,101,...,101], ← y_center=100 + offset 1 ... [227,227,...,227]] ← y_center=100 + offset 127 Position 1: [[150,150,...,150], ← y_center=150 + offset 0 [151,151,...,151], ... [277,277,...,277]] .. code-block:: python # Step 3: Gather using advanced indexing (MAGIC! ✨) patches_real = object_real[y_indices, x_indices] patches_imag = object_imag[y_indices, x_indices] # Shape: [4096, 128, 128] # This is a SINGLE GPU operation! All 4096 patches extracted at once! **Concrete example: Advanced indexing with real numbers** Let's use a small 8×8 object and extract 2 patches of 3×3: .. code-block:: python # Object (8×8) with easy-to-track values object_real = torch.tensor([ [10, 11, 12, 13, 14, 15, 16, 17], # Row 0 [20, 21, 22, 23, 24, 25, 26, 27], # Row 1 [30, 31, 32, 33, 34, 35, 36, 37], # Row 2 [40, 41, 42, 43, 44, 45, 46, 47], # Row 3 [50, 51, 52, 53, 54, 55, 56, 57], # Row 4 [60, 61, 62, 63, 64, 65, 66, 67], # Row 5 [70, 71, 72, 73, 74, 75, 76, 77], # Row 6 [80, 81, 82, 83, 84, 85, 86, 87], # Row 7 ]) # Extract 2 patches (3×3 each) # Patch 0: center at (2, 2) → top-left = (1, 1) # Patch 1: center at (4, 5) → top-left = (3, 4) # Step 1: Create offset grid for 3×3 patch grid_y = torch.tensor([[0, 0, 0], [1, 1, 1], [2, 2, 2]]) grid_x = torch.tensor([[0, 1, 2], [0, 1, 2], [0, 1, 2]]) # Step 2: Centers and top-left corners y_centers = torch.tensor([2, 4]) x_centers = torch.tensor([2, 5]) y_start = y_centers - 1 # [1, 3] x_start = x_centers - 1 # [1, 4] # Step 3: Build indices [2, 3, 3] y_indices = y_start[:, None, None] + grid_y[None, :, :] x_indices = x_start[:, None, None] + grid_x[None, :, :] # y_indices: x_indices: # [[[1, 1, 1], [[[1, 2, 3], # [2, 2, 2], [1, 2, 3], # [3, 3, 3]], [1, 2, 3]], # # [[3, 3, 3], [[4, 5, 6], # [4, 4, 4], [4, 5, 6], # [5, 5, 5]]] [4, 5, 6]]] # Step 4: Advanced indexing! patches_real = object_real[y_indices, x_indices] What happens internally: Assume center is 32, position (0,0) becomes 21, (0,1) 22, (0,2) 23, etc. ``object_real`` is the main 8 by 8. .. code-block:: python # For patch 0, position (0,0): y_indices[0,0,0] = 1, x_indices[0,0,0] = 1 patches_real[0,0,0] = object_real[1, 1] = 21 # For patch 0, position (0,1): y_indices[0,0,1] = 1, x_indices[0,0,1] = 2 patches_real[0,0,1] = object_real[1, 2] = 22 # For patch 0, position (1,1): y_indices[0,1,1] = 2, x_indices[0,1,1] = 2 patches_real[0,1,1] = object_real[2, 2] = 32 # ... and so on for all 18 positions (2 patches × 3×3) Result: .. code-block:: python patches_real = [ [[21, 22, 23], # Patch 0 [31, 32, 33], # ← Center at (2,2) = 32 [41, 42, 43]], [[44, 45, 46], # Patch 1 [54, 55, 56], # ← Center at (4,5) = 55 [64, 65, 66]] ] Visual: .. code-block:: text Object (8×8): 10 11 12 13 14 15 16 17 20 21 22 23 24 25 26 27 Patch 0 (red): Patch 1 (blue): 30 31 32 33 34 35 36 37 21 22 23 44 45 46 40 41 42 43 44 45 46 47 31 32 33 54 55 56 50 51 52 53 54 55 56 57 41 42 43 64 65 66 60 61 62 63 64 65 66 67 70 71 72 73 74 75 76 77 80 81 82 83 84 85 86 87 The key: ``y_indices`` and ``x_indices`` act as a "lookup table" telling PyTorch exactly which element to grab from ``object_real`` for each position in the output. **How advanced indexing works** .. code-block:: text object_real[y_indices, x_indices] PyTorch internally does: for b in range(4096): for i in range(128): for j in range(128): y = y_indices[b, i, j] # e.g., 100 x = x_indices[b, i, j] # e.g., 50 result[b, i, j] = object_real[y, x] But this runs as PARALLEL threads on GPU! Not a loop! ``` GPU parallelization: ``` CPU Loop (sequential): GPU (parallel): ┌───┐ ┌─┬─┬─┬─┬─┬─┬─┬─┐ │ 1 │ → │ 2 │ → │ 3 │ → ... │1│2│3│4│5│6│7│8│ ← All at once! └───┘ └─┴─┴─┴─┴─┴─┴─┴─┘ Time: 4096 × 128×128 ops Time: ~1ms (massively parallel) ``` Complex multiplication: Object × probe --------------------------------------- Complex number basics ^^^^^^^^^^^^^^^^^^^^^ .. code-block:: text Complex number: z = a + bi where: a = real part, b = imaginary part, i = √(-1) Multiplication rule: (a + bi) × (c + di) = (ac - bd) + (ad + bc)i ───────── ───────── real imag **In our code** .. code-block:: python # We have: object_patch = object_real[y_indices, x_indices] + i * object_imag[y_indices, x_indices] probe = probe_real + i * probe_imag # Exit wave = object_patch × probe # Using formula: (a+bi)×(c+di) = (ac-bd) + (ad+bc)i exit_real = object_real[y_indices, x_indices] * probe_real \ - object_imag[y_indices, x_indices] * probe_imag exit_imag = object_real[y_indices, x_indices] * probe_imag \ + object_imag[y_indices, x_indices] * probe_real **Visual example (single pixel)** .. code-block:: text Object patch at position (0,0): 0.8 + 0.3i Probe at position (0,0): 0.9 + 0.1i Exit wave = (0.8 + 0.3i) × (0.9 + 0.1i) = (0.8×0.9 - 0.3×0.1) + (0.8×0.1 + 0.3×0.9)i = (0.72 - 0.03) + (0.08 + 0.27)i = 0.69 + 0.35i exit_real[0,0,0] = 0.69 exit_imag[0,0,0] = 0.35 **Shape transformation** .. code-block:: text Input: object_real[y_indices, x_indices] [4096, 128, 128] object_imag[y_indices, x_indices] [4096, 128, 128] probe_real [128, 128] probe_imag [128, 128] Broadcasting (probe gets repeated): probe_real → [4096, 128, 128] (same value for all 4096 batches) Output: exit_real [4096, 128, 128] exit_imag [4096, 128, 128] Broadcasting visualization: .. code-block:: text Batch dimension: probe_real [128, 128] ↓ broadcast ┌─────┬─────┬─────┬─────┬─────┬─────┐ │ #0 │ #1 │ #2 │ ... │4094 │4095 │ └─────┴─────┴─────┴─────┴─────┴─────┘ ↑ ↑ ↑ ↑ ↑ Same probe repeated 4096 times (implicitly) Each batch: object_patch[b] × probe → exit_wave[b] ```