JDD¶
JDD is the joint denoise and demosaic module. It learns the mapping from noisy packed RAW bursts to clean linear RGB:
where \(T=3\) in the current training setup, \(R_t^{noisy}\) are independently degraded RAW frames, and \(M_\sigma\) is a noise standard-deviation map produced by raw-sim.
Task Definition¶
The input is a packed RAW burst plus one noise-map channel:
For Bayer or 2x2 binning RAW, each frame has 4 packed channels, so:
For Quad Bayer RAW, each frame has 16 packed channels:
The target is clean linear RGB at the full spatial resolution:
Random Degradation¶
JDD does not store a fixed noisy dataset by default. During training, it samples RGB patches and generates RAW degradation online through the public raw-sim batch interface.
For each sample:
- Randomly choose an RGB source image.
- Randomly crop a patch with size
data.patch_size. - Sample analog gain from the camera module.
- Convert sRGB to clean linear RGB target.
- Run the
raw-simpipeline to produce three noisy RAW frames. - Build a noise map from the same calibrated noise model.
This makes every epoch see different crop positions, noise realizations, and analog gain values.
Analog Gain Sampling¶
The camera JSON can define camera.analog_gain as a scalar, a two-element range, or an object with min and max. JDD uses camera.analog_gain_sampling to choose the sampling distribution.
Uniform sampling:
Log-uniform sampling:
The current ISO convention is inherited from raw-sim:
Because the Poisson-Gaussian noise model scales with analog gain, random gain sampling exposes the model to a broad noise range instead of a single fixed ISO.
Burst Image Restoration¶
For each RGB patch, JDD synthesizes three noisy RAW frames from the same clean signal:
where \(\mathcal{D}\) is the raw-sim degradation pipeline and \(\epsilon_t\) is the frame-specific random noise seed.
The three frames share the same crop, camera parameters, analog gain, CFA layout, and lens filter, but use independent noise samples. This trains the network to use burst redundancy for denoising while also learning demosaicing and RAW-to-linear-RGB reconstruction.
Training Data Path¶
When data.simulate_on_device is enabled, the DataLoader only returns:
- RGB patch
- random seed
- sampled analog gain
- metadata
The training loop then calls:
from raw_sim.batch import simulate_burst_batch_on_device
x, y = simulate_burst_batch_on_device(
rgb=batch["rgb"],
analog_gain=batch["analog_gain"],
seed=batch["seed"],
camera=camera,
frames=3,
noise_map_reduce="mean",
)
This keeps the JDD degradation path coupled to raw-sim instead of duplicating simulation logic inside JDD. Updates to lens PSF, CFA packing, or noise modeling in raw-sim automatically flow into JDD training.
Model¶
The current network is a compact NAFNet-style restoration model:
packed RAW burst + noise map
-> input convolution
-> NAF blocks
-> output convolution
-> PixelShuffle CFA upsampling
-> clean linear RGB
The PixelShuffle scale follows the CFA type:
| CFA type | Packed scale | RAW channels per frame | Input channels for 3 frames |
|---|---|---|---|
| Bayer / binning | 2 | 4 | 13 |
| Quad Bayer | 4 | 16 | 49 |
Main Training Command¶
cd JDD
bash ./scripts/train.sh --config ./configs/train_JDD_patch128_3f_iter100000.json
Useful training fields:
| Field | Purpose |
|---|---|
data.simulate_on_device |
Use GPU-side raw-sim batch degradation. |
data.frames |
Number of burst frames. Current JDD expects 3. |
data.noise_map_reduce |
Reduces per-channel noise maps before appending the map channel. |
train.batch_size |
Number of RGB patches per step. |
train.amp |
Mixed-precision training on CUDA. |
train.channels_last |
Channels-last memory format for CUDA kernels. |
train.val_every |
Validation interval. |