import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader import numpy as np import argparse from pathlib import Path from collections import deque import csv from datetime import datetime from datasets import load_dataset, Audio import os # Local imports from cspace import CSpace from model import CSpaceCompressor from dataset import CSpaceDataset, collate_variable_length # Configuration Constants DATASET_NAME = "mythicinfinity/libritts" SUBSET = "dev" SAMPLE_RATE = 24000 LOG_INTERVAL = 10 CHECKPOINT_DIR = Path("checkpoints") LOGS_DIR = Path("logs") def cspace_mse_loss(pred, target, mask=None): """MSE loss directly in C-Space (complex domain), optionally masked.""" error = torch.abs(pred - target) ** 2 if mask is not None: error = error * mask return torch.sum(error) / torch.sum(mask) return torch.mean(error) def cspace_l1_loss(pred, target, mask=None): """L1 loss directly in C-Space (complex domain), optionally masked.""" error = torch.abs(pred - target) if mask is not None: error = error * mask return torch.sum(error) / torch.sum(mask) return torch.mean(error) def get_gradient_norm(model): total_norm = 0.0 for p in model.parameters(): if p.grad is not None: param_norm = p.grad.data.norm(2) total_norm += param_norm.item() ** 2 return total_norm ** 0.5 def train_epoch(model, train_loader, optimizer, device, epoch, csv_writer, csv_file): model.train() mses = deque(maxlen=LOG_INTERVAL) l1s = deque(maxlen=LOG_INTERVAL) grad_norms = deque(maxlen=LOG_INTERVAL) for batch_idx, (cspace_batch, mask) in enumerate(train_loader): cspace_batch = cspace_batch.to(device) mask = mask.to(device) optimizer.zero_grad() pred = model(cspace_batch) loss = cspace_mse_loss(pred, cspace_batch, mask=mask) loss.backward() grad_norm = get_gradient_norm(model) torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) optimizer.step() mse = loss.item() l1 = cspace_l1_loss(pred, cspace_batch, mask=mask).item() mses.append(mse) l1s.append(l1) grad_norms.append(grad_norm) if (batch_idx + 1) % LOG_INTERVAL == 0: avg_mse = np.mean(list(mses)) avg_l1 = np.mean(list(l1s)) avg_grad = np.mean(list(grad_norms)) print(f"E{epoch+1} B{batch_idx+1:>4} | MSE: {avg_mse:.4f} | L1: {avg_l1:.4f} | Grad: {avg_grad:.3f}") csv_writer.writerow({ 'epoch': epoch + 1, 'batch': batch_idx + 1, 'mse': f"{avg_mse:.4f}", 'l1': f"{avg_l1:.4f}", 'grad_norm': f"{avg_grad:.3f}" }) csv_file.flush() return np.mean(list(mses)) def validate(model, val_loader, device, epoch, csv_writer, csv_file): model.eval() val_mses = [] val_l1s = [] with torch.no_grad(): for cspace_batch, mask in val_loader: cspace_batch = cspace_batch.to(device) mask = mask.to(device) pred = model(cspace_batch) mse = cspace_mse_loss(pred, cspace_batch, mask=mask).item() l1 = cspace_l1_loss(pred, cspace_batch, mask=mask).item() val_mses.append(mse) val_l1s.append(l1) avg_mse = np.mean(val_mses) avg_l1 = np.mean(val_l1s) print(f"VAL E{epoch+1}: MSE={avg_mse:.4f} L1={avg_l1:.4f}") csv_writer.writerow({ 'epoch': epoch + 1, 'batch': 'VAL', 'mse': f"{avg_mse:.4f}", 'l1': f"{avg_l1:.4f}", 'grad_norm': 'N/A' }) csv_file.flush() return avg_mse def main(): parser = argparse.ArgumentParser() parser.add_argument("--config", type=str, default="cochlear_config.json") parser.add_argument("--batch-size", type=int, default=8) parser.add_argument("--epochs", type=int, default=10) parser.add_argument("--lr", type=float, default=1e-3) parser.add_argument("--num-nodes", type=int, default=32) args = parser.parse_args() # 1. Setup Directories (Robustly) CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True) LOGS_DIR.mkdir(parents=True, exist_ok=True) device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Device: {device}") print(f"Loading CSpace model with config: {args.config}") cspace_model = CSpace(config=args.config, device=device) print(f"Loading {DATASET_NAME}...") ds = load_dataset(DATASET_NAME, SUBSET, split="dev.clean") ds = ds.cast_column("audio", Audio(sampling_rate=SAMPLE_RATE)) split = ds.train_test_split(test_size=0.1) train_ds = split["train"] val_ds = split["test"] train_dataset = CSpaceDataset(train_ds, cspace_model, sample_rate=SAMPLE_RATE, augment=True) val_dataset = CSpaceDataset(val_ds, cspace_model, sample_rate=SAMPLE_RATE, augment=False) train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, collate_fn=collate_variable_length, num_workers=0) val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, collate_fn=collate_variable_length, num_workers=0) print("Initializing CSpace Compressor...") model = CSpaceCompressor(num_nodes=args.num_nodes).to(device) optimizer = optim.Adam(model.parameters(), lr=args.lr) scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs) log_file_path = LOGS_DIR / f"training_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv" log_file = open(log_file_path, 'w', newline='') csv_writer = csv.DictWriter(log_file, fieldnames=['epoch', 'batch', 'mse', 'l1', 'grad_norm']) csv_writer.writeheader() best_val_loss = float('inf') for epoch in range(args.epochs): print(f"\n{'='*80}\nEpoch {epoch + 1}/{args.epochs}\n{'='*80}") train_epoch(model, train_loader, optimizer, device, epoch, csv_writer, log_file) val_loss = validate(model, val_loader, device, epoch, csv_writer, log_file) scheduler.step() # Save Logic if val_loss < best_val_loss: best_val_loss = val_loss ckpt_path = CHECKPOINT_DIR / f"best_model_epoch{epoch + 1}.pt" # PARANOID SAFETY CHECK: Ensure dir exists right before saving if not CHECKPOINT_DIR.exists(): print(f"Warning: {CHECKPOINT_DIR} was missing. Recreating...") CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True) torch.save(model.state_dict(), ckpt_path) print(f"✓ Saved checkpoint: {ckpt_path}") log_file.close() print("Training complete.") if __name__ == "__main__": main()