| |
| """ |
| Training script for Spatial JEPA on The Well datasets. |
| |
| Usage: |
| python train_jepa.py --dataset turbulent_radiative_layer_2D --batch_size 16 |
| python train_jepa.py --dataset active_matter --streaming --epochs 50 |
| """ |
| import argparse |
| import logging |
| import math |
| import os |
| import time |
|
|
| import torch |
| import torch.nn as nn |
| from torch.amp import GradScaler, autocast |
| from tqdm import tqdm |
|
|
| from data_pipeline import create_dataloader, prepare_batch, get_channel_info |
| from jepa import JEPA |
|
|
| logging.basicConfig(level=logging.WARNING) |
| logger = logging.getLogger("train_jepa") |
| logger.setLevel(logging.INFO) |
| _handler = logging.StreamHandler() |
| _handler.setFormatter(logging.Formatter("%(asctime)s [%(levelname)s] %(message)s", datefmt="%H:%M:%S")) |
| logger.addHandler(_handler) |
| logger.propagate = False |
|
|
|
|
| def cosine_lr(step, warmup, total, base_lr, min_lr=1e-6): |
| if step < warmup: |
| return base_lr * step / max(warmup, 1) |
| progress = (step - warmup) / max(total - warmup, 1) |
| return min_lr + 0.5 * (base_lr - min_lr) * (1 + math.cos(progress * math.pi)) |
|
|
|
|
| def cosine_ema(step, total, start=0.996, end=1.0): |
| """EMA decay schedule: ramps from start to end over training.""" |
| progress = step / max(total, 1) |
| return end - (end - start) * (1 + math.cos(progress * math.pi)) / 2 |
|
|
|
|
| def train(args): |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| logger.info(f"Device: {device}") |
|
|
| |
| logger.info(f"Loading dataset: {args.dataset} (streaming={args.streaming})") |
| train_loader, train_dataset = create_dataloader( |
| dataset_name=args.dataset, |
| split="train", |
| batch_size=args.batch_size, |
| n_steps_input=args.n_input, |
| n_steps_output=args.n_output, |
| num_workers=args.workers, |
| streaming=args.streaming, |
| local_path=args.local_path, |
| ) |
|
|
| ch_info = get_channel_info(train_dataset) |
| logger.info(f"Channel info: {ch_info}") |
|
|
| c_in = ch_info["input_channels"] |
| c_out = ch_info["output_channels"] |
|
|
| |
| |
| assert c_in == c_out, ( |
| f"JEPA expects same input/output channels, got {c_in} vs {c_out}. " |
| "Set n_input == n_output or use different architecture." |
| ) |
|
|
| |
| model = JEPA( |
| in_channels=c_in, |
| latent_channels=args.latent_ch, |
| base_ch=args.base_ch, |
| pred_hidden=args.pred_hidden, |
| ema_decay=args.ema_start, |
| ).to(device) |
|
|
| n_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| logger.info(f"Trainable parameters: {n_params:,}") |
|
|
| |
| |
| trainable = list(model.online_encoder.parameters()) + list(model.predictor.parameters()) |
| optimizer = torch.optim.AdamW(trainable, lr=args.lr, weight_decay=args.wd) |
| scaler = GradScaler("cuda", enabled=args.amp) |
|
|
| |
| start_epoch = 0 |
| global_step = 0 |
| if args.resume and os.path.exists(args.resume): |
| ckpt = torch.load(args.resume, map_location=device, weights_only=False) |
| model.load_state_dict(ckpt["model"]) |
| optimizer.load_state_dict(ckpt["optimizer"]) |
| scaler.load_state_dict(ckpt["scaler"]) |
| start_epoch = ckpt["epoch"] + 1 |
| global_step = ckpt["global_step"] |
| logger.info(f"Resumed from epoch {start_epoch}, step {global_step}") |
|
|
| |
| os.makedirs(args.ckpt_dir, exist_ok=True) |
| total_steps = args.epochs * len(train_loader) |
|
|
| try: |
| import wandb |
|
|
| if args.wandb: |
| wandb.init(project="the-well-jepa", config=vars(args)) |
| except ImportError: |
| args.wandb = False |
|
|
| logger.info(f"Starting training: {args.epochs} epochs, ~{total_steps} steps") |
|
|
| for epoch in range(start_epoch, args.epochs): |
| model.train() |
| epoch_loss = 0.0 |
| epoch_metrics = {"sim": 0, "var": 0, "cov": 0} |
| n_batches = 0 |
| t0 = time.time() |
|
|
| pbar = tqdm(train_loader, desc=f"Epoch {epoch}", leave=False) |
| for batch in pbar: |
| try: |
| x_input, x_target = prepare_batch(batch, device) |
| except Exception as e: |
| logger.warning(f"Batch error: {e}, skipping") |
| continue |
|
|
| |
| lr = cosine_lr(global_step, args.warmup, total_steps, args.lr) |
| for pg in optimizer.param_groups: |
| pg["lr"] = lr |
|
|
| |
| ema = cosine_ema(global_step, total_steps, args.ema_start, args.ema_end) |
| model.set_ema_decay(ema) |
|
|
| optimizer.zero_grad(set_to_none=True) |
|
|
| with autocast(device_type="cuda", dtype=torch.bfloat16, enabled=args.amp): |
| loss, metrics = model.compute_loss(x_input, x_target) |
|
|
| scaler.scale(loss).backward() |
| scaler.unscale_(optimizer) |
| nn.utils.clip_grad_norm_(trainable, args.grad_clip) |
| scaler.step(optimizer) |
| scaler.update() |
|
|
| |
| model.update_target() |
|
|
| epoch_loss += loss.item() |
| for k in epoch_metrics: |
| epoch_metrics[k] += metrics[k] |
| n_batches += 1 |
| global_step += 1 |
|
|
| pbar.set_postfix( |
| loss=f"{loss.item():.4f}", |
| sim=f"{metrics['sim']:.4f}", |
| ema=f"{ema:.4f}", |
| ) |
|
|
| if args.wandb: |
| wandb.log( |
| {"train/loss": loss.item(), "train/lr": lr, "train/ema": ema, **{f"train/{k}": v for k, v in metrics.items()}}, |
| step=global_step, |
| ) |
|
|
| avg_loss = epoch_loss / max(n_batches, 1) |
| avg_m = {k: v / max(n_batches, 1) for k, v in epoch_metrics.items()} |
| elapsed = time.time() - t0 |
| logger.info( |
| f"Epoch {epoch}: loss={avg_loss:.4f}, sim={avg_m['sim']:.4f}, " |
| f"var={avg_m['var']:.4f}, cov={avg_m['cov']:.4f}, " |
| f"time={elapsed:.1f}s" |
| ) |
|
|
| |
| if (epoch + 1) % args.save_every == 0 or epoch == args.epochs - 1: |
| ckpt_path = os.path.join(args.ckpt_dir, f"jepa_ep{epoch:04d}.pt") |
| torch.save( |
| { |
| "epoch": epoch, |
| "global_step": global_step, |
| "model": model.state_dict(), |
| "optimizer": optimizer.state_dict(), |
| "scaler": scaler.state_dict(), |
| "args": vars(args), |
| "ch_info": ch_info, |
| }, |
| ckpt_path, |
| ) |
| logger.info(f"Saved {ckpt_path}") |
|
|
| logger.info("Training complete.") |
|
|
|
|
| def main(): |
| p = argparse.ArgumentParser(description="Train Spatial JEPA on The Well") |
| |
| p.add_argument("--dataset", default="turbulent_radiative_layer_2D") |
| p.add_argument("--streaming", action="store_true", default=True) |
| p.add_argument("--no-streaming", dest="streaming", action="store_false") |
| p.add_argument("--local_path", default=None) |
| p.add_argument("--batch_size", type=int, default=16) |
| p.add_argument("--workers", type=int, default=0) |
| p.add_argument("--n_input", type=int, default=1) |
| p.add_argument("--n_output", type=int, default=1) |
| |
| p.add_argument("--latent_ch", type=int, default=128) |
| p.add_argument("--base_ch", type=int, default=32) |
| p.add_argument("--pred_hidden", type=int, default=256) |
| |
| p.add_argument("--lr", type=float, default=3e-4) |
| p.add_argument("--wd", type=float, default=0.05) |
| p.add_argument("--warmup", type=int, default=500) |
| p.add_argument("--grad_clip", type=float, default=1.0) |
| p.add_argument("--amp", action="store_true", default=True) |
| p.add_argument("--no-amp", dest="amp", action="store_false") |
| p.add_argument("--epochs", type=int, default=100) |
| p.add_argument("--ema_start", type=float, default=0.996) |
| p.add_argument("--ema_end", type=float, default=1.0) |
| |
| p.add_argument("--ckpt_dir", default="checkpoints/jepa") |
| p.add_argument("--save_every", type=int, default=5) |
| p.add_argument("--resume", default=None) |
| |
| p.add_argument("--wandb", action="store_true", default=False) |
|
|
| args = p.parse_args() |
| train(args) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|