Skip to content

Inference

There are two practical inference workflows in this repository:
- run the standalone script inference.py
- call PixelDiffusionConditional.predict_step(...) directly

DepthDif supports pixel-space configs (configs/px_space/*) and latent-workflow configs (configs/lat_space/*).
For latent workflow setup and command flow, see Autoencoder + Latent Diffusion.

Workflow 1: Use inference.py

inference.py is a configurable script for quick prediction sanity checks.

What it supports

  • load config files and instantiate model/datamodule
  • load checkpoint (explicit override or model.load_checkpoint / model.resume_checkpoint)
  • run from:
  • dataloader sample (MODE="dataloader")
  • synthetic random batch (MODE="random")
  • optional intermediate sample capture

Important script settings

At the top of inference.py, set:
- MODEL_CONFIG_PATH
- DATA_CONFIG_PATH
- TRAIN_CONFIG_PATH
- CHECKPOINT_PATH (or keep None to use model.load_checkpoint then model.resume_checkpoint)
- MODE, LOADER_SPLIT, DEVICE, INCLUDE_INTERMEDIATES

Note on default paths

The script constants should be set explicitly. In this repository, the actively used configs are:
- OSTIA + Argo disk setup: configs/px_space/model_config.yaml, configs/px_space/data_ostia_argo_disk.yaml, configs/px_space/training_config.yaml

Workflow 2: Direct predict_step

The model inference entry point is:
- PixelDiffusionConditional.predict_step(batch, batch_idx=0)

Minimum required batch keys:
- x
- x_valid_mask
- y_valid_mask

Common optional keys:
- eo
- x_valid_mask_1d
- land_mask
- coords
- date
- sampler
- clamp_known_pixels
- return_intermediates
- intermediate_step_indices

Returned outputs

predict_step returns a dictionary containing:
- y_hat: standardized prediction
- y_hat_denorm: temperature-denormalized prediction, masked to NaN where y_valid_mask==0
- denoise_samples: reverse samples (if requested)
- x0_denoise_samples: per-step x0 predictions (if requested)
- sampler: sampler used for prediction

Example (ostia_argo_disk config)

import torch  

from data.datamodule import DepthTileDataModule  
from data.dataset_ostia_argo_disk import OstiaArgoTiffDataset  
from models.difFF import PixelDiffusionConditional  

model_config = "configs/px_space/model_config.yaml"  
data_config = "configs/px_space/data_ostia_argo_disk.yaml"  
train_config = "configs/px_space/training_config.yaml"  
ckpt_path = "logs/<run>/best-epochXXX.ckpt"  

train_dataset = OstiaArgoTiffDataset.from_config(data_config, split="train")  
val_dataset = OstiaArgoTiffDataset.from_config(data_config, split="val")  
datamodule = DepthTileDataModule(dataset=train_dataset, val_dataset=val_dataset)  

model = PixelDiffusionConditional.from_config(  
    model_config_path=model_config,  
    data_config_path=data_config,  
    training_config_path=train_config,  
    datamodule=datamodule,  
)  

state = torch.load(ckpt_path, map_location="cpu")  
state_dict = state["state_dict"] if "state_dict" in state else state  
model.load_state_dict(state_dict, strict=False)  
model.eval()  

batch = next(iter(datamodule.val_dataloader()))  
with torch.no_grad():  
    pred = model.predict_step(batch, batch_idx=0)  

y_hat = pred["y_hat"]  
y_hat_denorm = pred["y_hat_denorm"]  

Sampler Choice

Validation/inference sampler can be switched via training config:
- training.validation_sampling.sampler: "ddpm" or "ddim"
- DDIM controls:
- ddim_num_timesteps
- ddim_eta

The same sampler can also be injected per batch through batch["sampler"] in direct prediction calls.