Skip to content

Model

DepthDif uses a conditional pixel-space diffusion model implemented in models/difFF/PixelDiffusion.py.

Model schema:
depthdif_schema

Core stack:
- Lightning wrapper: PixelDiffusionConditional
- diffusion core: DenoisingDiffusionConditionalProcess
- denoiser backbone: UnetConvNextBlock (ConvNeXt-style U-Net)

The model learns to generate y while conditioning on observed channels (x, optional eo, optional mask channels).

Conditioning Setup

Two conditioning layouts are supported by code/config:

  • Single-band task: x -> y
  • EO multiband task: [eo, x, x_valid_mask] -> y

Condition assembly happens in _prepare_condition_for_model:
- optionally prepend eo (condition_include_eo=true)
- append data channels from x
- optionally append x_valid_mask channels (condition_use_valid_mask=true)
- enforce channel count equals model.condition_channels

Architecture Summary

UnetConvNextBlock follows a U-Net encoder/decoder with ConvNeXt blocks and linear attention.

With default dim_mults=[1,2,4,8]:
- 4 downsampling stages
- bottleneck block with attention
- 3 upsampling stages with skip connections
- final ConvNeXt block + 1x1 output conv to generated_channels

For the ambient EO preset in configs/px_space/model_config_ambient.yaml, the U-Net base width is increased to dim: 96. This keeps the same depth (dim_mults=[1,2,4,8]) but gives the denoiser more capacity when moving from earlier low-channel setups to the current 50 generated channels + 52 condition channels.

Time conditioning:
- sinusoidal timestep embedding -> MLP -> additive bias in ConvNeXt blocks

Coordinate/date conditioning (when enabled):
- per-channel FiLM scale/shift in ConvNeXt blocks
- details in Data + Coordinate Injection

Training Objective

Training step (training_step) calls conditional diffusion p_loss on standardized temperature tensors.

Behavior:
- sample random timestep t
- forward diffuse the selected training target to the noisy target branch
- predict either:
- noise (epsilon parameterization), or
- clean sample (x0 parameterization)

Loss options:
- unmasked MSE (default behavior when masking disabled)
- masked MSE with mode-specific supervision support:
- standard mode: over y_valid_mask on the full y target
- ambient mode: over x_valid_mask intersected with y_valid_mask on the degraded x target
- the horizontal land_mask remains available in the batch contract, but the task-valid masks already define the supervised support

Ambient occlusion objective (model.ambient_occlusion.enabled: true):
- sample an additional Bernoulli keep-mask over already observed pixels (~A = B * A)
- feed the model a further-corrupted condition (x_tilde = x * ~A) and ~A as condition mask
- switch the diffusion target from y to the original sparse-observation tensor x
- optionally apply ~A to noisy target branch during p_loss (~A * x_t)
- compute masked MSE on the originally valid x support intersected with valid y support (A ∩ Y, not ~A)
- detailed walkthrough and citation: Ambient Occlusion Objective

Current EO config (configs/px_space/model_config.yaml) uses:
- parameterization: "x0"
- mask_loss_with_valid_pixels: true

Latent model workflow is configured via configs/lat_space/model_config.yaml with AE controls in configs/lat_space/ae_config.yaml; see Autoencoder + Latent Diffusion for the full setup.

This means: if ambient mode is disabled, training loss is pulled over all valid y pixels via y_valid_mask.

Inference Flow

Prediction entry point is predict_step.

At inference:
- build condition tensor from batch inputs
- start reverse process from Gaussian latent
- keep condition fixed during reverse sampling
- use configured sampler (ddpm by default, ddim optional)
- optional known-pixel clamping can overwrite known pixels each step

Output dictionary from predict_step:
- y_hat: standardized model output
- y_hat_denorm: denormalized output
- denoise_samples: optional intermediate reverse samples
- x0_denoise_samples: optional per-step x0 predictions
- sampler: sampler object used

Post-Processing in Lightning Inference

After denormalization, inference can apply:
- optional Gaussian blur (model.post_process.gaussian_blur.*)
- direct y prediction: keep the generated field and set y_valid_mask==0 pixels to NaN
- ambient x completion: return the model prediction as-is after optional sampler-time clamp_known_pixels, then set y_valid_mask==0 pixels to NaN

This post-processing is centralized in predict_step.

Validation Diagnostics

Validation computes two paths:
- per-batch validation loss (validation_step) using the same objective as training
- one full reverse-diffusion reconstruction per epoch from the global-rank-0 cached first validation batch (on_validation_epoch_end)

When available, full reconstruction logging includes:
- MSE
- PSNR/SSIM (if skimage is installed)
- qualitative reconstruction grid
- denoising-intermediate grid and MAE-vs-step curve (when intermediates enabled)
- reconstruction plotting keeps the unmerged model prediction panel and masks invalid output support through y_valid_mask
- these epoch-end diagnostics stay rank-local on global rank 0 to avoid DDP logging mismatches for optional metrics like PSNR/SSIM