Phase retrieval methods
Three approaches
Different techniques exploit different aspects of the 4D-STEM dataset to recover phase:
- Differential Phase Contrast (DPC)
Measures how the electron beam deflects as it passes through the sample. At each scan position, compute the center of mass (CoM) of the diffraction pattern. The CoM shift is proportional to the phase gradient \(\nabla\phi\). Integrate these gradients to reconstruct the full phase map \(\phi(x,y)\).
Strength: Fast, simple algorithm; works well for weak-phase samples
Limitation: Assumes weak-phase approximation; breaks down for thick or strongly scattering samples
- Beam-tilt corrected bright-field STEM (tcBF-STEM)
Measures virtual image shifts caused by microscope aberrations. For each detector angle, extract a virtual bright-field image. These images are shifted relative to each other due to aberrations (defocus, astigmatism, coma). Fit a polynomial aberration model to the shifts, correct each image, then combine all corrected views using upsampling to achieve super-resolution.
Strength: Super-resolution through multi-view combination; computationally cheaper than ptychography
Limitation: Primarily corrects aberrations rather than retrieving full complex transmission function
- Ptychography
Iteratively refines both the object \(O(x,y)\) and probe \(P(x,y)\) to match measured diffraction intensities across overlapping scan positions. Uses gradient-based optimization to minimize the error between predicted and measured intensities.
Strength: Retrieves complete complex transmission function (amplitude + phase); handles arbitrary scattering strengths; simultaneously reconstructs unknown probe aberrations
Limitation: Computationally intensive; requires overlapping scan positions
The following sections describe each technique in detail.
Differential Phase Contrast (DPC)
What is DPC?
Differential Phase Contrast (DPC), developed in 1974, retrieves phase information by measuring how electron waves deflect as they pass through a sample. The technique measures the center of mass of the diffraction pattern at each scan position, which is proportional to the local phase gradient—and thus to the electron deflection caused by the sample’s electrostatic potential.
How does DPC work?
At each scan position \((x,y)\), measure the center of mass of the diffraction pattern:
The center of mass is proportional to the phase gradient:
where \(C\) is a calibration constant. Integrate the measured gradients to reconstruct the full phase map \(\phi(x,y)\).
What assumptions does DPC make?
DPC assumes small phase shifts (weak-phase approximation):
This approximation breaks down for thick or strongly scattering samples.
Why calculate the rotation angle?
The rotation angle aligns the coordinate system of the diffraction pattern with the scan coordinate system. It is found by minimizing the divergence of the center of mass vector field.
Ptychography
Introduction
Ptychography is a revolutionary imaging technique that solves two fundamental problems in electron microscopy: traditional imaging is limited by lens aberrations, but ptychography bypasses these lens limitations and can achieve sub-Ångström resolution. Unlike DPC or tcBF-STEM, ptychography works for thick samples, handles strong scattering, recovers the complete wavefunction, and corrects aberrations computationally.
- What problem does ptychography solve that DPC or tcBF-STEM cannot?
Ptychography fully reconstructs the complex transmission function (both amplitude and phase) of a specimen with high resolution and sensitivity. Unlike DPC, which relies on the weak-phase approximation and fails for thick or strongly scattering samples, ptychography handles arbitrary scattering strengths. Unlike tcBF-STEM, which focuses on aberration correction and super-resolution, ptychography retrieves the complete complex wavefunction.
Key components
- The probe \(P(d_x,d_y)\)
The focused electron beam has finite size and shape, known aberrations, and is scanned across the sample.
- The object \(O(d_x,d_y)\)
The sample’s transmission function is complex-valued (amplitude and phase), represents the sample’s structure, and is what we want to reconstruct.
- The measurements
Diffraction patterns are recorded at each position in reciprocal space. They contain mixed probe and object information. The key is that patterns from overlapping regions provide redundancy.
Note
Each diffraction pattern views the same area from a different “angle”, providing the redundancy needed to solve the phase problem.
The reconstruction algorithm
The algorithm iteratively refines both the probe and object to match measured diffraction patterns.
- Initialization
We start with initial guesses for both components. The object \(O(d_x,d_y) = A(d_x,d_y) e^{i \phi(d_x,d_y)}\) usually starts with uniform amplitude and phase, or an informed guess from other methods. The probe is modeled in reciprocal space, where the aperture function sets the shape, aberrations appear as phase shifts, and parameters are based on microscope specifications.
- Forward model
At each scan position \(\mathbf{d}\), we calculate the exit wave \(\psi_{\text{exit}}(d_x,d_y) = P(d_x,d_y) \cdot O(d_x,d_y)\), propagate to the detector \(\tilde{\psi}(g_x,g_y) = \mathcal{F}\{\psi_{\text{exit}}\}\), and calculate the intensity \(I_{\text{calc}}(g_x,g_y) = |\tilde{\psi}(g_x,g_y)|^2\).
The probe is modeled in reciprocal space as:
\[\tilde P(\mathbf{k}) = a(\mathbf{k}) \, e^{i \chi(\mathbf{k})}\]where \(a(\mathbf{k})\) is the aperture function and \(\chi(\mathbf{k})\) is the aberration phase. The aberration function is expanded as a polynomial in spatial-frequency coordinates \((k_x,k_y)\):
\[\chi(\mathbf{k}) = a_0 + a_1 k_x + a_2 k_y + a_3 (k_x^2 + k_y^2) + a_4 (k_x^2 - k_y^2) + a_5 (2 k_x k_y) + \ldots\]The linear terms \((a_1, a_2)\) represent beam tilt, the radial quadratic term \(a_3\) represents defocus, the anisotropic quadratic terms \((a_4, a_5)\) represent astigmatism, and higher-order terms capture coma and spherical aberration. We model the probe in reciprocal space because aberrations naturally act as phase modulations in the pupil plane. The coefficients \(a_i\) are the optimization variables we refine during reconstruction. To use the probe in the exit wave calculation, we inverse Fourier transform to real space:
\[P(x,y) = \mathcal{F}^{-1}\{a(\mathbf{k}) e^{i \chi(\mathbf{k})}\} = \mathcal{F}^{-1}\{\tilde P(\mathbf{k})\}\]This real-space probe \(P(x,y)\) is then used for the pointwise multiplication with the object.
- Exit wave computation
At each scan position \(R = (R_x, R_y)\), the probe illuminates a local region of the object. The exit wave is the element-wise multiplication of the shifted probe and object:
\[\Psi_{\mathrm{exit}}(x,y;R) = P(x - R_x, y - R_y) \cdot O(x,y)\]This multiplication is pointwise (not convolution). For example, if the probe is a 64×64 array positioned at \(R=(10,15)\) on a 512×512 object, the probe is shifted to center at (10,15) and multiplied element-by-element with the corresponding 64×64 patch of the object. Since both \(P\) and \(O\) are complex-valued (stored as complex64 in PyTorch), the resulting exit wave \(\Psi_{\mathrm{exit}}\) is also complex. This process repeats for every scan position.
- Predicted intensity
The detector cannot measure complex amplitudes directly—it only records intensity, which is real-valued. The predicted intensity is obtained by Fourier transforming the complex exit wave and computing its squared magnitude:
\[I_{\mathrm{pred}}(k_x,k_y;R) = \big|\mathcal{F}\{\Psi_{\mathrm{exit}}(x,y;R)\}\big|^2\]In practice, compute the 2D FFT of \(\Psi_{\mathrm{exit}}\) (which is complex), then calculate the intensity as
torch.abs(fft_result)**2or equivalentlyfft_result.real**2 + fft_result.imag**2. This gives a real-valued diffraction pattern that can be compared with measured intensities.- Optimization
The reconstruction minimizes the squared error between measured and predicted intensities:
\[L = \sum_{R} \sum_{k} \left| I_{\mathrm{meas}}(R;k) - I_{\mathrm{pred}}(R;k) \right|^2\]where \(R\) indexes scan positions and \(k=(k_x,k_y)\) indexes detector pixels. Gradient-based optimizers (Adam, SGD) update both the object pixels and probe aberration coefficients. Because the probe is low-dimensional and sensitive to updates, it typically uses a smaller learning rate than the object. Iterative algorithms like ePIE or DM provide alternative update schemes. Convergence yields a high-resolution reconstruction of the specimen’s complex transmission function and an estimate of the probe with its aberrations.
Mixed state probes
The standard single-mode ptychographic model assumes a single coherent probe \(P(x,y)\). However, real electron beams have partial coherence due to source size, chromatic aberration, and instabilities. Mixed state reconstruction models the probe as a weighted sum of orthogonal modes:
\[P = \sum_{m} c_m P_m\]where \(P_m(x,y)\) are the probe modes and \(c_m\) are their relative weights satisfying \(\sum_m |c_m|^2 = 1\). Each mode represents a different coherent state of the illumination.
The measured intensity at each scan position becomes the incoherent sum over all modes:
\[I_j(\mathbf{q}) = \sum_{m=1}^{M} w_m \left| \mathcal{F}\{P_m(\mathbf{r} - \mathbf{r}_j) O(\mathbf{r})\} \right|^2\]where \(w_m = |c_m|^2\) are the mode weights and \(\mathbf{q}\) represents reciprocal space coordinates. The incoherent sum reflects that different modes do not interfere with each other at the detector.
During reconstruction, all probe modes \(P_m(x,y)\) and their weights \(w_m\) are optimized simultaneously along with the object \(O(x,y)\). This decomposition captures partial coherence effects and improves reconstruction quality when the beam has significant incoherence. Typical reconstructions use 2-5 probe modes, with the first mode dominating and additional modes capturing incoherent contributions.
Mixed state objects
The standard ptychographic model assumes a single deterministic object \(O(x,y)\). However, real samples may exhibit variations due to damage, dynamics, or compositional heterogeneity. Mixed state object reconstruction models the sample as a statistical mixture:
\[O = \sum_{n} w_n O_n\]where \(O_n(x,y)\) are the object states and \(w_n\) are their probabilities with \(\sum_n w_n = 1\). Each state represents a different possible configuration of the sample.
The measured intensity becomes:
\[I_j(\mathbf{q}) = \sum_{n=1}^{N} w_n \left| \mathcal{F}\{P(\mathbf{r} - \mathbf{r}_j) O_n(\mathbf{r})\} \right|^2\]This models scenarios where the sample changes between measurements or exists in multiple states simultaneously. During reconstruction, all object states \(O_n(x,y)\) and their weights \(w_n\) are optimized. This approach is particularly useful for beam-sensitive materials, samples undergoing phase transitions, or specimens with compositional variations.
Combined mixed states
For maximum generality, both probe and object can be represented as mixed states simultaneously:
\[I_j(\mathbf{q}) = \sum_{m=1}^{M} \sum_{n=1}^{N} w_m w_n \left| \mathcal{F}\{P_m(\mathbf{r} - \mathbf{r}_j) O_n(\mathbf{r})\} \right|^2\]This double sum captures both partial coherence in the illumination and statistical variations in the sample. Each combination of probe mode \(m\) and object state \(n\) contributes incoherently to the measured intensity. While computationally more expensive, this formulation provides the most complete physical model and can significantly improve reconstruction quality for challenging experimental conditions.
Computational implementation
- Data layout and complex arithmetic
Both the probe and object are complex-valued functions, requiring careful memory management. In PyTorch, use torch.complex64 (32-bit real + 32-bit imaginary) for the probe \(P(x,y)\) and object \(O(x,y)\). This balances memory usage with numerical precision. For higher accuracy, use torch.complex128, though this doubles memory consumption.
A 512×512 complex64 object requires 2 MB (512² × 8 bytes). Store these arrays in GPU memory to avoid CPU-GPU transfers during iteration. Measured diffraction patterns are real-valued and can use float32.
PyTorch natively supports complex multiplication (
*), FFT (torch.fft.fft2,torch.fft.ifft2), and absolute value (torch.abs). The intensity computation \(|z|^2\) can be efficiently computed astorch.abs(z)**2.Standard mixed precision training (torch.cuda.amp) does not support complex dtypes as of PyTorch 2.x. Keep complex arrays in full precision (complex64) and use float16 only for real-valued intermediate tensors if needed. Use pinned (page-locked) host memory for faster CPU-GPU transfers when loading new diffraction patterns. Consider asynchronous transfers to overlap data loading with computation.
- Efficient batching
Process multiple positions at once to use the GPU’s parallel computing power. Balance memory versus speed—larger batches provide faster processing but need more GPU memory. Typical batch sizes range from 16 to 256 positions.
For memory optimization, use mixed precision when possible, reuse memory buffers, and stream data efficiently by loading the next batch while processing, overlapping computation and transfer, and using pinned memory for faster transfers.
For algorithm improvements, use better initial guesses, smarter update schemes, adaptive step sizes, and early stopping when converged.
- Gradient accumulation
If you want an effective batch size of 256 but can only fit 64 positions in memory, process four mini-batches of 64, accumulating the gradients for the object \(\partial L/\partial O\) and probe coefficients \(\partial L/\partial a_i\) after each mini-batch. Update parameters only after processing all four mini-batches. This provides the training stability of large batches without the memory cost. In PyTorch, gradients accumulate automatically across
.backward()calls until you calloptimizer.step()andoptimizer.zero_grad().- Multi-GPU acceleration strategies
Here are the main approaches for utilizing multiple GPUs:
Data parallelism splits scan positions across GPUs. Each GPU processes its positions and combines results through averaging. This needs careful gradient synchronization.
Model parallelism splits the object across GPUs. Each GPU handles a region and exchanges boundary information. This is good for very large objects.
Hybrid strategies combine both approaches, balance computation and communication, adapt to hardware architecture, and use efficient collective operations.
Data parallelism in detail
Distribute scan positions across GPUs. Each GPU processes a different subset of positions with a replica of the full object and probe. For example, with 4 GPUs and 256 scan positions, GPU 0 processes positions 0-63, GPU 1 processes 64-127, and so on. Each GPU compares predicted intensities \(I_{\mathrm{pred}}\) with measured intensities \(I_{\mathrm{meas}}\) for its subset of positions, computes its local loss contribution, then computes gradients with respect to the object and probe parameters.
The gradients are averaged across GPUs using all-reduce: sum all gradients, then divide by the number of GPUs. This averaging is critical—without it, the effective learning rate would scale by the number of GPUs (4× larger update steps with 4 GPUs). Averaging ensures that the gradient magnitude matches single-GPU training, keeping the learning rate consistent regardless of how many GPUs you use. All GPUs then apply the same averaged gradient to update their local copies of the parameters, keeping all replicas synchronized. The total loss is the sum of losses from all GPUs.
Note that some frameworks (e.g., PyTorch DDP) sum gradients without dividing but expect you to scale the learning rate by \(1/N_{\mathrm{GPUs}}\) to achieve the same effect. This approach is simplest and works well when the object fits in a single GPU’s memory.
Model parallelism in detail
When the object is too large for one GPU, partition it spatially across devices. Each GPU owns a region of the object and communicates boundary information with neighbors. This requires careful handling of probe positions that overlap partition boundaries.
Hybrid approach
Combine data and model parallelism. Partition the object across GPUs and process multiple scan positions per GPU simultaneously. Use communication libraries (NCCL, MPI) for efficient gradient synchronization.
Why must all GPUs start with identical parameters in data parallelism?
Each GPU computes gradients with respect to its current state of the object and probe. If GPUs start with different initial values, their gradients point in different directions in parameter space—averaging such gradients is meaningless because they represent updates from different starting points.
During the first iteration, all GPUs initialize with identical object \(O_0\) and probe parameters (aberration coefficients \(a_i^{(0)}\)). Each GPU computes \(\partial L/\partial O\) and \(\partial L/\partial a_i\) based on its subset of scan positions, but these gradients are evaluated at the shared state \((O_0, a_i^{(0)})\). The gradients are averaged across GPUs. All GPUs apply the same averaged gradient: \(O_1 = O_0 - \eta \langle \partial L/\partial O \rangle\) and \(a_i^{(1)} = a_i^{(0)} - \eta \langle \partial L/\partial a_i \rangle\). All GPUs now have identical \((O_1, a_i^{(1)})\).
Subsequent iterations maintain synchronization because all GPUs apply the same update and remain at the same state. Each iteration, they compute gradients from this shared state, average them, and update synchronously. If any GPU had different initial parameters, this synchronization would break immediately—each GPU would diverge to a different reconstruction.
What is communication overhead in multi-GPU training?
Communication overhead is the time GPUs spend exchanging gradient data instead of computing. During all-reduce, GPUs must transfer their local gradients to each other and wait for all transfers to complete before proceeding.
- How all-reduce works
All-reduce is a collective operation where every GPU ends up with the same averaged result. No single GPU “processes” the reduction—instead, GPUs communicate peer-to-peer to efficiently compute the average in parallel. Modern implementations (NCCL on NVIDIA GPUs) use ring or tree algorithms.
In a ring all-reduce example with 4 GPUs, each GPU divides its gradient array into 4 chunks. During the reduce scatter phase, GPUs pass chunks in a ring pattern, accumulating sums. After 3 steps, each GPU holds the sum of one chunk across all GPUs. During the all gather phase, GPUs pass the summed chunks around the ring. After 3 more steps, every GPU has all summed chunks. Each GPU divides by 4 to get the average. The result is that all 4 GPUs end up with identical averaged gradients stored locally. No “master” GPU collects everything—it’s fully distributed.
- Timing breakdown per iteration
The forward pass takes 50ms (parallel computation, no communication). The backward pass takes 50ms (parallel computation, no communication). All-reduce gradients take 15ms (GPUs send and receive gradient chunks and wait for completion). The optimizer step takes 5ms (parallel computation using local copy of averaged gradients).
During the 15ms all-reduce, GPUs are mostly idle waiting for data transfers over PCIe or NVLink. With 4 GPUs, single-iteration time is 50+50+15+5 = 120ms instead of the ideal 120/4 = 30ms, giving a 120/120 = 1.0× speedup with overhead versus the 120/30 = 4× ideal speedup.
With 4 GPUs and efficient interconnects, expect a 3.2-3.6× speedup. The gap from the ideal 4× is the communication overhead.
Beam-tilt corrected bright-field STEM (tcBF-STEM)
What problem does tcBF-STEM solve?
The goal is to obtain high-resolution images while accounting for aberrations present in the microscope. While iterative full ptychography is a solution, it is computationally expensive. tcBF-STEM is a computationally simpler method that enables extracting and correcting for aberrations from 4D-STEM datasets.
How does tcBF-STEM work in a nutshell?
The approach extracts aberration coefficients from shifts in virtual bright-field (BF) images, uses these coefficients to correct the images, then combines multiple corrected views through upsampling to achieve high-resolution aberration-free images. Furthermore, these aberration coefficients can serve as initial values for iterative ptychography.
How is tcBF-STEM different from traditional aberration correction?
Traditional aberration correction hardware (e.g., Cs correctors) physically adjust the electron optics to minimize aberrations before imaging. In contrast, tcBF-STEM is a computational post-processing technique that extracts aberration information from 4D-STEM datasets and applies numerical corrections to the images after acquisition. This allows for correcting residual aberrations without additional hardware.
Why can tcBF-STEM achieve super-resolution?
The key insight is that each detector position \((k_x, k_y)\) provides a virtual BF image that is slightly shifted in real space. These shifts contain two types of information: systematic shifts caused by microscope aberrations (defocus, astigmatism, coma), and sub-pixel sampling information where each shifted view samples the specimen at slightly different positions, providing information between pixel centers. By first correcting for aberrations to align all images, then combining them through upsampling, you can achieve resolution finer than the original probe step size.
The tcBF-STEM algorithm
- Step 1: extract virtual BF images and measure shifts
For each \((k_x, k_y)\) detector position, create a virtual bright-field image. Then measure how much each image is shifted relative to a reference image using cross-correlation in Fourier space:
This gives the displacement field:
across all detector positions.
- Step 2: fit polynomial aberration model to shifts
Model the aberration function \(\chi(k_x, k_y)\) as a polynomial and find the coefficients \(a_i\) that best explain the measured shifts. The measured shifts relate to the aberration function through derivatives:
\[\Delta x = \frac{\partial \chi}{\partial k_x}, \quad \Delta y = \frac{\partial \chi}{\partial k_y}\]Expand the phase function as a polynomial:
\[\chi(k_x, k_y) = a_0 + a_1 k_x + a_2 k_y + a_3(k_x^2 + k_y^2) + a_4(k_x^2 - k_y^2) + a_5(2k_x k_y) + \ldots\]Find coefficients \(a_i\) by minimizing squared error:
\[\min_{a_i} \sum_{(k_x, k_y)} \left[ (\Delta x_{\text{meas}} - \partial \chi_{\text{model}}/\partial k_x)^2 + (\Delta y_{\text{meas}} - \partial \chi_{\text{model}}/\partial k_y)^2 \right]\]The coefficients correspond to physical aberrations: \(a_0\) (phase offset), \(a_1, a_2\) (beam tilt), \(a_3\) (defocus), \(a_4, a_5\) (astigmatism), and higher-order terms (coma, spherical aberration).
- Step 3: correct images by reversing shifts
For each detector position, undo the aberration-induced displacement by applying the inverse shift in Fourier space using the Fourier shift theorem:
\[I_{\text{aligned}}(x, y) = \mathcal{F}^{-1}\left[\mathcal{F}\{I_{\text{BF}}(x, y)\} \cdot e^{-2\pi i (f_x \Delta x + f_y \Delta y)}\right]\]where \((f_x, f_y)\) are the Fourier frequency coordinates. This translates the real-space image by \(-\Delta \mathbf{r}\):
\[I_{\text{aligned}}(x, y) = I_{\text{BF}}(x + \Delta x, y + \Delta y)\]This aligns all virtual BF images so they can be coherently combined.
Note
To learn more about cross-correlation and the Fourier shift theorem, refer to the relevant sections in the Fourier page.
- Step 4: upsample using kernel density estimation
Combine all aligned images onto a fine regular grid using KDE:
\[I_{\text{tcBF}}(x, y) = \frac{\sum_{i} I_i(x - \Delta x_i, y - \Delta y_i) \cdot K(d_i)}{\sum_{i} K(d_i)}\]where \(d_i = \sqrt{\Delta x_i^2 + \Delta y_i^2}\) is the aberration shift magnitude and \(K(d) = \exp(-d^2/2\sigma^2)\) is the Gaussian kernel. This achieves super-resolution by combining overlapping views with different shift offsets.
Why don’t we integrate the displacement field directly to find aberrations?
Integration of measured displacements would be noisy. Instead, fit a parametric polynomial model whose derivatives match the measured shifts:
This directly extracts aberration coefficients \(a_i\) without noisy integration [].
Why use KDE for tcBF-STEM?
KDE provides three key advantages. It handles irregular sampling by naturally interpolating between images shifted by different amounts. It preserves sub-pixel information by weighting all nearby measurements by distance, extracting maximum information. It enables super-resolution by combining multiple shifted views to reconstruct features finer than the probe spacing.
The bandwidth \(\sigma\) controls smoothness: too small creates noise, too large blurs features. Optimal choice is \(\sigma \approx\) half the average shift spacing.
Why are virtual BF images shifted at different detector positions?
By reciprocity, extracting a virtual BF image from detector angle \((k_x, k_y)\) corresponds to tilted plane-wave illumination. Different tilt angles produce shifted real-space images due to aberrations (defocus, astigmatism, coma) and wave propagation geometry. These shifts encode both aberration information and sub-pixel sampling information.