| | |
| | """ |
| | Training script for conditional DDPM on The Well datasets. |
| | Includes periodic evaluation with WandB video logging. |
| | |
| | Usage: |
| | python train_diffusion.py --dataset turbulent_radiative_layer_2D --wandb |
| | python train_diffusion.py --dataset active_matter --batch_size 4 --wandb |
| | """ |
| | import argparse |
| | import logging |
| | import math |
| | import os |
| | import time |
| |
|
| | import torch |
| | import torch.nn as nn |
| | from torch.amp import GradScaler, autocast |
| | from tqdm import tqdm |
| |
|
| | from data_pipeline import create_dataloader, prepare_batch, get_channel_info |
| | from unet import UNet |
| | from diffusion import GaussianDiffusion |
| |
|
| | |
| | logging.basicConfig(level=logging.WARNING) |
| | logger = logging.getLogger("train_diffusion") |
| | logger.setLevel(logging.INFO) |
| | _h = logging.StreamHandler() |
| | _h.setFormatter(logging.Formatter("%(asctime)s [%(levelname)s] %(message)s", datefmt="%H:%M:%S")) |
| | logger.addHandler(_h) |
| | logger.propagate = False |
| |
|
| | |
| | logging.getLogger("eval_utils").setLevel(logging.INFO) |
| | logging.getLogger("eval_utils").addHandler(_h) |
| | logging.getLogger("eval_utils").propagate = False |
| |
|
| |
|
| | def cosine_lr(step, warmup, total, base_lr, min_lr=1e-6): |
| | if step < warmup: |
| | return base_lr * step / max(warmup, 1) |
| | progress = (step - warmup) / max(total - warmup, 1) |
| | return min_lr + 0.5 * (base_lr - min_lr) * (1 + math.cos(progress * math.pi)) |
| |
|
| |
|
| | def train(args): |
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | logger.info(f"Device: {device}") |
| |
|
| | |
| | wandb_run = None |
| | if args.wandb: |
| | import wandb |
| |
|
| | wandb_run = wandb.init( |
| | project="the-well-diffusion", |
| | name=f"{args.dataset}_bs{args.batch_size}_lr{args.lr}", |
| | config=vars(args), |
| | ) |
| | logger.info(f"WandB run: {wandb_run.url}") |
| |
|
| | |
| | logger.info(f"Loading training data: {args.dataset} (streaming={args.streaming})") |
| | train_loader, train_dataset = create_dataloader( |
| | dataset_name=args.dataset, |
| | split="train", |
| | batch_size=args.batch_size, |
| | n_steps_input=args.n_input, |
| | n_steps_output=args.n_output, |
| | num_workers=args.workers, |
| | streaming=args.streaming, |
| | local_path=args.local_path, |
| | ) |
| |
|
| | ch_info = get_channel_info(train_dataset) |
| | logger.info(f"Channel info: {ch_info}") |
| |
|
| | c_in = ch_info["input_channels"] |
| | c_out = ch_info["output_channels"] |
| |
|
| | |
| | logger.info("Loading validation data...") |
| | val_loader, _ = create_dataloader( |
| | dataset_name=args.dataset, |
| | split="valid", |
| | batch_size=args.batch_size, |
| | n_steps_input=args.n_input, |
| | n_steps_output=args.n_output, |
| | num_workers=0, |
| | streaming=args.streaming, |
| | local_path=args.local_path, |
| | ) |
| |
|
| | |
| | logger.info(f"Loading rollout data (n_steps_output={args.n_rollout})...") |
| | rollout_loader, _ = create_dataloader( |
| | dataset_name=args.dataset, |
| | split="valid", |
| | batch_size=1, |
| | n_steps_input=args.n_input, |
| | n_steps_output=args.n_rollout, |
| | num_workers=0, |
| | streaming=args.streaming, |
| | local_path=args.local_path, |
| | ) |
| |
|
| | |
| | unet = UNet( |
| | in_channels=c_out + c_in, |
| | out_channels=c_out, |
| | base_ch=args.base_ch, |
| | ch_mults=tuple(args.ch_mults), |
| | n_res=args.n_res, |
| | attn_levels=tuple(args.attn_levels), |
| | dropout=args.dropout, |
| | ) |
| | diffusion = GaussianDiffusion(unet, timesteps=args.timesteps).to(device) |
| |
|
| | n_params = sum(p.numel() for p in diffusion.parameters() if p.requires_grad) |
| | logger.info(f"Model parameters: {n_params:,}") |
| |
|
| | if wandb_run: |
| | wandb_run.summary["n_params"] = n_params |
| |
|
| | |
| | optimizer = torch.optim.AdamW(diffusion.parameters(), lr=args.lr, weight_decay=args.wd) |
| | scaler = GradScaler("cuda", enabled=args.amp) |
| |
|
| | |
| | start_epoch = 0 |
| | global_step = 0 |
| | if args.resume and os.path.exists(args.resume): |
| | ckpt = torch.load(args.resume, map_location=device, weights_only=False) |
| | diffusion.load_state_dict(ckpt["model"]) |
| | optimizer.load_state_dict(ckpt["optimizer"]) |
| | scaler.load_state_dict(ckpt["scaler"]) |
| | start_epoch = ckpt["epoch"] + 1 |
| | global_step = ckpt["global_step"] |
| | logger.info(f"Resumed from epoch {start_epoch}, step {global_step}") |
| |
|
| | |
| | os.makedirs(args.ckpt_dir, exist_ok=True) |
| | total_steps = args.epochs * len(train_loader) |
| |
|
| | logger.info(f"Starting training: {args.epochs} epochs, ~{total_steps} steps") |
| | logger.info(f"Eval every {args.eval_every} epochs, rollout {args.n_rollout} steps") |
| |
|
| | for epoch in range(start_epoch, args.epochs): |
| | diffusion.train() |
| | epoch_loss = 0.0 |
| | n_batches = 0 |
| | t0 = time.time() |
| |
|
| | pbar = tqdm(train_loader, desc=f"Epoch {epoch}", leave=False) |
| | for batch in pbar: |
| | try: |
| | x_cond, x_target = prepare_batch(batch, device) |
| | except Exception as e: |
| | logger.warning(f"Batch error: {e}, skipping") |
| | continue |
| |
|
| | lr = cosine_lr(global_step, args.warmup, total_steps, args.lr) |
| | for pg in optimizer.param_groups: |
| | pg["lr"] = lr |
| |
|
| | optimizer.zero_grad(set_to_none=True) |
| |
|
| | with autocast(device_type="cuda", dtype=torch.bfloat16, enabled=args.amp): |
| | loss = diffusion.training_loss(x_target, x_cond) |
| |
|
| | scaler.scale(loss).backward() |
| | scaler.unscale_(optimizer) |
| | nn.utils.clip_grad_norm_(diffusion.parameters(), args.grad_clip) |
| | scaler.step(optimizer) |
| | scaler.update() |
| |
|
| | epoch_loss += loss.item() |
| | n_batches += 1 |
| | global_step += 1 |
| |
|
| | pbar.set_postfix(loss=f"{loss.item():.4f}", lr=f"{lr:.2e}") |
| |
|
| | if wandb_run and global_step % 20 == 0: |
| | wandb_run.log({"train/loss": loss.item(), "train/lr": lr}, step=global_step) |
| |
|
| | avg_loss = epoch_loss / max(n_batches, 1) |
| | elapsed = time.time() - t0 |
| | logger.info( |
| | f"Epoch {epoch}: loss={avg_loss:.4f}, batches={n_batches}, " |
| | f"time={elapsed:.1f}s, lr={lr:.2e}" |
| | ) |
| | if wandb_run: |
| | wandb_run.log({"train/epoch_loss": avg_loss, "epoch": epoch}, step=global_step) |
| |
|
| | |
| | if (epoch + 1) % args.eval_every == 0: |
| | from eval_utils import run_evaluation |
| |
|
| | logger.info("=" * 40) |
| | logger.info(f"EVALUATION at epoch {epoch}") |
| | logger.info("=" * 40) |
| |
|
| | eval_metrics = run_evaluation( |
| | model=diffusion, |
| | val_loader=val_loader, |
| | rollout_loader=rollout_loader, |
| | device=device, |
| | global_step=global_step, |
| | wandb_run=wandb_run, |
| | n_val_batches=args.eval_batches, |
| | n_rollout=args.n_rollout, |
| | ddim_steps=args.ddim_steps, |
| | ) |
| |
|
| | logger.info( |
| | f" val/mse={eval_metrics['val/mse']:.6f}, " |
| | f"rollout_mse_mean={eval_metrics['val/rollout_mse_mean']:.6f}" |
| | ) |
| | logger.info("=" * 40) |
| |
|
| | |
| | if (epoch + 1) % args.save_every == 0 or epoch == args.epochs - 1: |
| | ckpt_path = os.path.join(args.ckpt_dir, f"diffusion_ep{epoch:04d}.pt") |
| | torch.save( |
| | { |
| | "epoch": epoch, |
| | "global_step": global_step, |
| | "model": diffusion.state_dict(), |
| | "optimizer": optimizer.state_dict(), |
| | "scaler": scaler.state_dict(), |
| | "args": vars(args), |
| | "ch_info": ch_info, |
| | }, |
| | ckpt_path, |
| | ) |
| | logger.info(f"Saved {ckpt_path}") |
| |
|
| | if wandb_run: |
| | wandb_run.finish() |
| | logger.info("Training complete.") |
| |
|
| |
|
| | def main(): |
| | p = argparse.ArgumentParser(description="Train conditional DDPM on The Well") |
| | |
| | p.add_argument("--dataset", default="turbulent_radiative_layer_2D") |
| | p.add_argument("--streaming", action="store_true", default=True) |
| | p.add_argument("--no-streaming", dest="streaming", action="store_false") |
| | p.add_argument("--local_path", default=None) |
| | p.add_argument("--batch_size", type=int, default=8) |
| | p.add_argument("--workers", type=int, default=0) |
| | p.add_argument("--n_input", type=int, default=1) |
| | p.add_argument("--n_output", type=int, default=1) |
| | |
| | p.add_argument("--base_ch", type=int, default=64) |
| | p.add_argument("--ch_mults", type=int, nargs="+", default=[1, 2, 4, 8]) |
| | p.add_argument("--n_res", type=int, default=2) |
| | p.add_argument("--attn_levels", type=int, nargs="+", default=[3]) |
| | p.add_argument("--dropout", type=float, default=0.1) |
| | p.add_argument("--timesteps", type=int, default=1000) |
| | |
| | p.add_argument("--lr", type=float, default=1e-4) |
| | p.add_argument("--wd", type=float, default=0.01) |
| | p.add_argument("--warmup", type=int, default=1000) |
| | p.add_argument("--grad_clip", type=float, default=1.0) |
| | p.add_argument("--amp", action="store_true", default=True) |
| | p.add_argument("--no-amp", dest="amp", action="store_false") |
| | p.add_argument("--epochs", type=int, default=100) |
| | |
| | p.add_argument("--eval_every", type=int, default=5, help="Eval every N epochs") |
| | p.add_argument("--eval_batches", type=int, default=4, help="Val batches for MSE") |
| | p.add_argument("--n_rollout", type=int, default=20, help="Rollout steps for video") |
| | p.add_argument("--ddim_steps", type=int, default=50, help="DDIM steps for eval sampling") |
| | |
| | p.add_argument("--ckpt_dir", default="checkpoints/diffusion") |
| | p.add_argument("--save_every", type=int, default=5) |
| | p.add_argument("--resume", default=None) |
| | |
| | p.add_argument("--wandb", action="store_true", default=False) |
| |
|
| | args = p.parse_args() |
| | train(args) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|