API Reference¶
This page is automatically generated via mkdocstrings from the current codebase.
Training & Inference Entry Points¶
depth_recon¶
depth_recon
¶
Public DepthDif inference API.
__getattr__(name)
¶
Lazily expose public inference helpers without importing the full stack.
depth_recon.inference.api¶
depth_recon.inference.api
¶
Public inference API for PyPI and notebook usage.
InferenceAssets
dataclass
¶
Local paths to model/config artifacts used by public inference.
PublicInferenceAssets
dataclass
¶
Local public inference artifacts, including the model assets and land mask.
download_argo_cli(argv=None)
¶
Run the public ARGO download console command.
download_argo_for_week(year, iso_week, output_dir, *, base_url=DEFAULT_EN4_BASE_URL, cache_dir=None, force_download=False, downloader=None, progress_callback=None)
¶
Download and extract EN4/ARGO profile files needed for one ISO week.
download_ostia_cli(argv=None)
¶
Run the public OSTIA download console command.
download_ostia_for_week(year, iso_week, output_dir, *, dataset_candidates=DEFAULT_OSTIA_DATASET_CANDIDATES, force_download=False, username=None, password=None, token=None, runner=None)
¶
Download the OSTIA SST file needed for one ISO-week inference date.
infer_week_cli(argv=None)
¶
Run the public week-inference console command.
main(argv=None)
¶
Run public API subcommands when invoked as a module.
resolve_hf_assets(*, config_repo=DEFAULT_HF_REPO_ID, revision=DEFAULT_HF_REVISION, cache_dir=None, model_config_path=DEFAULT_HF_MODEL_CONFIG, data_config_path=DEFAULT_HF_DATA_CONFIG, train_config_path=DEFAULT_HF_TRAIN_CONFIG, checkpoint_path=DEFAULT_HF_CHECKPOINT, force_download=False, downloader=None, progress_callback=None)
¶
Download or reuse configs and checkpoint from Hugging Face.
resolve_hf_land_mask(*, config_repo=DEFAULT_HF_REPO_ID, revision=DEFAULT_HF_REVISION, cache_dir=None, land_mask_path=DEFAULT_HF_LAND_MASK, force_download=False, downloader=None, progress_callback=None)
¶
Download or reuse the public land-mask GeoTIFF from Hugging Face.
resolve_public_inference_assets(*, config_repo=DEFAULT_HF_REPO_ID, revision=DEFAULT_HF_REVISION, cache_dir=None, model_config_path=DEFAULT_HF_MODEL_CONFIG, data_config_path=DEFAULT_HF_DATA_CONFIG, train_config_path=DEFAULT_HF_TRAIN_CONFIG, checkpoint_path=DEFAULT_HF_CHECKPOINT, land_mask_path=DEFAULT_HF_LAND_MASK, force_download=False, downloader=None, progress_callback=None)
¶
Resolve all public artifacts needed before ARGO/OSTIA inference.
run_argo_week_inference(year, iso_week, rectangle=None, output_root=DEFAULT_OUTPUT_ROOT, device='auto', checkpoint=None, config_repo=DEFAULT_HF_REPO_ID, *, revision=DEFAULT_HF_REVISION, cache_dir=None, argo_dir=None, ostia_dir=None, auto_download_argo=True, auto_download_ostia=True, copernicus_username=None, copernicus_password=None, copernicus_token=None, batch_size=None, land_mask_path=None, min_ocean_fraction=0.05, sigma=DEFAULT_EXPORT_GAUSSIAN_BLUR_SIGMA, sampler=None, ddim_num_timesteps=None, uncertainty_sampler=None, uncertainty_ddim_num_timesteps=None, export_uncertainty=False, uncertainty_num_samples=DEFAULT_UNCERTAINTY_NUM_SAMPLES, uncertainty_only=False, strict_load=False, force_download=False, downloader=None, progress_callback=None)
¶
Run public ARGO inference for one ISO week and return the run directory.
run_week_inference(year, iso_week, rectangle=None, output_root=DEFAULT_OUTPUT_ROOT, device='auto', checkpoint=None, config_repo=DEFAULT_HF_REPO_ID, *, revision=DEFAULT_HF_REVISION, cache_dir=None, argo_dir=None, glorys_dir=None, ostia_dir=None, sealevel_dir=None, metadata_cache_dir=None, auto_download_argo=False, auto_download_ostia=True, copernicus_username=None, copernicus_password=None, copernicus_token=None, export_ground_truth=True, full_sample_count=0, batch_size=None, land_mask_path=DEFAULT_LAND_MASK_PATH, min_ocean_fraction=0.05, sigma=DEFAULT_EXPORT_GAUSSIAN_BLUR_SIGMA, sampler=None, ddim_num_timesteps=None, uncertainty_sampler=None, uncertainty_ddim_num_timesteps=None, export_uncertainty=False, uncertainty_num_samples=DEFAULT_UNCERTAINTY_NUM_SAMPLES, uncertainty_only=False, strict_load=False, force_download=False, downloader=None, progress_callback=None)
¶
Run DepthDif inference for one ISO week and return the run directory.
depth_recon.inference¶
depth_recon.inference
¶
Inference package for runtime helpers and hosted export workflows.
__getattr__(name)
¶
Lazily expose public API helpers without preloading the CLI module.
build_datamodule(dataset, data_cfg, training_cfg)
¶
Build and return datamodule.
build_dataset(data_config_path, ds_cfg, *, split='all', dataset_overrides=None)
¶
Build and return dataset.
build_model(model_config_path, data_config_path, training_config_path, model_cfg, datamodule)
¶
Build and return model.
build_random_batch(model, data_cfg, batch_size, height, width, device)
¶
Build and return random batch.
choose_device(device_arg)
¶
Choose and return device.
ds_cfg_value(ds_cfg, nested_key, flat_key, *, default)
¶
Read one dataset config field while preferring the nested schema.
load_checkpoint_weights(model, checkpoint_path, *, strict=False, prefer_ema=True)
¶
Load checkpoint weights into a model, preferring EMA weights when available.
load_yaml(path)
¶
Load and return yaml data.
pretty_shape(value)
¶
Return a compact human-readable shape/type description.
resolve_checkpoint_path(ckpt_override, model_cfg)
¶
Resolve and validate checkpoint path.
resolve_dataset_variant(ds_cfg, data_config_path)
¶
Resolve and validate dataset variant.
resolve_model_type(model_cfg)
¶
Resolve and validate model type.
run_predict_once(model, batch, include_intermediates)
¶
Compute run predict once and return the result.
to_device(batch, device)
¶
Move tensor values in a batch dictionary to the target device.
Data¶
depth_recon.data.datamodule¶
depth_recon.data.datamodule
¶
DepthTileDataModule
¶
Bases: LightningDataModule
Lightning DataModule that builds train and validation dataloaders.
__init__(*, dataset, val_dataset=None, dataloader_cfg=None, val_fraction=0.2, seed=7)
¶
Initialize DepthTileDataModule with configured parameters.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dataset
|
Dataset
|
Input value. |
required |
val_dataset
|
Dataset | None
|
Input value. |
None
|
dataloader_cfg
|
dict[str, Any] | None
|
Configuration dictionary or section. |
None
|
val_fraction
|
float
|
Input value. |
0.2
|
seed
|
int
|
Input value. |
7
|
Returns:
| Name | Type | Description |
|---|---|---|
None |
None
|
No value is returned. |
setup(stage=None)
¶
Compute setup and return the result.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
stage
|
str | None
|
Input value. |
None
|
Returns:
| Name | Type | Description |
|---|---|---|
None |
None
|
No value is returned. |
train_dataloader()
¶
Return the training dataloader from the attached datamodule.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
None
|
This callable takes no explicit input arguments. |
required |
Returns:
| Name | Type | Description |
|---|---|---|
DataLoader |
DataLoader
|
Computed output value. |
val_dataloader()
¶
Return the validation dataloader from the attached datamodule.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
None
|
This callable takes no explicit input arguments. |
required |
Returns:
| Name | Type | Description |
|---|---|---|
DataLoader |
DataLoader
|
Computed output value. |
depth_recon.data.dataset_argo_geotiff_gridded¶
depth_recon.data.dataset_argo_geotiff_gridded
¶
ArgoGeoTIFFGriddedPatchDataset
¶
Bases: Dataset
Dataset that lazily reads training patches from exported GeoTIFF stores.
depth_axis_m
property
¶
Return the GLORYS depth axis in meters.
rows
property
¶
Return patch/date metadata rows.
__getitem__(idx)
¶
Return one model-ready training sample.
__init__(*, geotiff_root_dir=DEFAULT_GEOTIFF_ROOT_DIR, metadata_cache_dir=DEFAULT_METADATA_CACHE_DIR, split='all', tile_size=128, resolution_deg=0.1, patch_grid_source='land_mask', land_mask_path=None, patch_stride=None, max_land_fraction=0.3, force_include_regions=None, finetune_sampling=None, temporal_window_days=7, glorys_var_name='thetao', ostia_var_name='analysed_sst', eo_source='ostia', eo_var_name=None, require_argo_for_train=True, require_argo_for_val=True, require_argo_for_all=False, synthetic_mode=False, synthetic_pixel_count=250, return_info=True, return_coords=True, include_salinity=False, output_fields=None, random_seed=7, cache_size=8, val_fraction=0.2, val_year=None)
¶
Initialize the GeoTIFF-backed patch dataset.
__len__()
¶
Return dataset row count.
from_config(config_path=None, *, split='all', dataset_overrides=None)
classmethod
¶
Build a GeoTIFF dataset from a YAML data config.
ArgoGeoTIFFProfileStore
¶
Profile-indexed ARGO zarr source exported with the GeoTIFF dataset.
__init__(path, *, include_salinity=False)
¶
Open a compact ARGO profile zarr store.
close()
¶
Close the opened zarr dataset.
load_salinity_profiles(indices)
¶
Load selected ARGO salinity profiles as raw PSU arrays.
load_temperature_profiles(indices)
¶
Load selected ARGO temperature profiles as Celsius arrays.
query_indices(*, target_date, grid_y0, grid_x0, tile_size)
¶
Return profile indices assigned to one date and grid patch.
GeoTIFFPatchIndex
¶
GeoTIFFRasterStore
¶
Date-indexed GeoTIFF raster source for one exported variable.
Model Core¶
depth_recon.models.diffusion.PixelDiffusion¶
depth_recon.models.diffusion.PixelDiffusion
¶
PixelDiffusionConditional
¶
Bases: LightningModule
Lightning module that trains and samples conditional pixel diffusion.
__init__(datamodule=None, generated_channels=1, condition_channels=1, output_fields=None, variable_scenario=None, condition_mask_channels=1, condition_include_eo=False, condition_use_valid_mask=True, condition_use_land_mask=False, clamp_known_pixels=True, mask_loss_with_valid_pixels=False, coastal_loss_enabled=False, coastal_loss_radius_px=5, coastal_loss_weight=3.0, coastal_loss_ramp='linear', parameterization='epsilon', num_timesteps=1000, noise_schedule='linear', noise_beta_start=0.0001, noise_beta_end=0.02, unet_dim=64, unet_dim_mults=(1, 2, 4, 8), unet_with_time_emb=True, unet_output_mean_scale=False, unet_residual=False, coord_conditioning_enabled=False, coord_encoding='unit_sphere', date_conditioning_enabled=False, date_encoding='day_of_year_sincos', coord_embed_dim=None, batch_size=1, lr=0.001, lr_scheduler_enabled=False, lr_scheduler_monitor='val/loss_ckpt', lr_scheduler_interval='epoch', lr_scheduler_mode='min', lr_scheduler_factor=0.5, lr_scheduler_patience=10, lr_scheduler_threshold=0.0001, lr_scheduler_threshold_mode='rel', lr_scheduler_cooldown=0, lr_scheduler_min_lr=0.0, lr_scheduler_eps=1e-08, lr_warmup_enabled=True, lr_warmup_steps=1000, lr_warmup_start_ratio=0.1, val_inference_sampler='ddpm', val_ddim_num_timesteps=200, val_ddim_eta=0.0, val_ddim_temperature=1.0, log_intermediates=True, ambient_occlusion_enabled=False, ambient_further_drop_prob=0.1, ambient_apply_to_noisy_branch=True, ambient_shared_spatial_mask=True, ambient_min_kept_observed_pixels=1, ambient_require_x0_parameterization=True, skip_full_reconstruction_in_sanity_check=True, max_full_reconstruction_samples=5, postprocess_gaussian_blur_enabled=False, postprocess_gaussian_blur_sigma=0.35, postprocess_gaussian_blur_kernel_size=3, model_summary_input_size=128, wandb_verbose=True, log_stats_every_n_steps=1, log_images_every_n_steps=200)
¶
Initialize PixelDiffusionConditional with configured parameters.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
datamodule
|
LightningDataModule | None
|
Input value. |
None
|
generated_channels
|
int
|
Input value. |
1
|
condition_channels
|
int
|
Input value. |
1
|
output_fields
|
tuple[str, ...] | list[str] | None
|
Output variables to train/predict. Defaults to temperature only. |
None
|
variable_scenario
|
str | None
|
Scenario label embedded in checkpoints. |
None
|
condition_mask_channels
|
int
|
Mask tensor controlling valid or known pixels. |
1
|
condition_include_eo
|
bool
|
Boolean flag controlling behavior. |
False
|
condition_use_valid_mask
|
bool
|
Mask tensor controlling valid or known pixels. |
True
|
condition_use_land_mask
|
bool
|
Include GLORYS spatial support as conditioning. |
False
|
clamp_known_pixels
|
bool
|
Boolean flag controlling behavior. |
True
|
mask_loss_with_valid_pixels
|
bool
|
Mask tensor controlling valid or known pixels. |
False
|
coastal_loss_enabled
|
bool
|
Increase supervised ocean-pixel loss near land. |
False
|
coastal_loss_radius_px
|
int
|
Pixel radius around land to upweight. |
5
|
coastal_loss_weight
|
float
|
Maximum land-adjacent loss weight. |
3.0
|
coastal_loss_ramp
|
str
|
Distance falloff mode for coastal weights. |
'linear'
|
parameterization
|
str
|
Input value. |
'epsilon'
|
num_timesteps
|
int
|
Step or timestep value. |
1000
|
noise_schedule
|
str
|
Input value. |
'linear'
|
noise_beta_start
|
float
|
Input value. |
0.0001
|
noise_beta_end
|
float
|
Input value. |
0.02
|
unet_dim
|
int
|
Input value. |
64
|
unet_dim_mults
|
tuple[int, ...]
|
Input value. |
(1, 2, 4, 8)
|
unet_with_time_emb
|
bool
|
Boolean flag controlling behavior. |
True
|
unet_output_mean_scale
|
bool
|
Boolean flag controlling behavior. |
False
|
unet_residual
|
bool
|
Boolean flag controlling behavior. |
False
|
coord_conditioning_enabled
|
bool
|
Boolean flag controlling behavior. |
False
|
coord_encoding
|
str
|
Input value. |
'unit_sphere'
|
date_conditioning_enabled
|
bool
|
Boolean flag controlling behavior. |
False
|
date_encoding
|
str
|
Input value. |
'day_of_year_sincos'
|
coord_embed_dim
|
int | None
|
Input value. |
None
|
batch_size
|
int
|
Size/count parameter. |
1
|
lr
|
float
|
Input value. |
0.001
|
lr_scheduler_enabled
|
bool
|
Boolean flag controlling behavior. |
False
|
lr_scheduler_monitor
|
str
|
Input value. |
'val/loss_ckpt'
|
lr_scheduler_interval
|
str
|
Scheduler cadence, "step" or "epoch". |
'epoch'
|
lr_scheduler_mode
|
str
|
Input value. |
'min'
|
lr_scheduler_factor
|
float
|
Input value. |
0.5
|
lr_scheduler_patience
|
int
|
Input value. |
10
|
lr_scheduler_threshold
|
float
|
Input value. |
0.0001
|
lr_scheduler_threshold_mode
|
str
|
Input value. |
'rel'
|
lr_scheduler_cooldown
|
int
|
Input value. |
0
|
lr_scheduler_min_lr
|
float
|
Input value. |
0.0
|
lr_scheduler_eps
|
float
|
Input value. |
1e-08
|
lr_warmup_enabled
|
bool
|
Boolean flag controlling behavior. |
True
|
lr_warmup_steps
|
int
|
Step or timestep value. |
1000
|
lr_warmup_start_ratio
|
float
|
Input value. |
0.1
|
val_inference_sampler
|
str
|
Input value. |
'ddpm'
|
val_ddim_num_timesteps
|
int
|
Input value. |
200
|
val_ddim_eta
|
float
|
Input value. |
0.0
|
val_ddim_temperature
|
float
|
Scale for DDIM initial and step noise. |
1.0
|
log_intermediates
|
bool
|
Boolean flag controlling behavior. |
True
|
ambient_occlusion_enabled
|
bool
|
Boolean flag controlling behavior. |
False
|
ambient_further_drop_prob
|
float
|
Input value. |
0.1
|
ambient_apply_to_noisy_branch
|
bool
|
Boolean flag controlling behavior. |
True
|
ambient_shared_spatial_mask
|
bool
|
Boolean flag controlling behavior. |
True
|
ambient_min_kept_observed_pixels
|
int
|
Input value. |
1
|
ambient_require_x0_parameterization
|
bool
|
Boolean flag controlling behavior. |
True
|
skip_full_reconstruction_in_sanity_check
|
bool
|
Boolean flag controlling behavior. |
True
|
max_full_reconstruction_samples
|
int
|
Input value. |
5
|
postprocess_gaussian_blur_enabled
|
bool
|
Boolean flag controlling behavior. |
False
|
postprocess_gaussian_blur_sigma
|
float
|
Input value. |
0.35
|
postprocess_gaussian_blur_kernel_size
|
int
|
Input value. |
3
|
model_summary_input_size
|
int
|
Spatial size used for Lightning FLOP summary. |
128
|
wandb_verbose
|
bool
|
Boolean flag controlling behavior. |
True
|
log_stats_every_n_steps
|
int
|
Step or timestep value. |
1
|
log_images_every_n_steps
|
int
|
Step or timestep value. |
200
|
Returns:
| Name | Type | Description |
|---|---|---|
None |
None
|
No value is returned. |
configure_optimizers()
¶
Create optimizer and optional scheduler configuration.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
None
|
This callable takes no explicit input arguments. |
required |
Returns:
| Type | Description |
|---|---|
Optimizer | dict[str, Any]
|
torch.optim.Optimizer | dict[str, Any]: Computed output value. |
forward(condition, sampler=None, verbose=False, clamp_known_pixels=None, *, known_mask=None, known_values=None, coords=None, date=None, return_intermediates=False, intermediate_step_indices=None, return_x0_intermediates=False)
¶
Run the module forward computation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
condition
|
Tensor
|
Tensor input for the computation. |
required |
sampler
|
Module | None
|
Sampler instance used for reverse diffusion. |
None
|
verbose
|
bool
|
Boolean flag controlling behavior. |
False
|
clamp_known_pixels
|
bool | None
|
Boolean flag controlling behavior. |
None
|
known_mask
|
Tensor | None
|
Mask tensor controlling valid or known pixels. |
None
|
known_values
|
Tensor | None
|
Tensor input for the computation. |
None
|
coords
|
Tensor | None
|
Coordinate conditioning values. |
None
|
date
|
Tensor | None
|
Date conditioning values. |
None
|
return_intermediates
|
bool
|
Boolean flag controlling behavior. |
False
|
intermediate_step_indices
|
list[int] | None
|
Input value. |
None
|
return_x0_intermediates
|
bool
|
Boolean flag controlling behavior. |
False
|
Returns:
| Type | Description |
|---|---|
Tensor | tuple[Tensor, list[tuple[int, Tensor]]] | tuple[Tensor, list[tuple[int, Tensor]], list[tuple[int, Tensor]]]
|
torch.Tensor | tuple[torch.Tensor, list[tuple[int, torch.Tensor]]] | tuple[torch.Tensor, list[tuple[int, torch.Tensor]], list[tuple[int, torch.Tensor]]]: Tensor output produced by this call. |
from_config(model_config_path=None, data_config_path=None, training_config_path=None, datamodule=None)
classmethod
¶
Compute from config and return the result.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model_config_path
|
str
|
Path to an input or output file. |
None
|
data_config_path
|
str
|
Path to an input or output file. |
None
|
training_config_path
|
str
|
Path to an input or output file. |
None
|
datamodule
|
LightningDataModule | None
|
Input value. |
None
|
Returns:
| Type | Description |
|---|---|
'PixelDiffusionConditional'
|
'PixelDiffusionConditional': Computed output value. |
input_T(value)
¶
Compute input T and return the result.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
value
|
Tensor
|
Tensor input for the computation. |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
torch.Tensor: Tensor output produced by this call. |
load_state_dict(state_dict, strict=True)
¶
Load checkpoint weights into the current module.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state_dict
|
dict[str, Tensor]
|
Tensor input for the computation. |
required |
strict
|
bool
|
Boolean flag controlling behavior. |
True
|
Returns:
| Name | Type | Description |
|---|---|---|
Any |
Any
|
Computed output value. |
on_load_checkpoint(checkpoint)
¶
Validate variable scenario metadata before Lightning restores weights.
on_save_checkpoint(checkpoint)
¶
Embed variable scenario metadata in Lightning checkpoints.
on_validation_epoch_end()
¶
Compute on validation epoch end and return the result.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
None
|
This callable takes no explicit input arguments. |
required |
Returns:
| Name | Type | Description |
|---|---|---|
None |
None
|
No value is returned. |
on_validation_epoch_start()
¶
Compute on validation epoch start and return the result.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
None
|
This callable takes no explicit input arguments. |
required |
Returns:
| Name | Type | Description |
|---|---|---|
None |
None
|
No value is returned. |
optimizer_step(epoch, batch_idx, optimizer, optimizer_closure=None)
¶
Perform one optimizer step with optional learning-rate warmup.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
epoch
|
int
|
Step or timestep value. |
required |
batch_idx
|
int
|
Zero-based index for selecting a sample or batch. |
required |
optimizer
|
Optimizer
|
Optimizer used for parameter updates. |
required |
optimizer_closure
|
Any | None
|
Optimizer used for parameter updates. |
None
|
Returns:
| Name | Type | Description |
|---|---|---|
None |
None
|
No value is returned. |
output_T(value)
¶
Compute output T and return the result.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
value
|
Tensor
|
Tensor input for the computation. |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
torch.Tensor: Tensor output produced by this call. |
predict_step(batch, batch_idx, dataloader_idx=0)
¶
Compute predict step and return the result.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch
|
dict[str, Any]
|
Input value. |
required |
batch_idx
|
int
|
Zero-based index for selecting a sample or batch. |
required |
dataloader_idx
|
int
|
Input value. |
0
|
Returns:
| Type | Description |
|---|---|
dict[str, Any]
|
dict[str, Any]: Dictionary containing computed outputs. |
train_dataloader()
¶
Return the training dataloader from the attached datamodule.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
None
|
This callable takes no explicit input arguments. |
required |
Returns:
| Type | Description |
|---|---|
DataLoader[Any]
|
torch.utils.data.DataLoader[Any]: Computed output value. |
training_step(batch, batch_idx)
¶
Compute training step and return the result.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch
|
dict[str, Any]
|
Input value. |
required |
batch_idx
|
int
|
Zero-based index for selecting a sample or batch. |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
torch.Tensor: Tensor output produced by this call. |
uncertainty_step(batch, batch_idx, dataloader_idx=0, num_samples=20, sampler=None, collapse_channels=True)
¶
Estimate pixel-wise generation uncertainty from repeated predictions.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch
|
dict[str, Any]
|
Input batch passed to prediction. |
required |
batch_idx
|
int
|
Zero-based index for selecting a sample or batch. |
required |
dataloader_idx
|
int
|
Dataloader index passed through to prediction. |
0
|
num_samples
|
int
|
Number of repeated generations used for uncertainty. |
20
|
sampler
|
Module | None
|
Optional sampler used only for this uncertainty pass. |
None
|
collapse_channels
|
bool
|
Collapse depth/channel uncertainty to one raster. |
True
|
Returns:
| Type | Description |
|---|---|
dict[str, Any]
|
dict[str, Any]: Dictionary containing uncertainty maps and metadata. |
val_dataloader()
¶
Return the validation dataloader from the attached datamodule.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
None
|
This callable takes no explicit input arguments. |
required |
Returns:
| Type | Description |
|---|---|
DataLoader[Any] | None
|
torch.utils.data.DataLoader[Any] | None: Computed output value. |
validation_step(batch, batch_idx)
¶
Compute validation step and return the result.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch
|
dict[str, Any]
|
Input value. |
required |
batch_idx
|
int
|
Zero-based index for selecting a sample or batch. |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
torch.Tensor: Tensor output produced by this call. |
depth_recon.models.diffusion.EMA¶
depth_recon.models.diffusion.EMA
¶
EMA
¶
Bases: Callback
Callback that maintains exponential moving-average model weights.
ema_initialized
property
¶
Compute ema initialized and return the result.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
None
|
This callable takes no explicit input arguments. |
required |
Returns:
| Name | Type | Description |
|---|---|---|
bool |
bool
|
Computed scalar output. |
weights_are_applied
property
¶
Return whether EMA weights are currently loaded into the module.
__init__(decay, apply_ema_every_n_steps=1, start_step=0, save_ema_weights_in_callback_state=False, evaluate_ema_weights_instead=False)
¶
Initialize EMA with configured parameters.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
decay
|
float
|
Input value. |
required |
apply_ema_every_n_steps
|
int
|
Step or timestep value. |
1
|
start_step
|
int
|
Step or timestep value. |
0
|
save_ema_weights_in_callback_state
|
bool
|
Boolean flag controlling behavior. |
False
|
evaluate_ema_weights_instead
|
bool
|
Boolean flag controlling behavior. |
False
|
Returns:
| Name | Type | Description |
|---|---|---|
None |
None
|
No value is returned. |
apply_ema(pl_module)
¶
Compute apply ema and return the result.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
pl_module
|
LightningModule
|
Input value. |
required |
Returns:
| Name | Type | Description |
|---|---|---|
None |
None
|
No value is returned. |
apply_multi_tensor_ema(pl_module)
¶
Compute apply multi tensor ema and return the result.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
pl_module
|
LightningModule
|
Input value. |
required |
Returns:
| Name | Type | Description |
|---|---|---|
None |
None
|
No value is returned. |
compute_weight_delta_metrics(pl_module)
¶
Compute raw-vs-EMA weight distance metrics.
ema(pl_module)
¶
Compute ema and return the result.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
pl_module
|
LightningModule
|
Input value. |
required |
Returns:
| Name | Type | Description |
|---|---|---|
None |
None
|
No value is returned. |
load_state_dict(state_dict)
¶
Load checkpoint weights into the current module.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state_dict
|
Dict[str, Any]
|
Input value. |
required |
Returns:
| Name | Type | Description |
|---|---|---|
None |
None
|
No value is returned. |
log_weight_delta_metrics(trainer, pl_module)
¶
Log EMA scalar diagnostics for the current validation epoch.
on_fit_start(trainer, pl_module)
¶
Initialize EMA before sanity validation can run.
on_load_checkpoint(trainer, pl_module, checkpoint)
¶
Compute on load checkpoint and return the result.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
trainer
|
Trainer
|
Input value. |
required |
pl_module
|
LightningModule
|
Input value. |
required |
checkpoint
|
Dict[str, Any]
|
Input value. |
required |
Returns:
| Name | Type | Description |
|---|---|---|
None |
None
|
No value is returned. |
on_test_end(trainer, pl_module)
¶
Compute on test end and return the result.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
trainer
|
Trainer
|
Input value. |
required |
pl_module
|
LightningModule
|
Input value. |
required |
Returns:
| Name | Type | Description |
|---|---|---|
None |
None
|
No value is returned. |
on_test_start(trainer, pl_module)
¶
Compute on test start and return the result.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
trainer
|
Trainer
|
Input value. |
required |
pl_module
|
LightningModule
|
Input value. |
required |
Returns:
| Name | Type | Description |
|---|---|---|
None |
None
|
No value is returned. |
on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)
¶
Compute on train batch end and return the result.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
trainer
|
Trainer
|
Input value. |
required |
pl_module
|
LightningModule
|
Input value. |
required |
outputs
|
STEP_OUTPUT
|
Input value. |
required |
batch
|
Any
|
Input value. |
required |
batch_idx
|
int
|
Zero-based index for selecting a sample or batch. |
required |
Returns:
| Name | Type | Description |
|---|---|---|
None |
None
|
No value is returned. |
on_train_start(trainer, pl_module)
¶
Compute on train start and return the result.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
trainer
|
Trainer
|
Input value. |
required |
pl_module
|
LightningModule
|
Input value. |
required |
Returns:
| Name | Type | Description |
|---|---|---|
None |
None
|
No value is returned. |
on_validation_end(trainer, pl_module)
¶
Compute on validation end and return the result.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
trainer
|
Trainer
|
Input value. |
required |
pl_module
|
LightningModule
|
Input value. |
required |
Returns:
| Name | Type | Description |
|---|---|---|
None |
None
|
No value is returned. |
on_validation_epoch_end(trainer, pl_module)
¶
Log EMA weight diagnostics once per validation epoch.
on_validation_start(trainer, pl_module)
¶
Compute on validation start and return the result.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
trainer
|
Trainer
|
Input value. |
required |
pl_module
|
LightningModule
|
Input value. |
required |
Returns:
| Name | Type | Description |
|---|---|---|
None |
None
|
No value is returned. |
replace_model_weights(pl_module)
¶
Compute replace model weights and return the result.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
pl_module
|
LightningModule
|
Input value. |
required |
Returns:
| Name | Type | Description |
|---|---|---|
None |
None
|
No value is returned. |
restore_original_weights(pl_module)
¶
Compute restore original weights and return the result.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
pl_module
|
LightningModule
|
Input value. |
required |
Returns:
| Name | Type | Description |
|---|---|---|
None |
None
|
No value is returned. |
should_apply_ema(step)
¶
Compute should apply ema and return the result.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
step
|
int
|
Step or timestep value. |
required |
Returns:
| Name | Type | Description |
|---|---|---|
bool |
bool
|
Computed scalar output. |
state_dict()
¶
Return the serializable state for this object.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
None
|
This callable takes no explicit input arguments. |
required |
Returns:
| Type | Description |
|---|---|
Dict[str, Any]
|
Dict[str, Any]: Computed output value. |
depth_recon.models.diffusion.DenoisingDiffusionProcess.DenoisingDiffusionProcess¶
depth_recon.models.diffusion.DenoisingDiffusionProcess.DenoisingDiffusionProcess
¶
ConvNextBlock
¶
Bases: Module
ConvNeXt residual block used within the U-Net backbone.
__init__(dim, dim_out, *, time_emb_dim=None, coord_emb_dim=None, mult=2, norm=True)
¶
Initialize ConvNextBlock with configured parameters.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dim
|
int
|
Input value. |
required |
dim_out
|
int
|
Input value. |
required |
time_emb_dim
|
int | None
|
Input value. |
None
|
coord_emb_dim
|
int | None
|
Input value. |
None
|
mult
|
int
|
Input value. |
2
|
norm
|
bool
|
Boolean flag controlling behavior. |
True
|
Returns:
| Name | Type | Description |
|---|---|---|
None |
None
|
No value is returned. |
forward(x, time_emb=None, coord_emb=None)
¶
Run the module forward computation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
Tensor input for the computation. |
required |
time_emb
|
Tensor | None
|
Tensor input for the computation. |
None
|
coord_emb
|
Tensor | None
|
Tensor input for the computation. |
None
|
Returns:
| Type | Description |
|---|---|
Tensor
|
torch.Tensor: Tensor output produced by this call. |
DDIM_Sampler
¶
Bases: Module
DDIM sampler that performs accelerated reverse-diffusion updates.
__init__(num_timesteps=100, train_timesteps=1000, clip_sample=True, schedule='linear', beta_start=0.0001, beta_end=0.02, eta=0.0, temperature=1.0, betas=None, parameterization='epsilon')
¶
Initialize DDIM_Sampler with configured parameters.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
num_timesteps
|
int
|
Step or timestep value. |
100
|
train_timesteps
|
int
|
Step or timestep value. |
1000
|
clip_sample
|
bool
|
Boolean flag controlling behavior. |
True
|
schedule
|
str
|
Input value. |
'linear'
|
beta_start
|
float
|
Input value. |
0.0001
|
beta_end
|
float
|
Input value. |
0.02
|
eta
|
float
|
Input value. |
0.0
|
temperature
|
float
|
Scale for DDIM initial and stochastic step noise. |
1.0
|
betas
|
Tensor | list[float] | tuple[float, ...] | None
|
Tensor input for the computation. |
None
|
parameterization
|
str
|
Input value. |
'epsilon'
|
Returns:
| Name | Type | Description |
|---|---|---|
None |
None
|
No value is returned. |
estimate_std(alpha_cumprod, alpha_cumprod_prev)
¶
Compute estimate std and return the result.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
alpha_cumprod
|
Tensor
|
Tensor input for the computation. |
required |
alpha_cumprod_prev
|
Tensor
|
Tensor input for the computation. |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
torch.Tensor: Tensor output produced by this call. |
forward(*args, **kwargs)
¶
Run the sampler call and return the next sample.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
*args
|
Any
|
Additional positional arguments forwarded to the underlying call. |
()
|
**kwargs
|
Any
|
Additional keyword arguments forwarded to the underlying call. |
{}
|
Returns:
| Type | Description |
|---|---|
Tensor
|
torch.Tensor: Tensor output produced by this call. |
set_parameterization(parameterization)
¶
Compute set parameterization and return the result.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
parameterization
|
str
|
Input value. |
required |
Returns:
| Name | Type | Description |
|---|---|---|
None |
None
|
No value is returned. |
step(x_t, t, z_t)
¶
Predict the previous diffusion sample for one timestep.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x_t
|
Tensor
|
Tensor input for the computation. |
required |
t
|
Tensor
|
Tensor input for the computation. |
required |
z_t
|
Tensor
|
Tensor input for the computation. |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
torch.Tensor: Tensor output produced by this call. |
DDPM_Sampler
¶
Bases: Module
DDPM sampler that performs one reverse-diffusion step at a time.
__init__(num_timesteps=1000, schedule='linear', beta_start=0.0001, beta_end=0.02, parameterization='epsilon')
¶
Initialize DDPM_Sampler with configured parameters.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
num_timesteps
|
int
|
Step or timestep value. |
1000
|
schedule
|
str
|
Input value. |
'linear'
|
beta_start
|
float
|
Input value. |
0.0001
|
beta_end
|
float
|
Input value. |
0.02
|
parameterization
|
str
|
Input value. |
'epsilon'
|
Returns:
| Name | Type | Description |
|---|---|---|
None |
None
|
No value is returned. |
forward(*args, **kwargs)
¶
Run the sampler call and return the next sample.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
*args
|
Any
|
Additional positional arguments forwarded to the underlying call. |
()
|
**kwargs
|
Any
|
Additional keyword arguments forwarded to the underlying call. |
{}
|
Returns:
| Type | Description |
|---|---|
Tensor
|
torch.Tensor: Tensor output produced by this call. |
posterior_params(x_t, t, noise_pred)
¶
Compute posterior params and return the result.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x_t
|
Tensor
|
Tensor input for the computation. |
required |
t
|
Tensor
|
Tensor input for the computation. |
required |
noise_pred
|
Tensor
|
Tensor input for the computation. |
required |
Returns:
| Type | Description |
|---|---|
tuple[Tensor, Tensor]
|
tuple[torch.Tensor, torch.Tensor]: Tuple containing computed outputs. |
set_parameterization(parameterization)
¶
Compute set parameterization and return the result.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
parameterization
|
str
|
Input value. |
required |
Returns:
| Name | Type | Description |
|---|---|---|
None |
None
|
No value is returned. |
step(x_t, t, z_t)
¶
Predict the previous diffusion sample for one timestep.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x_t
|
Tensor
|
Tensor input for the computation. |
required |
t
|
Tensor
|
Tensor input for the computation. |
required |
z_t
|
Tensor
|
Tensor input for the computation. |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
torch.Tensor: Tensor output produced by this call. |
DenoisingDiffusionConditionalProcess
¶
Bases: Module
Conditional diffusion process module for guided reconstruction.
__init__(generated_channels=3, condition_channels=3, loss_fn=F.mse_loss, schedule='linear', beta_start=0.0001, beta_end=0.02, num_timesteps=1000, unet_dim=64, unet_dim_mults=(1, 2, 4, 8), unet_with_time_emb=True, unet_output_mean_scale=False, unet_residual=False, coord_conditioning_enabled=False, coord_encoding='unit_sphere', date_conditioning_enabled=False, date_encoding='day_of_year_sincos', coord_embed_dim=None, parameterization='epsilon', sampler=None)
¶
Initialize DenoisingDiffusionConditionalProcess with configured parameters.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
generated_channels
|
int
|
Input value. |
3
|
condition_channels
|
int
|
Input value. |
3
|
loss_fn
|
Callable[[Tensor, Tensor], Tensor]
|
Tensor input for the computation. |
mse_loss
|
schedule
|
str
|
Input value. |
'linear'
|
beta_start
|
float
|
Input value. |
0.0001
|
beta_end
|
float
|
Input value. |
0.02
|
num_timesteps
|
int
|
Step or timestep value. |
1000
|
unet_dim
|
int
|
Input value. |
64
|
unet_dim_mults
|
tuple[int, ...]
|
Input value. |
(1, 2, 4, 8)
|
unet_with_time_emb
|
bool
|
Boolean flag controlling behavior. |
True
|
unet_output_mean_scale
|
bool
|
Boolean flag controlling behavior. |
False
|
unet_residual
|
bool
|
Boolean flag controlling behavior. |
False
|
coord_conditioning_enabled
|
bool
|
Boolean flag controlling behavior. |
False
|
coord_encoding
|
str
|
Input value. |
'unit_sphere'
|
date_conditioning_enabled
|
bool
|
Boolean flag controlling behavior. |
False
|
date_encoding
|
str
|
Input value. |
'day_of_year_sincos'
|
coord_embed_dim
|
int | None
|
Input value. |
None
|
parameterization
|
str
|
Input value. |
'epsilon'
|
sampler
|
Module | None
|
Sampler instance used for reverse diffusion. |
None
|
Returns:
| Name | Type | Description |
|---|---|---|
None |
None
|
No value is returned. |
forward(condition, sampler=None, verbose=False, known_mask=None, known_values=None, coord=None, date=None, return_intermediates=False, intermediate_step_indices=None, return_x0_intermediates=False)
¶
Run reverse diffusion and return generated outputs.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
condition
|
Tensor
|
Tensor input for the computation. |
required |
sampler
|
Module | None
|
Sampler instance used for reverse diffusion. |
None
|
verbose
|
bool
|
Boolean flag controlling behavior. |
False
|
known_mask
|
Tensor | None
|
Mask tensor controlling valid or known pixels. |
None
|
known_values
|
Tensor | None
|
Tensor input for the computation. |
None
|
coord
|
Tensor | None
|
Coordinate conditioning values. |
None
|
date
|
Tensor | None
|
Date conditioning values. |
None
|
return_intermediates
|
bool
|
Boolean flag controlling behavior. |
False
|
intermediate_step_indices
|
list[int] | None
|
Input value. |
None
|
return_x0_intermediates
|
bool
|
Boolean flag controlling behavior. |
False
|
Returns:
| Type | Description |
|---|---|
Tensor | tuple[Tensor, list[tuple[int, Tensor]]] | tuple[Tensor, list[tuple[int, Tensor]], list[tuple[int, Tensor]]]
|
torch.Tensor | tuple[torch.Tensor, list[tuple[int, torch.Tensor]]] | tuple[torch.Tensor, list[tuple[int, torch.Tensor]], list[tuple[int, torch.Tensor]]]: Tensor output produced by this call. |
p_loss(output, condition, *, loss_mask=None, further_valid_mask=None, land_mask=None, mask_loss=False, coastal_loss_enabled=False, coastal_loss_radius_px=0, coastal_loss_weight=1.0, coastal_loss_ramp='linear', apply_further_corruption_to_noisy_branch=False, coord=None, date=None)
¶
Compute the diffusion training loss for the current batch.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
output
|
Tensor
|
Tensor input for the computation. |
required |
condition
|
Tensor
|
Tensor input for the computation. |
required |
loss_mask
|
Tensor | None
|
Mask tensor selecting the supervised pixels. |
None
|
further_valid_mask
|
Tensor | None
|
Mask tensor controlling valid or known pixels. |
None
|
land_mask
|
Tensor | None
|
GLORYS spatial ocean/domain support mask. |
None
|
mask_loss
|
bool
|
Mask tensor controlling valid or known pixels. |
False
|
coastal_loss_enabled
|
bool
|
Increase supervised ocean-pixel loss near land. |
False
|
coastal_loss_radius_px
|
int
|
Pixel radius around land to upweight. |
0
|
coastal_loss_weight
|
float
|
Maximum land-adjacent loss weight. |
1.0
|
coastal_loss_ramp
|
str
|
Distance falloff mode for coastal weights. |
'linear'
|
apply_further_corruption_to_noisy_branch
|
bool
|
Boolean flag controlling behavior. |
False
|
coord
|
Tensor | None
|
Coordinate conditioning values. |
None
|
date
|
Tensor | None
|
Date conditioning values. |
None
|
Returns:
| Type | Description |
|---|---|
Tensor
|
torch.Tensor: Tensor output produced by this call. |
DenoisingDiffusionProcess
¶
Bases: Module
Unconditional diffusion process module for training and sampling.
__init__(generated_channels=3, loss_fn=F.mse_loss, schedule='linear', beta_start=0.0001, beta_end=0.02, num_timesteps=1000, unet_dim=64, unet_dim_mults=(1, 2, 4, 8), unet_with_time_emb=True, unet_output_mean_scale=False, unet_residual=False, parameterization='epsilon', sampler=None)
¶
Initialize DenoisingDiffusionProcess with configured parameters.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
generated_channels
|
int
|
Input value. |
3
|
loss_fn
|
Callable[[Tensor, Tensor], Tensor]
|
Tensor input for the computation. |
mse_loss
|
schedule
|
str
|
Input value. |
'linear'
|
beta_start
|
float
|
Input value. |
0.0001
|
beta_end
|
float
|
Input value. |
0.02
|
num_timesteps
|
int
|
Step or timestep value. |
1000
|
unet_dim
|
int
|
Input value. |
64
|
unet_dim_mults
|
tuple[int, ...]
|
Input value. |
(1, 2, 4, 8)
|
unet_with_time_emb
|
bool
|
Boolean flag controlling behavior. |
True
|
unet_output_mean_scale
|
bool
|
Boolean flag controlling behavior. |
False
|
unet_residual
|
bool
|
Boolean flag controlling behavior. |
False
|
parameterization
|
str
|
Input value. |
'epsilon'
|
sampler
|
Module | None
|
Sampler instance used for reverse diffusion. |
None
|
Returns:
| Name | Type | Description |
|---|---|---|
None |
None
|
No value is returned. |
forward(shape=(256, 256), batch_size=1, sampler=None, verbose=False)
¶
Run reverse diffusion and return generated outputs.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
shape
|
tuple[int, int]
|
Input value. |
(256, 256)
|
batch_size
|
int
|
Size/count parameter. |
1
|
sampler
|
Module | None
|
Sampler instance used for reverse diffusion. |
None
|
verbose
|
bool
|
Boolean flag controlling behavior. |
False
|
Returns:
| Type | Description |
|---|---|
Tensor
|
torch.Tensor: Tensor output produced by this call. |
p_loss(output)
¶
Compute the diffusion training loss for the current batch.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
output
|
Tensor
|
Tensor input for the computation. |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
torch.Tensor: Tensor output produced by this call. |
ForwardModel
¶
Bases: Module
Base interface for forward diffusion process implementations.
__init__(num_timesteps=1000, schedule='linear')
¶
Initialize ForwardModel with configured parameters.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
num_timesteps
|
int
|
Step or timestep value. |
1000
|
schedule
|
str
|
Input value. |
'linear'
|
Returns:
| Name | Type | Description |
|---|---|---|
None |
None
|
No value is returned. |
forward(x_0, t)
¶
Run the module forward computation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x_0
|
Tensor
|
Tensor input for the computation. |
required |
t
|
Tensor
|
Tensor input for the computation. |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
torch.Tensor: Tensor output produced by this call. |
step(x_t, t)
¶
Run one update step.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x_t
|
Tensor
|
Tensor input for the computation. |
required |
t
|
Tensor
|
Tensor input for the computation. |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
torch.Tensor: Tensor output produced by this call. |
GaussianForwardProcess
¶
Bases: ForwardModel
Forward diffusion process based on Gaussian noise transitions.
__init__(num_timesteps=1000, schedule='linear', beta_start=0.0001, beta_end=0.02)
¶
Initialize GaussianForwardProcess with configured parameters.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
num_timesteps
|
int
|
Step or timestep value. |
1000
|
schedule
|
str
|
Input value. |
'linear'
|
beta_start
|
float
|
Input value. |
0.0001
|
beta_end
|
float
|
Input value. |
0.02
|
Returns:
| Name | Type | Description |
|---|---|---|
None |
None
|
No value is returned. |
forward(x_0, t, return_noise=False)
¶
Run reverse diffusion and return generated outputs.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x_0
|
Tensor
|
Tensor input for the computation. |
required |
t
|
Tensor
|
Tensor input for the computation. |
required |
return_noise
|
bool
|
Boolean flag controlling behavior. |
False
|
Returns:
| Type | Description |
|---|---|
Tensor | tuple[Tensor, Tensor]
|
torch.Tensor | tuple[torch.Tensor, torch.Tensor]: Tensor output produced by this call. |
step(x_t, t, return_noise=False)
¶
Apply one forward-diffusion transition step.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x_t
|
Tensor
|
Tensor input for the computation. |
required |
t
|
Tensor
|
Tensor input for the computation. |
required |
return_noise
|
bool
|
Boolean flag controlling behavior. |
False
|
Returns:
| Type | Description |
|---|---|
Tensor | tuple[Tensor, Tensor]
|
torch.Tensor | tuple[torch.Tensor, torch.Tensor]: Tensor output produced by this call. |
LayerNorm
¶
Bases: Module
Channel-wise layer normalization for 2D feature maps.
__init__(dim, eps=1e-05)
¶
Initialize LayerNorm with configured parameters.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dim
|
int
|
Input value. |
required |
eps
|
float
|
Input value. |
1e-05
|
Returns:
| Name | Type | Description |
|---|---|---|
None |
None
|
No value is returned. |
forward(x)
¶
Run the module forward computation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
Tensor input for the computation. |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
torch.Tensor: Tensor output produced by this call. |
LinearAttention
¶
Bases: Module
Linear attention block for efficient spatial mixing.
__init__(dim, heads=4, dim_head=32)
¶
Initialize LinearAttention with configured parameters.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dim
|
int
|
Input value. |
required |
heads
|
int
|
Input value. |
4
|
dim_head
|
int
|
Input value. |
32
|
Returns:
| Name | Type | Description |
|---|---|---|
None |
None
|
No value is returned. |
forward(x)
¶
Run the module forward computation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
Tensor input for the computation. |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
torch.Tensor: Tensor output produced by this call. |
PreNorm
¶
Bases: Module
Module that normalizes inputs before applying a submodule.
__init__(dim, fn)
¶
Initialize PreNorm with configured parameters.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dim
|
int
|
Input value. |
required |
fn
|
Module
|
Input value. |
required |
Returns:
| Name | Type | Description |
|---|---|---|
None |
None
|
No value is returned. |
forward(x)
¶
Run the module forward computation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
Tensor input for the computation. |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
torch.Tensor: Tensor output produced by this call. |
Residual
¶
Bases: Module
Wrapper module that adds a residual skip connection.
__init__(fn)
¶
Initialize Residual with configured parameters.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
fn
|
Module
|
Input value. |
required |
Returns:
| Name | Type | Description |
|---|---|---|
None |
None
|
No value is returned. |
forward(x, *args, **kwargs)
¶
Run the module forward computation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
Tensor input for the computation. |
required |
*args
|
Any
|
Additional positional arguments forwarded to the underlying call. |
()
|
**kwargs
|
Any
|
Additional keyword arguments forwarded to the underlying call. |
{}
|
Returns:
| Type | Description |
|---|---|
Tensor
|
torch.Tensor: Tensor output produced by this call. |
SinusoidalPosEmb
¶
Bases: Module
Module that generates sinusoidal timestep embeddings.
__init__(dim)
¶
Initialize SinusoidalPosEmb with configured parameters.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dim
|
int
|
Input value. |
required |
Returns:
| Name | Type | Description |
|---|---|---|
None |
None
|
No value is returned. |
forward(x)
¶
Run the module forward computation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
Tensor input for the computation. |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
torch.Tensor: Tensor output produced by this call. |
UnetConvNextBlock
¶
Bases: Module
U-Net/ConvNeXt backbone used by the diffusion model.
__init__(dim, out_dim=None, dim_mults=(1, 2, 4, 8), channels=3, with_time_emb=True, coord_emb_dim=None, output_mean_scale=False, residual=False)
¶
Initialize UnetConvNextBlock with configured parameters.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dim
|
int
|
Input value. |
required |
out_dim
|
int | None
|
Input value. |
None
|
dim_mults
|
tuple[int, ...]
|
Input value. |
(1, 2, 4, 8)
|
channels
|
int
|
Input value. |
3
|
with_time_emb
|
bool
|
Boolean flag controlling behavior. |
True
|
coord_emb_dim
|
int | None
|
Input value. |
None
|
output_mean_scale
|
bool
|
Boolean flag controlling behavior. |
False
|
residual
|
bool
|
Boolean flag controlling behavior. |
False
|
Returns:
| Name | Type | Description |
|---|---|---|
None |
None
|
No value is returned. |
forward(x, time=None, coord_emb=None)
¶
Run the module forward computation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
Tensor input for the computation. |
required |
time
|
Tensor | None
|
Tensor input for the computation. |
None
|
coord_emb
|
Tensor | None
|
Tensor input for the computation. |
None
|
Returns:
| Type | Description |
|---|---|
Tensor
|
torch.Tensor: Tensor output produced by this call. |
Downsample(dim)
¶
Create a strided-convolution downsampling layer.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dim
|
int
|
Input value. |
required |
Returns:
| Type | Description |
|---|---|
Conv2d
|
nn.Conv2d: Computed output value. |
Upsample(dim)
¶
Create a transpose-convolution upsampling layer.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dim
|
int
|
Input value. |
required |
Returns:
| Type | Description |
|---|---|
ConvTranspose2d
|
nn.ConvTranspose2d: Computed output value. |
cosine_beta_schedule(timesteps, s=0.008, beta_start=0.0001, beta_end=None)
¶
Compute cosine beta schedule and return the result.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
timesteps
|
int
|
Step or timestep value. |
required |
s
|
float
|
Input value. |
0.008
|
beta_start
|
float | None
|
Input value. |
0.0001
|
beta_end
|
float | None
|
Input value. |
None
|
Returns:
| Type | Description |
|---|---|
Tensor
|
torch.Tensor: Tensor output produced by this call. |
default(val, d)
¶
Return the input value or a fallback default.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
val
|
T | None
|
Input value. |
required |
d
|
T | Callable[[], T]
|
Input value. |
required |
Returns:
| Name | Type | Description |
|---|---|---|
T |
T
|
Computed output value. |
exists(x)
¶
Return whether the provided value is not None.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
object
|
Input value. |
required |
Returns:
| Name | Type | Description |
|---|---|---|
bool |
bool
|
Computed scalar output. |
get_beta_schedule(variant, timesteps, beta_start=0.0001, beta_end=0.02)
¶
Compute get beta schedule and return the result.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
variant
|
str
|
Input value. |
required |
timesteps
|
int
|
Step or timestep value. |
required |
beta_start
|
float
|
Input value. |
0.0001
|
beta_end
|
float
|
Input value. |
0.02
|
Returns:
| Type | Description |
|---|---|
Tensor
|
torch.Tensor: Tensor output produced by this call. |
linear_beta_schedule(timesteps, beta_start=0.0001, beta_end=0.02)
¶
Compute linear beta schedule and return the result.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
timesteps
|
int
|
Step or timestep value. |
required |
beta_start
|
float
|
Input value. |
0.0001
|
beta_end
|
float
|
Input value. |
0.02
|
Returns:
| Type | Description |
|---|---|
Tensor
|
torch.Tensor: Tensor output produced by this call. |
quadratic_beta_schedule(timesteps, beta_start=0.0001, beta_end=0.02)
¶
Compute quadratic beta schedule and return the result.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
timesteps
|
int
|
Step or timestep value. |
required |
beta_start
|
float
|
Input value. |
0.0001
|
beta_end
|
float
|
Input value. |
0.02
|
Returns:
| Type | Description |
|---|---|
Tensor
|
torch.Tensor: Tensor output produced by this call. |
sigmoid_beta_schedule(timesteps, beta_start=0.0001, beta_end=0.02)
¶
Compute sigmoid beta schedule and return the result.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
timesteps
|
int
|
Step or timestep value. |
required |
beta_start
|
float
|
Input value. |
0.0001
|
beta_end
|
float
|
Input value. |
0.02
|
Returns:
| Type | Description |
|---|---|
Tensor
|
torch.Tensor: Tensor output produced by this call. |
depth_recon.models.diffusion.DenoisingDiffusionProcess.forward¶
depth_recon.models.diffusion.DenoisingDiffusionProcess.forward
¶
This file contains implementations of the forward diffusion process
Current Models:
1) Gaussian Diffusion
ForwardModel
¶
Bases: Module
Base interface for forward diffusion process implementations.
__init__(num_timesteps=1000, schedule='linear')
¶
Initialize ForwardModel with configured parameters.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
num_timesteps
|
int
|
Step or timestep value. |
1000
|
schedule
|
str
|
Input value. |
'linear'
|
Returns:
| Name | Type | Description |
|---|---|---|
None |
None
|
No value is returned. |
forward(x_0, t)
¶
Run the module forward computation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x_0
|
Tensor
|
Tensor input for the computation. |
required |
t
|
Tensor
|
Tensor input for the computation. |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
torch.Tensor: Tensor output produced by this call. |
step(x_t, t)
¶
Run one update step.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x_t
|
Tensor
|
Tensor input for the computation. |
required |
t
|
Tensor
|
Tensor input for the computation. |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
torch.Tensor: Tensor output produced by this call. |
GaussianForwardProcess
¶
Bases: ForwardModel
Forward diffusion process based on Gaussian noise transitions.
__init__(num_timesteps=1000, schedule='linear', beta_start=0.0001, beta_end=0.02)
¶
Initialize GaussianForwardProcess with configured parameters.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
num_timesteps
|
int
|
Step or timestep value. |
1000
|
schedule
|
str
|
Input value. |
'linear'
|
beta_start
|
float
|
Input value. |
0.0001
|
beta_end
|
float
|
Input value. |
0.02
|
Returns:
| Name | Type | Description |
|---|---|---|
None |
None
|
No value is returned. |
forward(x_0, t, return_noise=False)
¶
Run reverse diffusion and return generated outputs.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x_0
|
Tensor
|
Tensor input for the computation. |
required |
t
|
Tensor
|
Tensor input for the computation. |
required |
return_noise
|
bool
|
Boolean flag controlling behavior. |
False
|
Returns:
| Type | Description |
|---|---|
Tensor | tuple[Tensor, Tensor]
|
torch.Tensor | tuple[torch.Tensor, torch.Tensor]: Tensor output produced by this call. |
step(x_t, t, return_noise=False)
¶
Apply one forward-diffusion transition step.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x_t
|
Tensor
|
Tensor input for the computation. |
required |
t
|
Tensor
|
Tensor input for the computation. |
required |
return_noise
|
bool
|
Boolean flag controlling behavior. |
False
|
Returns:
| Type | Description |
|---|---|
Tensor | tuple[Tensor, Tensor]
|
torch.Tensor | tuple[torch.Tensor, torch.Tensor]: Tensor output produced by this call. |
cosine_beta_schedule(timesteps, s=0.008, beta_start=0.0001, beta_end=None)
¶
Compute cosine beta schedule and return the result.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
timesteps
|
int
|
Step or timestep value. |
required |
s
|
float
|
Input value. |
0.008
|
beta_start
|
float | None
|
Input value. |
0.0001
|
beta_end
|
float | None
|
Input value. |
None
|
Returns:
| Type | Description |
|---|---|
Tensor
|
torch.Tensor: Tensor output produced by this call. |
get_beta_schedule(variant, timesteps, beta_start=0.0001, beta_end=0.02)
¶
Compute get beta schedule and return the result.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
variant
|
str
|
Input value. |
required |
timesteps
|
int
|
Step or timestep value. |
required |
beta_start
|
float
|
Input value. |
0.0001
|
beta_end
|
float
|
Input value. |
0.02
|
Returns:
| Type | Description |
|---|---|
Tensor
|
torch.Tensor: Tensor output produced by this call. |
linear_beta_schedule(timesteps, beta_start=0.0001, beta_end=0.02)
¶
Compute linear beta schedule and return the result.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
timesteps
|
int
|
Step or timestep value. |
required |
beta_start
|
float
|
Input value. |
0.0001
|
beta_end
|
float
|
Input value. |
0.02
|
Returns:
| Type | Description |
|---|---|
Tensor
|
torch.Tensor: Tensor output produced by this call. |
quadratic_beta_schedule(timesteps, beta_start=0.0001, beta_end=0.02)
¶
Compute quadratic beta schedule and return the result.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
timesteps
|
int
|
Step or timestep value. |
required |
beta_start
|
float
|
Input value. |
0.0001
|
beta_end
|
float
|
Input value. |
0.02
|
Returns:
| Type | Description |
|---|---|
Tensor
|
torch.Tensor: Tensor output produced by this call. |
sigmoid_beta_schedule(timesteps, beta_start=0.0001, beta_end=0.02)
¶
Compute sigmoid beta schedule and return the result.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
timesteps
|
int
|
Step or timestep value. |
required |
beta_start
|
float
|
Input value. |
0.0001
|
beta_end
|
float
|
Input value. |
0.02
|
Returns:
| Type | Description |
|---|---|
Tensor
|
torch.Tensor: Tensor output produced by this call. |
depth_recon.models.diffusion.DenoisingDiffusionProcess.beta_schedules¶
depth_recon.models.diffusion.DenoisingDiffusionProcess.beta_schedules
¶
cosine_beta_schedule(timesteps, s=0.008, beta_start=0.0001, beta_end=None)
¶
Compute cosine beta schedule and return the result.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
timesteps
|
int
|
Step or timestep value. |
required |
s
|
float
|
Input value. |
0.008
|
beta_start
|
float | None
|
Input value. |
0.0001
|
beta_end
|
float | None
|
Input value. |
None
|
Returns:
| Type | Description |
|---|---|
Tensor
|
torch.Tensor: Tensor output produced by this call. |
get_beta_schedule(variant, timesteps, beta_start=0.0001, beta_end=0.02)
¶
Compute get beta schedule and return the result.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
variant
|
str
|
Input value. |
required |
timesteps
|
int
|
Step or timestep value. |
required |
beta_start
|
float
|
Input value. |
0.0001
|
beta_end
|
float
|
Input value. |
0.02
|
Returns:
| Type | Description |
|---|---|
Tensor
|
torch.Tensor: Tensor output produced by this call. |
linear_beta_schedule(timesteps, beta_start=0.0001, beta_end=0.02)
¶
Compute linear beta schedule and return the result.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
timesteps
|
int
|
Step or timestep value. |
required |
beta_start
|
float
|
Input value. |
0.0001
|
beta_end
|
float
|
Input value. |
0.02
|
Returns:
| Type | Description |
|---|---|
Tensor
|
torch.Tensor: Tensor output produced by this call. |
quadratic_beta_schedule(timesteps, beta_start=0.0001, beta_end=0.02)
¶
Compute quadratic beta schedule and return the result.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
timesteps
|
int
|
Step or timestep value. |
required |
beta_start
|
float
|
Input value. |
0.0001
|
beta_end
|
float
|
Input value. |
0.02
|
Returns:
| Type | Description |
|---|---|
Tensor
|
torch.Tensor: Tensor output produced by this call. |
sigmoid_beta_schedule(timesteps, beta_start=0.0001, beta_end=0.02)
¶
Compute sigmoid beta schedule and return the result.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
timesteps
|
int
|
Step or timestep value. |
required |
beta_start
|
float
|
Input value. |
0.0001
|
beta_end
|
float
|
Input value. |
0.02
|
Returns:
| Type | Description |
|---|---|
Tensor
|
torch.Tensor: Tensor output produced by this call. |
depth_recon.models.diffusion.DenoisingDiffusionProcess.samplers.DDPM¶
depth_recon.models.diffusion.DenoisingDiffusionProcess.samplers.DDPM
¶
This file contains the DDPM sampler class for a diffusion process
DDPM_Sampler
¶
Bases: Module
DDPM sampler that performs one reverse-diffusion step at a time.
__init__(num_timesteps=1000, schedule='linear', beta_start=0.0001, beta_end=0.02, parameterization='epsilon')
¶
Initialize DDPM_Sampler with configured parameters.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
num_timesteps
|
int
|
Step or timestep value. |
1000
|
schedule
|
str
|
Input value. |
'linear'
|
beta_start
|
float
|
Input value. |
0.0001
|
beta_end
|
float
|
Input value. |
0.02
|
parameterization
|
str
|
Input value. |
'epsilon'
|
Returns:
| Name | Type | Description |
|---|---|---|
None |
None
|
No value is returned. |
forward(*args, **kwargs)
¶
Run the sampler call and return the next sample.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
*args
|
Any
|
Additional positional arguments forwarded to the underlying call. |
()
|
**kwargs
|
Any
|
Additional keyword arguments forwarded to the underlying call. |
{}
|
Returns:
| Type | Description |
|---|---|
Tensor
|
torch.Tensor: Tensor output produced by this call. |
posterior_params(x_t, t, noise_pred)
¶
Compute posterior params and return the result.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x_t
|
Tensor
|
Tensor input for the computation. |
required |
t
|
Tensor
|
Tensor input for the computation. |
required |
noise_pred
|
Tensor
|
Tensor input for the computation. |
required |
Returns:
| Type | Description |
|---|---|
tuple[Tensor, Tensor]
|
tuple[torch.Tensor, torch.Tensor]: Tuple containing computed outputs. |
set_parameterization(parameterization)
¶
Compute set parameterization and return the result.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
parameterization
|
str
|
Input value. |
required |
Returns:
| Name | Type | Description |
|---|---|---|
None |
None
|
No value is returned. |
step(x_t, t, z_t)
¶
Predict the previous diffusion sample for one timestep.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x_t
|
Tensor
|
Tensor input for the computation. |
required |
t
|
Tensor
|
Tensor input for the computation. |
required |
z_t
|
Tensor
|
Tensor input for the computation. |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
torch.Tensor: Tensor output produced by this call. |
cosine_beta_schedule(timesteps, s=0.008, beta_start=0.0001, beta_end=None)
¶
Compute cosine beta schedule and return the result.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
timesteps
|
int
|
Step or timestep value. |
required |
s
|
float
|
Input value. |
0.008
|
beta_start
|
float | None
|
Input value. |
0.0001
|
beta_end
|
float | None
|
Input value. |
None
|
Returns:
| Type | Description |
|---|---|
Tensor
|
torch.Tensor: Tensor output produced by this call. |
get_beta_schedule(variant, timesteps, beta_start=0.0001, beta_end=0.02)
¶
Compute get beta schedule and return the result.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
variant
|
str
|
Input value. |
required |
timesteps
|
int
|
Step or timestep value. |
required |
beta_start
|
float
|
Input value. |
0.0001
|
beta_end
|
float
|
Input value. |
0.02
|
Returns:
| Type | Description |
|---|---|
Tensor
|
torch.Tensor: Tensor output produced by this call. |
linear_beta_schedule(timesteps, beta_start=0.0001, beta_end=0.02)
¶
Compute linear beta schedule and return the result.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
timesteps
|
int
|
Step or timestep value. |
required |
beta_start
|
float
|
Input value. |
0.0001
|
beta_end
|
float
|
Input value. |
0.02
|
Returns:
| Type | Description |
|---|---|
Tensor
|
torch.Tensor: Tensor output produced by this call. |
quadratic_beta_schedule(timesteps, beta_start=0.0001, beta_end=0.02)
¶
Compute quadratic beta schedule and return the result.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
timesteps
|
int
|
Step or timestep value. |
required |
beta_start
|
float
|
Input value. |
0.0001
|
beta_end
|
float
|
Input value. |
0.02
|
Returns:
| Type | Description |
|---|---|
Tensor
|
torch.Tensor: Tensor output produced by this call. |
sigmoid_beta_schedule(timesteps, beta_start=0.0001, beta_end=0.02)
¶
Compute sigmoid beta schedule and return the result.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
timesteps
|
int
|
Step or timestep value. |
required |
beta_start
|
float
|
Input value. |
0.0001
|
beta_end
|
float
|
Input value. |
0.02
|
Returns:
| Type | Description |
|---|---|
Tensor
|
torch.Tensor: Tensor output produced by this call. |
depth_recon.models.diffusion.DenoisingDiffusionProcess.samplers.DDIM¶
depth_recon.models.diffusion.DenoisingDiffusionProcess.samplers.DDIM
¶
This file contains the DDIM sampler class for a diffusion process
DDIM_Sampler
¶
Bases: Module
DDIM sampler that performs accelerated reverse-diffusion updates.
__init__(num_timesteps=100, train_timesteps=1000, clip_sample=True, schedule='linear', beta_start=0.0001, beta_end=0.02, eta=0.0, temperature=1.0, betas=None, parameterization='epsilon')
¶
Initialize DDIM_Sampler with configured parameters.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
num_timesteps
|
int
|
Step or timestep value. |
100
|
train_timesteps
|
int
|
Step or timestep value. |
1000
|
clip_sample
|
bool
|
Boolean flag controlling behavior. |
True
|
schedule
|
str
|
Input value. |
'linear'
|
beta_start
|
float
|
Input value. |
0.0001
|
beta_end
|
float
|
Input value. |
0.02
|
eta
|
float
|
Input value. |
0.0
|
temperature
|
float
|
Scale for DDIM initial and stochastic step noise. |
1.0
|
betas
|
Tensor | list[float] | tuple[float, ...] | None
|
Tensor input for the computation. |
None
|
parameterization
|
str
|
Input value. |
'epsilon'
|
Returns:
| Name | Type | Description |
|---|---|---|
None |
None
|
No value is returned. |
estimate_std(alpha_cumprod, alpha_cumprod_prev)
¶
Compute estimate std and return the result.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
alpha_cumprod
|
Tensor
|
Tensor input for the computation. |
required |
alpha_cumprod_prev
|
Tensor
|
Tensor input for the computation. |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
torch.Tensor: Tensor output produced by this call. |
forward(*args, **kwargs)
¶
Run the sampler call and return the next sample.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
*args
|
Any
|
Additional positional arguments forwarded to the underlying call. |
()
|
**kwargs
|
Any
|
Additional keyword arguments forwarded to the underlying call. |
{}
|
Returns:
| Type | Description |
|---|---|
Tensor
|
torch.Tensor: Tensor output produced by this call. |
set_parameterization(parameterization)
¶
Compute set parameterization and return the result.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
parameterization
|
str
|
Input value. |
required |
Returns:
| Name | Type | Description |
|---|---|---|
None |
None
|
No value is returned. |
step(x_t, t, z_t)
¶
Predict the previous diffusion sample for one timestep.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x_t
|
Tensor
|
Tensor input for the computation. |
required |
t
|
Tensor
|
Tensor input for the computation. |
required |
z_t
|
Tensor
|
Tensor input for the computation. |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
torch.Tensor: Tensor output produced by this call. |
cosine_beta_schedule(timesteps, s=0.008, beta_start=0.0001, beta_end=None)
¶
Compute cosine beta schedule and return the result.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
timesteps
|
int
|
Step or timestep value. |
required |
s
|
float
|
Input value. |
0.008
|
beta_start
|
float | None
|
Input value. |
0.0001
|
beta_end
|
float | None
|
Input value. |
None
|
Returns:
| Type | Description |
|---|---|
Tensor
|
torch.Tensor: Tensor output produced by this call. |
get_beta_schedule(variant, timesteps, beta_start=0.0001, beta_end=0.02)
¶
Compute get beta schedule and return the result.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
variant
|
str
|
Input value. |
required |
timesteps
|
int
|
Step or timestep value. |
required |
beta_start
|
float
|
Input value. |
0.0001
|
beta_end
|
float
|
Input value. |
0.02
|
Returns:
| Type | Description |
|---|---|
Tensor
|
torch.Tensor: Tensor output produced by this call. |
linear_beta_schedule(timesteps, beta_start=0.0001, beta_end=0.02)
¶
Compute linear beta schedule and return the result.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
timesteps
|
int
|
Step or timestep value. |
required |
beta_start
|
float
|
Input value. |
0.0001
|
beta_end
|
float
|
Input value. |
0.02
|
Returns:
| Type | Description |
|---|---|
Tensor
|
torch.Tensor: Tensor output produced by this call. |
quadratic_beta_schedule(timesteps, beta_start=0.0001, beta_end=0.02)
¶
Compute quadratic beta schedule and return the result.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
timesteps
|
int
|
Step or timestep value. |
required |
beta_start
|
float
|
Input value. |
0.0001
|
beta_end
|
float
|
Input value. |
0.02
|
Returns:
| Type | Description |
|---|---|
Tensor
|
torch.Tensor: Tensor output produced by this call. |
sigmoid_beta_schedule(timesteps, beta_start=0.0001, beta_end=0.02)
¶
Compute sigmoid beta schedule and return the result.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
timesteps
|
int
|
Step or timestep value. |
required |
beta_start
|
float
|
Input value. |
0.0001
|
beta_end
|
float
|
Input value. |
0.02
|
Returns:
| Type | Description |
|---|---|
Tensor
|
torch.Tensor: Tensor output produced by this call. |
depth_recon.models.diffusion.DenoisingDiffusionProcess.DenoisingDiffusionProcess.UnetConvNextBlock¶
depth_recon.models.diffusion.DenoisingDiffusionProcess.DenoisingDiffusionProcess.UnetConvNextBlock
¶
Bases: Module
U-Net/ConvNeXt backbone used by the diffusion model.
__init__(dim, out_dim=None, dim_mults=(1, 2, 4, 8), channels=3, with_time_emb=True, coord_emb_dim=None, output_mean_scale=False, residual=False)
¶
Initialize UnetConvNextBlock with configured parameters.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dim
|
int
|
Input value. |
required |
out_dim
|
int | None
|
Input value. |
None
|
dim_mults
|
tuple[int, ...]
|
Input value. |
(1, 2, 4, 8)
|
channels
|
int
|
Input value. |
3
|
with_time_emb
|
bool
|
Boolean flag controlling behavior. |
True
|
coord_emb_dim
|
int | None
|
Input value. |
None
|
output_mean_scale
|
bool
|
Boolean flag controlling behavior. |
False
|
residual
|
bool
|
Boolean flag controlling behavior. |
False
|
Returns:
| Name | Type | Description |
|---|---|---|
None |
None
|
No value is returned. |
forward(x, time=None, coord_emb=None)
¶
Run the module forward computation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
Tensor input for the computation. |
required |
time
|
Tensor | None
|
Tensor input for the computation. |
None
|
coord_emb
|
Tensor | None
|
Tensor input for the computation. |
None
|
Returns:
| Type | Description |
|---|---|
Tensor
|
torch.Tensor: Tensor output produced by this call. |
Utilities¶
depth_recon.utils.normalizations¶
depth_recon.utils.normalizations
¶
salinity_normalize(mode, tensor)
¶
Compute salinity normalization and return the result.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
mode
|
str
|
Input value. |
required |
tensor
|
Tensor
|
Tensor input for the computation. |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
torch.Tensor: Tensor output produced by this call. |
salinity_to_plot_unit(tensor, *, tensor_is_normalized=True)
¶
Compute salinity plot unit and return the result.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
tensor
|
Tensor
|
Tensor input for the computation. |
required |
tensor_is_normalized
|
bool
|
Boolean flag controlling behavior. |
True
|
Returns:
| Type | Description |
|---|---|
Tensor
|
torch.Tensor: Tensor output produced by this call. |
temperature_normalize(mode, tensor)
¶
Compute temperature normalize and return the result.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
mode
|
str
|
Input value. |
required |
tensor
|
Tensor
|
Tensor input for the computation. |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
torch.Tensor: Tensor output produced by this call. |
temperature_to_plot_unit(tensor, *, tensor_is_normalized=True)
¶
Compute temperature to plot unit and return the result.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
tensor
|
Tensor
|
Tensor input for the computation. |
required |
tensor_is_normalized
|
bool
|
Boolean flag controlling behavior. |
True
|
Returns:
| Type | Description |
|---|---|
Tensor
|
torch.Tensor: Tensor output produced by this call. |
depth_recon.utils.stretching¶
depth_recon.utils.stretching
¶
minmax_stretch(tensor, *, mask=None, nodata_value=None)
¶
Compute minmax stretch and return the result.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
tensor
|
Tensor
|
Tensor input for the computation. |
required |
mask
|
Tensor | None
|
Mask tensor controlling valid or known pixels. |
None
|
nodata_value
|
float | None
|
Input value. |
None
|
Returns:
| Type | Description |
|---|---|
Tensor
|
torch.Tensor: Tensor output produced by this call. |
depth_recon.utils.validation_denoise¶
depth_recon.utils.validation_denoise
¶
average_observed_argo_pixels_per_image(valid_mask)
¶
Return the average number of spatial pixels with ARGO observations.
build_capture_indices(total_steps, intermediate_step_indices)
¶
Build and return capture indices.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
total_steps
|
int
|
Step or timestep value. |
required |
intermediate_step_indices
|
list[int] | None
|
Input value. |
required |
Returns:
| Type | Description |
|---|---|
set[int]
|
set[int]: Computed output value. |
build_evenly_spaced_capture_steps(total_steps, num_frames)
¶
Build and return evenly spaced capture steps.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
total_steps
|
int
|
Step or timestep value. |
required |
num_frames
|
int
|
Size/count parameter. |
required |
Returns:
| Type | Description |
|---|---|
list[int]
|
list[int]: List containing computed outputs. |
log_wandb_conditional_reconstruction_grid(*, logger, x, y=None, y_hat, y_target, valid_mask=None, land_mask=None, eo=None, prefix='val_imgs', image_key='x_y_full_reconstruction', cmap='turbo', show_valid_mask_panel=True, plot_unit='temperature', error_metric_prefix='val_absolute_band_error', error_metric_unit='deg', error_metric_label='L1 (deg)', error_metric_title='Generated-Pixel L1 by Band')
¶
Log wandb conditional reconstruction grid for monitoring.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
logger
|
Any
|
Logger instance used for experiment tracking. |
required |
x
|
Tensor
|
Tensor input for the computation. |
required |
y
|
Tensor | None
|
Tensor input for the computation. |
None
|
y_hat
|
Tensor
|
Tensor input for the computation. |
required |
y_target
|
Tensor
|
Tensor input for the computation. |
required |
valid_mask
|
Tensor | None
|
Mask tensor controlling valid or known pixels. |
None
|
land_mask
|
Tensor | None
|
Mask tensor controlling valid or known pixels. |
None
|
eo
|
Tensor | None
|
Tensor input for the computation. |
None
|
prefix
|
str
|
Input value. |
'val_imgs'
|
image_key
|
str
|
Input value. |
'x_y_full_reconstruction'
|
cmap
|
str
|
Input value. |
'turbo'
|
show_valid_mask_panel
|
bool
|
Controls whether valid mask is shown as a panel. |
True
|
plot_unit
|
str
|
Physical variable scale to map into 0..1 plot units. |
'temperature'
|
error_metric_prefix
|
str
|
W&B namespace for per-band error metrics. |
'val_absolute_band_error'
|
error_metric_unit
|
str
|
Unit suffix used in per-band metric names. |
'deg'
|
error_metric_label
|
str
|
Series label for the compact W&B line chart. |
'L1 (deg)'
|
error_metric_title
|
str
|
Title for the compact W&B line chart. |
'Generated-Pixel L1 by Band'
|
Returns:
| Name | Type | Description |
|---|---|---|
None |
None
|
No value is returned. |
log_wandb_denoise_timestep_grid(*, logger, denoise_samples, mae_samples=None, total_steps, sampler, conditioning_image=None, eo_conditioning_image=None, ground_truth=None, valid_mask=None, land_mask=None, prefix='val_imgs', cmap='turbo', plot_unit='temperature', nrows=4, ncols=4, tile_size_px=128, tile_pad_px=2)
¶
Log wandb denoise timestep grid for monitoring.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
logger
|
Any
|
Logger instance used for experiment tracking. |
required |
denoise_samples
|
list[tuple[int, Tensor]]
|
Tensor input for the computation. |
required |
mae_samples
|
list[tuple[int, Tensor]] | None
|
Tensor input for the computation. |
None
|
total_steps
|
int
|
Step or timestep value. |
required |
sampler
|
Any
|
Sampler instance used for reverse diffusion. |
required |
conditioning_image
|
Tensor | None
|
Tensor input for the computation. |
None
|
eo_conditioning_image
|
Tensor | None
|
Tensor input for the computation. |
None
|
ground_truth
|
Tensor | None
|
Tensor input for the computation. |
None
|
valid_mask
|
Tensor | None
|
Mask tensor controlling valid or known pixels. |
None
|
land_mask
|
Tensor | None
|
Mask tensor controlling valid or known pixels. |
None
|
prefix
|
str
|
Input value. |
'val_imgs'
|
cmap
|
str
|
Input value. |
'turbo'
|
plot_unit
|
str
|
Physical variable scale to map into 0..1 plot units. |
'temperature'
|
nrows
|
int
|
Input value. |
4
|
ncols
|
int
|
Input value. |
4
|
tile_size_px
|
int
|
Input value. |
128
|
tile_pad_px
|
int
|
Input value. |
2
|
Returns:
| Name | Type | Description |
|---|---|---|
None |
None
|
No value is returned. |
log_wandb_depth_level_reconstruction_grid(*, logger, y_hat, y_target, valid_mask=None, eo=None, land_mask=None, prefix='val_imgs', image_key='depth_level_reconstruction_grid', band_indices=(0, 1, 3), sample_idx=0, cmap='turbo')
¶
Log wandb depth-level reconstruction grid for monitoring.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
logger
|
Any
|
Logger instance used for experiment tracking. |
required |
y_hat
|
Tensor
|
Tensor input for the computation. |
required |
y_target
|
Tensor
|
Tensor input for the computation. |
required |
valid_mask
|
Tensor | None
|
Mask tensor controlling valid or known pixels. |
None
|
eo
|
Tensor | None
|
Tensor input for the computation. |
None
|
land_mask
|
Tensor | None
|
Mask tensor controlling valid or known pixels. |
None
|
prefix
|
str
|
Input value. |
'val_imgs'
|
image_key
|
str
|
Input value. |
'depth_level_reconstruction_grid'
|
band_indices
|
tuple[int, ...]
|
Input value. |
(0, 1, 3)
|
sample_idx
|
int
|
Input value. |
0
|
cmap
|
str
|
Input value. |
'turbo'
|
Returns:
| Name | Type | Description |
|---|---|---|
None |
None
|
No value is returned. |
log_wandb_diffusion_schedule_profile(*, logger, sampler, total_steps, prefix='val_imgs', eps=1e-12)
¶
Log wandb diffusion schedule profile for monitoring.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
logger
|
Any
|
Logger instance used for experiment tracking. |
required |
sampler
|
Any
|
Sampler instance used for reverse diffusion. |
required |
total_steps
|
int
|
Step or timestep value. |
required |
prefix
|
str
|
Input value. |
'val_imgs'
|
eps
|
float
|
Input value. |
1e-12
|
Returns:
| Name | Type | Description |
|---|---|---|
None |
None
|
No value is returned. |
log_wandb_glorys_profile_comparison(*, logger, x, y_hat, y_target, conditioning_mask=None, candidate_mask=None, prefix='val_imgs', image_key='glorys_profile_comparison', sample_idx=0, profile_x_label='Temperature (deg C)')
¶
Log full-depth profile comparisons at generated-only validation pixels.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
logger
|
Any
|
Logger instance used for experiment tracking. |
required |
x
|
Tensor
|
Conditioning tensor containing sparse Argo-aligned profiles. |
required |
y_hat
|
Tensor
|
Reconstructed tensor in denormalized space. |
required |
y_target
|
Tensor
|
GLORYS target tensor in denormalized space. |
required |
conditioning_mask
|
Tensor | None
|
Mask tensor marking known x pixels. |
None
|
candidate_mask
|
Tensor | None
|
Mask tensor selecting generated-only pixels. |
None
|
prefix
|
str
|
Input value. |
'val_imgs'
|
image_key
|
str
|
Input value. |
'glorys_profile_comparison'
|
sample_idx
|
int
|
Zero-based index for selecting a sample or batch. |
0
|
profile_x_label
|
str
|
X-axis label for physical profile values. |
'Temperature (deg C)'
|
Returns:
| Name | Type | Description |
|---|---|---|
None |
None
|
No value is returned. |
plot_average_glorys_profile_error_axis(ax, *, mean_abs_error_prediction_vs_glorys, mean_abs_error_prediction_vs_argo, depth_axis=None, title=None, show_legend=False)
¶
Draw one pooled absolute-error-vs-depth axis for the validation summary.
plot_glorys_profile_comparison_axis(ax, *, x_profile, y_hat_profile, y_target_profile, observed_profile, depth_axis=None, ostia_sst_c=None, title=None, show_legend=False, profile_x_label='Temperature (deg C)', surface_context_label='OSTIA SST')
¶
Draw one validation-style profile comparison axis.
plot_glorys_profile_error_axis(ax, *, x_profile, y_hat_profile, y_target_profile, observed_profile, depth_axis=None, title=None, show_legend=False, error_x_label='Absolute error (deg C)')
¶
Draw one absolute-error-vs-depth axis for prediction errors.
save_average_glorys_profile_and_error_plot(*, output_path, mean_argo_profile_c, mean_prediction_profile_c, mean_glorys_profile_c, mean_abs_error_prediction_vs_glorys, mean_abs_error_prediction_vs_argo, depth_axis=None, figure_title=None, dpi=180)
¶
Save one two-panel pooled profile/error validation summary plot to disk.
save_average_glorys_profile_error_plot(*, output_path, mean_abs_error_prediction_vs_glorys, mean_abs_error_prediction_vs_argo, depth_axis=None, figure_title=None, dpi=180)
¶
Save one single-panel validation-summary error plot to disk.
save_glorys_profile_comparison_plot(*, output_path, x_profile, y_hat_profile, y_target_profile, observed_profile, depth_axis=None, ostia_sst_c=None, title=None, figure_title=None, profile_x_label='Temperature (deg C)', error_x_label='Absolute error (deg C)', surface_context_label='OSTIA SST', dpi=180, webp_quality=95)
¶
Save one validation-style profile comparison plot to disk.
step_to_sampler_timestep_label(*, step_index, total_steps, sampler)
¶
Compute step to sampler timestep label and return the result.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
step_index
|
int
|
Input value. |
required |
total_steps
|
int
|
Step or timestep value. |
required |
sampler
|
Any
|
Sampler instance used for reverse diffusion. |
required |
Returns:
| Name | Type | Description |
|---|---|---|
int |
int
|
Computed scalar output. |