Files
cspace/train.py
2025-12-12 20:41:37 -06:00

190 lines
6.8 KiB
Python

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import numpy as np
import argparse
from pathlib import Path
from collections import deque
import csv
from datetime import datetime
from datasets import load_dataset, Audio
import os
# Local imports
from cspace import CSpace
from model import CSpaceCompressor
from dataset import CSpaceDataset, collate_variable_length
# Configuration Constants
DATASET_NAME = "mythicinfinity/libritts"
SUBSET = "dev"
SAMPLE_RATE = 24000
LOG_INTERVAL = 10
CHECKPOINT_DIR = Path("checkpoints")
LOGS_DIR = Path("logs")
def cspace_mse_loss(pred, target, mask=None):
"""MSE loss directly in C-Space (complex domain), optionally masked."""
error = torch.abs(pred - target) ** 2
if mask is not None:
error = error * mask
return torch.sum(error) / torch.sum(mask)
return torch.mean(error)
def cspace_l1_loss(pred, target, mask=None):
"""L1 loss directly in C-Space (complex domain), optionally masked."""
error = torch.abs(pred - target)
if mask is not None:
error = error * mask
return torch.sum(error) / torch.sum(mask)
return torch.mean(error)
def get_gradient_norm(model):
total_norm = 0.0
for p in model.parameters():
if p.grad is not None:
param_norm = p.grad.data.norm(2)
total_norm += param_norm.item() ** 2
return total_norm ** 0.5
def train_epoch(model, train_loader, optimizer, device, epoch, csv_writer, csv_file):
model.train()
mses = deque(maxlen=LOG_INTERVAL)
l1s = deque(maxlen=LOG_INTERVAL)
grad_norms = deque(maxlen=LOG_INTERVAL)
for batch_idx, (cspace_batch, mask) in enumerate(train_loader):
cspace_batch = cspace_batch.to(device)
mask = mask.to(device)
optimizer.zero_grad()
pred = model(cspace_batch)
loss = cspace_mse_loss(pred, cspace_batch, mask=mask)
loss.backward()
grad_norm = get_gradient_norm(model)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
mse = loss.item()
l1 = cspace_l1_loss(pred, cspace_batch, mask=mask).item()
mses.append(mse)
l1s.append(l1)
grad_norms.append(grad_norm)
if (batch_idx + 1) % LOG_INTERVAL == 0:
avg_mse = np.mean(list(mses))
avg_l1 = np.mean(list(l1s))
avg_grad = np.mean(list(grad_norms))
print(f"E{epoch+1} B{batch_idx+1:>4} | MSE: {avg_mse:.4f} | L1: {avg_l1:.4f} | Grad: {avg_grad:.3f}")
csv_writer.writerow({
'epoch': epoch + 1, 'batch': batch_idx + 1,
'mse': f"{avg_mse:.4f}", 'l1': f"{avg_l1:.4f}", 'grad_norm': f"{avg_grad:.3f}"
})
csv_file.flush()
return np.mean(list(mses))
def validate(model, val_loader, device, epoch, csv_writer, csv_file):
model.eval()
val_mses = []
val_l1s = []
with torch.no_grad():
for cspace_batch, mask in val_loader:
cspace_batch = cspace_batch.to(device)
mask = mask.to(device)
pred = model(cspace_batch)
mse = cspace_mse_loss(pred, cspace_batch, mask=mask).item()
l1 = cspace_l1_loss(pred, cspace_batch, mask=mask).item()
val_mses.append(mse)
val_l1s.append(l1)
avg_mse = np.mean(val_mses)
avg_l1 = np.mean(val_l1s)
print(f"VAL E{epoch+1}: MSE={avg_mse:.4f} L1={avg_l1:.4f}")
csv_writer.writerow({
'epoch': epoch + 1, 'batch': 'VAL',
'mse': f"{avg_mse:.4f}", 'l1': f"{avg_l1:.4f}", 'grad_norm': 'N/A'
})
csv_file.flush()
return avg_mse
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default="cochlear_config.json")
parser.add_argument("--batch-size", type=int, default=8)
parser.add_argument("--epochs", type=int, default=10)
parser.add_argument("--lr", type=float, default=1e-3)
parser.add_argument("--num-nodes", type=int, default=32)
args = parser.parse_args()
# 1. Setup Directories (Robustly)
CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)
LOGS_DIR.mkdir(parents=True, exist_ok=True)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")
print(f"Loading CSpace model with config: {args.config}")
cspace_model = CSpace(config=args.config, device=device)
print(f"Loading {DATASET_NAME}...")
ds = load_dataset(DATASET_NAME, SUBSET, split="dev.clean")
ds = ds.cast_column("audio", Audio(sampling_rate=SAMPLE_RATE))
split = ds.train_test_split(test_size=0.1)
train_ds = split["train"]
val_ds = split["test"]
train_dataset = CSpaceDataset(train_ds, cspace_model, sample_rate=SAMPLE_RATE, augment=True)
val_dataset = CSpaceDataset(val_ds, cspace_model, sample_rate=SAMPLE_RATE, augment=False)
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True,
collate_fn=collate_variable_length, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False,
collate_fn=collate_variable_length, num_workers=0)
print("Initializing CSpace Compressor...")
model = CSpaceCompressor(num_nodes=args.num_nodes).to(device)
optimizer = optim.Adam(model.parameters(), lr=args.lr)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs)
log_file_path = LOGS_DIR / f"training_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv"
log_file = open(log_file_path, 'w', newline='')
csv_writer = csv.DictWriter(log_file, fieldnames=['epoch', 'batch', 'mse', 'l1', 'grad_norm'])
csv_writer.writeheader()
best_val_loss = float('inf')
for epoch in range(args.epochs):
print(f"\n{'='*80}\nEpoch {epoch + 1}/{args.epochs}\n{'='*80}")
train_epoch(model, train_loader, optimizer, device, epoch, csv_writer, log_file)
val_loss = validate(model, val_loader, device, epoch, csv_writer, log_file)
scheduler.step()
# Save Logic
if val_loss < best_val_loss:
best_val_loss = val_loss
ckpt_path = CHECKPOINT_DIR / f"best_model_epoch{epoch + 1}.pt"
# PARANOID SAFETY CHECK: Ensure dir exists right before saving
if not CHECKPOINT_DIR.exists():
print(f"Warning: {CHECKPOINT_DIR} was missing. Recreating...")
CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)
torch.save(model.state_dict(), ckpt_path)
print(f"✓ Saved checkpoint: {ckpt_path}")
log_file.close()
print("Training complete.")
if __name__ == "__main__":
main()