Skip to content

Ambient Occlusion Objective in DepthDif

This document describes changes implemented to go from the standard DepthDif and the occlusion branch to a re-implementation of the Ambient Diffusion training objective from:

TL;DR

Before this change, the model saw one masked input and trained mostly on the pixels that were already missing.

Now, during training only, we do one extra step: we hide some of the still-visible pixels again at random. So the model gets a harder, more incomplete input.

The important part is the loss target: ambient mode now scores the model on the degraded x target only where the input is actually valid and the GLORYS depth support is valid.

Why this matters: it avoids a weak objective where the model can learn shortcuts from the currently visible subset, and instead teaches it to recover stable structure under extra random occlusion.

Inference/sampling did not change. This is a training-objective change only.

Visual Walkthrough

Ambient objective step-by-step example

This figure shows one real training sample and the exact split that ambient occlusion creates between observations the model still receives and observations that are temporarily hidden from it.

From left to right, top to bottom:

  1. Argo max over channels: the original sparse Argo observations in the patch, collapsed across depth so the observed track pattern is easy to see.
  2. Points model still sees: the subset of those Argo observations that remains after the extra ambient masking step. These are the measurements still available to the model as input.
  3. Points withheld from model: the Argo observations that were originally present in the sample but were removed by the extra ambient masking step.
  4. Seen/withheld mask: the same split shown as a categorical mask, with green for pixels the model still sees, red for pixels withheld from the model, and black for locations with no Argo observation in the first place.
  5. OSTIA: the surface conditioning field for the same patch.
  6. GLORYS surface level: the target ocean field at the shallowest GLORYS level for the same patch, plotted on the same blue-to-red color scale as OSTIA.

What matters in this figure is the contrast between panels 1, 2, and 3: ambient occlusion does not invent a new sampling pattern from scratch. It starts from the real sparse Argo observations already present in the dataset, then removes an additional subset of them so the model has to work from a stricter, harder version of the same sample.

1. Top-Level Perspective

DepthDif previously trained a conditional diffusion model with a single corruption stage (dataset occlusion mask) and (typically) a loss focused on missing pixels.

The new procedure adds a second stochastic corruption stage during training:

  1. Start from the original observation mask \(A\) (from x_valid_mask).
  2. Sample an additional random keep/drop operator \(B\).
  3. Form a further-corrupted mask \(\tilde{A} = B \odot A\).
  4. Feed the model condition built from \(\tilde{A}\)-corrupted input.
  5. Supervise the prediction on the ambient support mask \(S = A \odot Y\), where \(A\) is x_valid_mask and \(Y\) is y_valid_mask.

Intuition: the model is forced to reconstruct original observed x values from a stricter subset of those same observations.

2. Notation

For one sample:

  • \(x_0 \in \mathbb{R}^{C \times H \times W}\): clean diffusion target (in ambient mode in this repo: normalized x).
  • \(A \in \{0,1\}^{C \times H \times W}\): original validity/observation mask (x_valid_mask).
  • \(x = A \odot x_0\): original sparse observed input (in this repo, x already carries this structure).
  • \(t \sim \mathrm{Unif}\{0,\dots,T-1\}\): diffusion timestep.
  • \(x_t = \sqrt{\bar{\alpha}_t}\,x_0 + \sqrt{1-\bar{\alpha}_t}\,\epsilon,\ \epsilon\sim\mathcal{N}(0,I)\): noisy target branch sample.
  • \(B \in \{0,1\}^{C \times H \times W}\): further keep mask sampled with keep-probability \(1-\delta\) on observed entries.
  • \(\tilde{A} = B \odot A\): further-corrupted observation mask.

In implementation, \(\delta =\) model.ambient_occlusion.further_drop_prob.

3. Previous Objective (Repository Before This Change)

With mask_loss_with_valid_pixels=true, the loss was computed on missing pixels:

\[ \mathcal{L}_{\text{prev}}(\theta) = \frac{ \left\|(1-A)\odot\left(\text{target}_t-\hat{x}_{\theta}\right)\right\|_2^2 }{ \|(1-A)\|_1 }, \]

where:

  • \(\hat{x}_{\theta}\) is the denoiser output.
  • \(\text{target}_t = x_0\) for parameterization="x0" or \(\epsilon\) for parameterization="epsilon".

Conditioning used the original sparse input/mask pair \((x, A)\) (plus EO, if enabled), without extra stochastic masking during training.

4. New Ambient Objective

4.1 Training Inputs

Define:

\[ \tilde{x} = \tilde{A}\odot x. \]

The model condition is built from \((\tilde{x}, \tilde{A}, \text{EO})\) instead of \((x, A, \text{EO})\).

Optionally (enabled by default), the noisy branch is also masked:

\[ \tilde{x}_t = \tilde{A}\odot x_t. \]

4.2 Loss Region and Target

The implemented ambient mode uses the degraded input target x, and supervises it only where the input and target supports both make sense. The ambient supervision mask is:

\[ S = A \odot Y, \]

where \(A\) is x_valid_mask and \(Y\) is y_valid_mask. This means ambient supervision stays on the actually valid input pixels after degradation, for both synthetic and proper Argo ambient inputs.

\[ \mathcal{L}_{\text{ambient}}(\theta) = \frac{ \left\|S\odot\left(\text{target}_t-\hat{x}_{\theta}\right)\right\|_2^2 }{ \|S\|_1 }. \]

land_mask no longer changes the diffusion loss support in the current training path; the implemented loss mask is exactly the task-valid mask \(S = A \odot Y\).

4.3 Relation to Paper Objective

The procedure matches the paper’s core structure:

\[ J_{\mathrm{corr}}(\theta)=\frac12\,\mathbb{E}\left[\left\|A\left(h_{\theta}(\tilde{A}\,x_t,\tilde{A},t)-x_0\right)\right\|_2^2\right], \]

up to the repository’s existing normalization/parameterization conventions and per-mask normalization by mask cardinality. In this repository, that clean target is the original sparse-observation tensor x for ambient mode, while the standard non-ambient path continues to target y.

5. What Changed vs What Stayed the Same

Changed

  1. Two-stage masking during training (\(A \rightarrow \tilde{A}\)).
  2. Condition path uses \(\tilde{A}\) and \(\tilde{x}\).
  3. Diffusion target switches to original x in ambient mode.
  4. Loss region switches to the valid ambient support mask \(A \odot Y\) in ambient mode.
  5. New ambient metrics are logged:
  6. train/ambient_further_drop_fraction
  7. train/ambient_observed_fraction_original
  8. train/ambient_observed_fraction_further
  9. same keys under val/*.

Unchanged

  1. Inference/sampler algorithms (DDPM/DDIM) are unchanged.
  2. Dataset generation of original corruption mask \(A\) is unchanged.
  3. The dataset still returns the horizontal land_mask in the batch contract, but the diffusion loss is driven by task-valid masks instead of land-mask gating.
  4. If ambient mode is disabled, the model and samplers stay the same while the objective reverts to direct y reconstruction over y_valid_mask.

6. Implemented Safety and Constraints

  1. Parameterization guard: if ambient mode is enabled and require_x0_parameterization=true, then parameterization must be "x0"; otherwise construction raises ValueError.
  2. Mask monotonicity: \(\tilde{A} \le A\) elementwise by construction.
  3. Degeneracy guard: at least min_kept_observed_pixels are kept per sample when possible, preventing empty effective supervision from the further corruption stage.
  4. shared_spatial_mask=true enforces one spatial \(B\) per sample, broadcast across channels.

7. Code Mapping (Equation to Implementation)

  • Ambient config surface:
  • configs/px_space/model_config.yaml (model.ambient_occlusion.*)
  • Runtime config wiring and safety:
  • models/difFF/PixelDiffusion.py
    • PixelDiffusionConditional.from_config(...)
    • PixelDiffusionConditional.__init__(...)
  • \(\tilde{A}\) construction:
  • PixelDiffusionConditional._build_ambient_further_valid_mask(...)
  • Condition path replacement \((x,A)\to(\tilde{x},\tilde{A})\) :
  • training_step(...), validation_step(...)
  • Ambient loss execution:
  • models/difFF/DenoisingDiffusionProcess/DenoisingDiffusionProcess.py
    • DenoisingDiffusionConditionalProcess.p_loss(...)
    • ambient loss mask = x_valid_mask intersected with y_valid_mask
    • optional apply_further_corruption_to_noisy_branch

8. Practical Interpretation

The old setup primarily asked the model to reconstruct hidden regions given fixed observed context.

The new setup introduces random context removal during training while preserving supervision on the original observed support. This makes the learning problem closer to the ambient objective: robustly estimate clean content under stochastic measurement degradation, not only under one fixed missingness pattern per sample.