Inference¶
There are two practical inference workflows in this repository:
- run the standalone script inference.py
- call PixelDiffusionConditional.predict_step(...) directly
Workflow 1: Use inference.py¶
inference.py is a configurable script for quick prediction sanity checks.
What it supports¶
- load config files and instantiate model/datamodule
- load checkpoint (explicit override or
model.resume_checkpoint) - run from:
- dataloader sample (
MODE="dataloader") - synthetic random batch (
MODE="random") - optional intermediate sample capture
Important script settings¶
At the top of inference.py, set:
- MODEL_CONFIG_PATH
- DATA_CONFIG_PATH
- TRAIN_CONFIG_PATH
- CHECKPOINT_PATH (or keep None to use config resume path)
- MODE, LOADER_SPLIT, DEVICE, INCLUDE_INTERMEDIATES
Note on default paths¶
The script constants should be set explicitly. In this repository, the actively used configs are:
- EO setup: configs/*_eo_4band.yaml
- legacy single-band setup: configs/older_configs/*.yaml
Workflow 2: Direct predict_step¶
The model inference entry point is:
- PixelDiffusionConditional.predict_step(batch, batch_idx=0)
Minimum required batch key:
- x
Common optional keys:
- eo
- valid_mask
- land_mask
- coords
- date
- sampler
- clamp_known_pixels
- return_intermediates
- intermediate_step_indices
Returned outputs¶
predict_step returns a dictionary containing:
- y_hat: standardized prediction
- y_hat_denorm: temperature-denormalized prediction
- denoise_samples: reverse samples (if requested)
- x0_denoise_samples: per-step x0 predictions (if requested)
- sampler: sampler used for prediction
Example (EO config)¶
import torch
from data.datamodule import DepthTileDataModule
from data.dataset_4bands import SurfaceTempPatch4BandsLightDataset
from models.difFF import PixelDiffusionConditional
model_config = "configs/model_config_eo_4band.yaml"
data_config = "configs/data_config_eo_4band.yaml"
train_config = "configs/training_config_eo_4band.yaml"
ckpt_path = "logs/<run>/best-epochXXX.ckpt"
dataset = SurfaceTempPatch4BandsLightDataset.from_config(data_config, split="all")
datamodule = DepthTileDataModule(dataset=dataset)
datamodule.setup("fit")
model = PixelDiffusionConditional.from_config(
model_config_path=model_config,
data_config_path=data_config,
training_config_path=train_config,
datamodule=datamodule,
)
state = torch.load(ckpt_path, map_location="cpu")
state_dict = state["state_dict"] if "state_dict" in state else state
model.load_state_dict(state_dict, strict=False)
model.eval()
batch = next(iter(datamodule.val_dataloader()))
with torch.no_grad():
pred = model.predict_step(batch, batch_idx=0)
y_hat = pred["y_hat"]
y_hat_denorm = pred["y_hat_denorm"]
Sampler Choice¶
Validation/inference sampler can be switched via training config:
- training.validation_sampling.sampler: "ddpm" or "ddim"
- DDIM controls:
- ddim_num_timesteps
- ddim_eta
The same sampler can also be injected per batch through batch["sampler"] in direct prediction calls.