Skip to content

Model Settings

This page maps key configuration flags to their runtime behavior in code.

Primary config files used in current EO setup:
- configs/px_space/data_ostia_argo_disk.yaml (OSTIA + Argo + GLORYS GeoTIFF disk preset)
- configs/px_space/model_config.yaml
- configs/px_space/model_config_ambient.yaml
- configs/px_space/training_config.yaml

Latent-space config set:
- configs/lat_space/model_config.yaml
- configs/lat_space/training_config.yaml
- configs/lat_space/data_config.yaml
- configs/lat_space/ae_config.yaml

See Autoencoder + Latent Diffusion for latent architecture and training workflow.

Major Settings

Conditioning channels

Config (model_config_*):
- model.generated_channels
- model.condition_channels
- model.condition_mask_channels
- model.condition_include_eo
- model.condition_use_valid_mask

Runtime effect:
- controls how condition = [eo?, x, x_valid_mask?] is assembled
- channel count is validated against expected condition_channels

Diffusion target parameterization

Config:
- model.parameterization: epsilon or x0

Runtime effect:
- defines target in diffusion loss and sampler conversions
- current EO config uses x0

Masked loss

Config:
- model.mask_loss_with_valid_pixels

Runtime effect:
- if enabled, loss uses the task-valid support instead of the old missing-pixel mask
- standard mode: y_valid_mask over the full y target
- ambient mode: x_valid_mask ∩ y_valid_mask over the x target
- mask alignment preserves per-band semantics (B x C x H x W) unless a single shared mask channel is explicitly used

Inference output composition

Runtime effect:
- direct y prediction keeps the generated field and masks invalid y_valid_mask support to NaN
- ambient x completion leaves known-pixel enforcement to clamp_known_pixels during sampling
- both modes then mask invalid y_valid_mask support to NaN

Known-pixel clamping during sampling

Config:
- model.clamp_known_pixels

Runtime effect:
- if enabled and known masks/values are available, known pixels are overwritten each reverse step
- useful for inpainting-style stability

Illustration:
img

Coordinate/date FiLM conditioning

Config:
- data: dataset.output.return_coords
- model:
- coord_conditioning.enabled
- coord_conditioning.encoding
- coord_conditioning.include_date
- coord_conditioning.date_encoding
- coord_conditioning.embed_dim

Runtime effect:
- creates a coordinate/date embedding and injects it via FiLM in ConvNeXt blocks
- details: Data + Coordinate Injection

Training and Optimization Settings

Noise schedule and diffusion steps

Config (training.noise):
- num_timesteps
- schedule: linear, cosine, quadratic, sigmoid
- beta_start, beta_end

Validation sampling mode

Config (training.validation_sampling):
- sampler: ddpm or ddim
- ddim_num_timesteps, ddim_eta
- log_intermediates

Runtime effect:
- training loss still uses forward noising objective
- full reverse sampling diagnostics use chosen validation sampler

Learning-rate warmup and plateau scheduler

Config (scheduler):
- warmup.enabled, warmup.steps, warmup.start_ratio
- reduce_on_plateau.enabled
- reduce_on_plateau.monitor, mode, factor, patience, threshold, cooldown

Runtime effect:
- warmup is applied per optimizer step in optimizer_step
- plateau scheduler is applied on epoch-level monitored metric

Trainer/Runtime Controls

Config (trainer):
- hardware/precision: accelerator, devices, optional num_gpus, precision
- logging/checkpoint cadence: log_every_n_steps, ckpt_monitor, lr_logging_interval
- validation load: val_batches_per_epoch or limit_val_batches
- stability knobs: gradient_clip_val, warning suppressions

Dataloader Settings

Config (dataloader):
- batch_size, val_batch_size
- num_workers, val_num_workers
- shuffle, val_shuffle
- pin_memory, persistent_workers, prefetch_factor

Runtime notes:
- prefetch_factor is only applied when num_workers > 0
- validation shuffle defaults to true in DataModule unless explicitly changed

Logging Settings (W&B)

Config (wandb):
- project/entity/run naming
- model logging policy
- watch toggles (watch_gradients, watch_parameters)
- scalar/image logging intervals

Runtime notes:
- watch mode is resolved from explicit gradient/parameter toggles
- config files used in the run are uploaded to W&B run files when possible

FUll settings documentation

This section contains the complete key-by-key configuration reference previously documented on the separate Configs page.

Dataset Configs (configs/px_space/data_ostia_argo_disk.yaml)

Dataset settings are grouped by intent (core, validity, degradation, conditioning, augmentation, output, runtime).
Defaults below refer to configs/px_space/data_ostia_argo_disk.yaml unless noted.

Config key Default value Explanation
dataset.core.dataset_variant "ostia_argo_disk" Selects dataset in train.py ("eo_4band", "ostia", or "ostia_argo_disk").
dataset.core.dataloader_type "light" The current training runner supports only "light" loading.
dataset.core.manifest_csv_path "/work/data/depth_prod/ostia_argo_tiff_index.csv" Manifest CSV used by OstiaArgoTiffDataset; paths inside it are resolved relative to the manifest location.
dataset.output.return_info false Returns per-sample metadata under batch["info"].
dataset.output.return_coords true Returns patch-center coordinates under batch["coords"].
dataset.runtime.random_seed 7 Seed used for deterministic split and random dataset sampling behavior.
split.val_fraction 0.2 Fraction of dataset reserved for validation.

configs/px_space/model_config.yaml

Config key Default value Explanation
model.model_type "cond_px_dif" Model type ("cond_px_dif" for pixel diffusion, "latent_cond_dif" for latent diffusion with AE bridge).
model.resume_checkpoint false false/null starts from scratch; checkpoint path resumes training.
model.load_checkpoint false false/null disables warm start; checkpoint path loads model state_dict only (no Lightning optimizer/trainer resume).
model.generated_channels 50 Number of predicted GLORYS depth channels.
model.condition_channels 52 Condition channel count: OSTIA EO (1) + corrupted Argo stack (50) + collapsed x_valid_mask (1).
model.condition_mask_channels 1 Number of x_valid_mask condition channels.
model.condition_include_eo true Includes batch["eo"] as condition input.
model.condition_use_valid_mask true Includes x_valid_mask in condition input.
model.clamp_known_pixels false Clamps known pixels each reverse step for inpainting-style stability.
model.mask_loss_with_valid_pixels true Computes loss on the task-valid supervision mask (y_valid_mask in standard mode, x_valid_mask ∩ y_valid_mask in ambient mode).
model.parameterization "x0" Diffusion training target ("epsilon" or "x0").
model.log_intermediates true Default validation intermediate logging behavior.
model.ambient_occlusion.enabled false Enables ambient-diffusion style occlusion objective (further-corrupt input, supervise x on x_valid_mask ∩ y_valid_mask).
model.ambient_occlusion.further_drop_prob 0.1 Additional drop probability delta applied on already observed pixels during training.
model.ambient_occlusion.apply_to_noisy_branch true Applies the further mask to the noisy target branch in p_loss (~A x_t).
model.ambient_occlusion.shared_spatial_mask true Uses one spatial further-mask per sample and shares it across channels.
model.ambient_occlusion.min_kept_observed_pixels 1 Guarantees a minimum number of observed pixels kept after further corruption.
model.ambient_occlusion.require_x0_parameterization true Enforces model.parameterization == "x0" when ambient objective is enabled.
model.post_process.gaussian_blur.enabled false Enables final denormalized Gaussian blur post-process.
model.post_process.gaussian_blur.sigma 0.5 Gaussian blur sigma in pixels.
model.post_process.gaussian_blur.kernel_size 3 Blur kernel size; even values are adjusted to odd.
model.coord_conditioning.enabled true Enables coordinate conditioning with FiLM.
model.coord_conditioning.encoding "unit_sphere" Coordinate encoding type ("unit_sphere", "sincos", "raw").
model.coord_conditioning.include_date true Includes date encoding with coordinates.
model.coord_conditioning.date_encoding "day_of_year_sincos" Date encoding mode (day-of-year sin/cos, denominator 365).
model.coord_conditioning.embed_dim null FiLM embedding dimension; defaults to unet.dim when null.
model.unet.dim 64 Base channel width of U-Net denoiser.
model.unet.dim_mults [1, 2, 4, 8] Per-stage width multipliers; controls depth/width scaling.
model.unet.with_time_emb true Enables timestep embeddings in denoiser.
model.unet.output_mean_scale false Optional output mean correction for diffusion variants.
model.unet.residual false If enabled, predicts residual added to input.

Detailed objective math, implementation mapping, visualization, and citation: Ambient Occlusion Objective.

configs/px_space/training_config.yaml

Config key Default value Explanation
training.lr 1.0e-4 Optimizer learning rate.
training.batch_size 4 Informational training batch size (dataloader section is source of truth).
training.noise.num_timesteps 1000 Number of diffusion timesteps.
training.noise.schedule "cosine" Noise schedule: linear, cosine, quadratic, sigmoid.
training.noise.beta_start 1.0e-4 First-step noise level (must be positive and below beta_end).
training.noise.beta_end 2.0e-2 Final-step noise level (must be below 1 and above beta_start).
training.validation_sampling.sampler "ddim" Validation sampler (ddpm full chain, ddim faster).
training.validation_sampling.ddim_num_timesteps 100 DDIM steps when sampler="ddim".
training.validation_sampling.ddim_eta 0.0 DDIM eta; 0.0 is deterministic DDIM.
training.validation_sampling.log_intermediates false Captures/logs denoising intermediate images in validation.
training.validation_sampling.skip_full_reconstruction_in_sanity_check true Skips expensive full reconstruction during Lightning sanity checks when true.
training.validation_sampling.max_full_reconstruction_samples 2 Max first-batch val samples used for full reconstruction pass.
trainer.max_epochs 1500 Maximum training epochs.
trainer.accelerator "auto" Lightning accelerator backend selection.
trainer.devices "auto" Device selection (auto, int, list).
trainer.num_gpus null Legacy explicit GPU count override; null leaves accelerator/devices in control.
trainer.strategy "auto" Distributed strategy selection.
trainer.precision "16-mixed" Mixed precision mode.
trainer.matmul_precision "high" torch.set_float32_matmul_precision mode.
trainer.suppress_accumulate_grad_stream_mismatch_warning true Suppresses PyTorch stream mismatch warning noise.
trainer.suppress_lightning_pytree_warning true Suppresses Lightning LeafSpec deprecation warning noise.
trainer.ckpt_monitor "val/loss_ckpt" Metric monitored for best-checkpoint saving.
trainer.lr_logging_interval "step" Learning-rate logging cadence (step or epoch).
trainer.log_every_n_steps 25 Trainer logging interval in steps.
trainer.num_sanity_val_steps 1 Number of startup sanity-validation steps.
trainer.limit_val_batches 4 Number/fraction of validation batches per epoch.
trainer.enable_model_summary true Enables Lightning model summary printout.
trainer.gradient_clip_val 1.0 Gradient clipping threshold (0.0 disables).
wandb.project "DepthDif_Simon" W&B project name.
wandb.entity "esa-phi-lab" W&B entity/team (null uses default account).
wandb.run_name "ostia_argo_disk_px" Explicit run name.
wandb.log_model "false" W&B model artifact logging policy.
wandb.verbose true Enables extra metric/image logging.
wandb.watch_gradients false Enables gradient history logging via wandb.watch.
wandb.watch_parameters false Enables parameter history logging via wandb.watch.
wandb.watch_log_freq 100 wandb.watch logging frequency in steps.
wandb.watch_log_graph false Logs computation graph when watch is enabled.
wandb.log_stats_every_n_steps 200 Step interval for scalar debug stats.
wandb.log_images_every_n_steps 200 Step interval for validation preview images.
dataloader.batch_size 4 Training dataloader batch size.
dataloader.val_batch_size 2 Validation batch size (falls back to batch_size if omitted).
dataloader.num_workers 4 Number of training dataloader workers.
dataloader.val_num_workers 0 Validation workers (0 avoids h5netcdf sanity-check instability).
dataloader.persistent_workers true Keeps train workers alive across epochs when true.
dataloader.val_persistent_workers false Validation worker persistence (when val_num_workers > 0).
dataloader.prefetch_factor 2 Prefetched batches per worker (only used when workers > 0).
dataloader.shuffle true Shuffles training dataset each epoch.
dataloader.val_shuffle false Shuffles validation set (often used with limited val batches).
dataloader.pin_memory true Enables pinned host memory for faster H2D transfer.
scheduler.warmup.enabled true Enables linear warmup before plateau scheduling.
scheduler.warmup.steps 2000 Warmup step count to ramp LR from start_ratio to base LR.
scheduler.warmup.start_ratio 0.2 Initial warmup LR as ratio of training.lr.
scheduler.reduce_on_plateau.enabled true Enables ReduceLROnPlateau.
scheduler.reduce_on_plateau.monitor "val/loss_ckpt" Metric monitored for LR reduction.
scheduler.reduce_on_plateau.mode "min" Plateau mode (min or max).
scheduler.reduce_on_plateau.factor 0.5 Multiplicative LR decay factor on plateau.
scheduler.reduce_on_plateau.patience 20 Validation epochs with no improvement before reducing LR.
scheduler.reduce_on_plateau.threshold 1.0e-4 Minimum significant metric change.
scheduler.reduce_on_plateau.threshold_mode "rel" Threshold mode (rel or abs).
scheduler.reduce_on_plateau.cooldown 0 Epoch cooldown after LR reduction.
scheduler.reduce_on_plateau.min_lr 1.0e-6 Lower bound for LR.
scheduler.reduce_on_plateau.eps 1.0e-8 Minimum effective LR change.