Initial Commit
This commit is contained in:
190
train.py
Normal file
190
train.py
Normal file
@@ -0,0 +1,190 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch.utils.data import DataLoader
|
||||
import numpy as np
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
from collections import deque
|
||||
import csv
|
||||
from datetime import datetime
|
||||
from datasets import load_dataset, Audio
|
||||
import os
|
||||
|
||||
# Local imports
|
||||
from cspace import CSpace
|
||||
from model import CSpaceCompressor
|
||||
from dataset import CSpaceDataset, collate_variable_length
|
||||
|
||||
# Configuration Constants
|
||||
DATASET_NAME = "mythicinfinity/libritts"
|
||||
SUBSET = "dev"
|
||||
SAMPLE_RATE = 24000
|
||||
LOG_INTERVAL = 10
|
||||
CHECKPOINT_DIR = Path("checkpoints")
|
||||
LOGS_DIR = Path("logs")
|
||||
|
||||
def cspace_mse_loss(pred, target, mask=None):
|
||||
"""MSE loss directly in C-Space (complex domain), optionally masked."""
|
||||
error = torch.abs(pred - target) ** 2
|
||||
if mask is not None:
|
||||
error = error * mask
|
||||
return torch.sum(error) / torch.sum(mask)
|
||||
return torch.mean(error)
|
||||
|
||||
def cspace_l1_loss(pred, target, mask=None):
|
||||
"""L1 loss directly in C-Space (complex domain), optionally masked."""
|
||||
error = torch.abs(pred - target)
|
||||
if mask is not None:
|
||||
error = error * mask
|
||||
return torch.sum(error) / torch.sum(mask)
|
||||
return torch.mean(error)
|
||||
|
||||
def get_gradient_norm(model):
|
||||
total_norm = 0.0
|
||||
for p in model.parameters():
|
||||
if p.grad is not None:
|
||||
param_norm = p.grad.data.norm(2)
|
||||
total_norm += param_norm.item() ** 2
|
||||
return total_norm ** 0.5
|
||||
|
||||
def train_epoch(model, train_loader, optimizer, device, epoch, csv_writer, csv_file):
|
||||
model.train()
|
||||
mses = deque(maxlen=LOG_INTERVAL)
|
||||
l1s = deque(maxlen=LOG_INTERVAL)
|
||||
grad_norms = deque(maxlen=LOG_INTERVAL)
|
||||
|
||||
for batch_idx, (cspace_batch, mask) in enumerate(train_loader):
|
||||
cspace_batch = cspace_batch.to(device)
|
||||
mask = mask.to(device)
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
pred = model(cspace_batch)
|
||||
loss = cspace_mse_loss(pred, cspace_batch, mask=mask)
|
||||
|
||||
loss.backward()
|
||||
grad_norm = get_gradient_norm(model)
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
||||
optimizer.step()
|
||||
|
||||
mse = loss.item()
|
||||
l1 = cspace_l1_loss(pred, cspace_batch, mask=mask).item()
|
||||
|
||||
mses.append(mse)
|
||||
l1s.append(l1)
|
||||
grad_norms.append(grad_norm)
|
||||
|
||||
if (batch_idx + 1) % LOG_INTERVAL == 0:
|
||||
avg_mse = np.mean(list(mses))
|
||||
avg_l1 = np.mean(list(l1s))
|
||||
avg_grad = np.mean(list(grad_norms))
|
||||
|
||||
print(f"E{epoch+1} B{batch_idx+1:>4} | MSE: {avg_mse:.4f} | L1: {avg_l1:.4f} | Grad: {avg_grad:.3f}")
|
||||
|
||||
csv_writer.writerow({
|
||||
'epoch': epoch + 1, 'batch': batch_idx + 1,
|
||||
'mse': f"{avg_mse:.4f}", 'l1': f"{avg_l1:.4f}", 'grad_norm': f"{avg_grad:.3f}"
|
||||
})
|
||||
csv_file.flush()
|
||||
|
||||
return np.mean(list(mses))
|
||||
|
||||
def validate(model, val_loader, device, epoch, csv_writer, csv_file):
|
||||
model.eval()
|
||||
val_mses = []
|
||||
val_l1s = []
|
||||
|
||||
with torch.no_grad():
|
||||
for cspace_batch, mask in val_loader:
|
||||
cspace_batch = cspace_batch.to(device)
|
||||
mask = mask.to(device)
|
||||
pred = model(cspace_batch)
|
||||
mse = cspace_mse_loss(pred, cspace_batch, mask=mask).item()
|
||||
l1 = cspace_l1_loss(pred, cspace_batch, mask=mask).item()
|
||||
val_mses.append(mse)
|
||||
val_l1s.append(l1)
|
||||
|
||||
avg_mse = np.mean(val_mses)
|
||||
avg_l1 = np.mean(val_l1s)
|
||||
print(f"VAL E{epoch+1}: MSE={avg_mse:.4f} L1={avg_l1:.4f}")
|
||||
|
||||
csv_writer.writerow({
|
||||
'epoch': epoch + 1, 'batch': 'VAL',
|
||||
'mse': f"{avg_mse:.4f}", 'l1': f"{avg_l1:.4f}", 'grad_norm': 'N/A'
|
||||
})
|
||||
csv_file.flush()
|
||||
return avg_mse
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--config", type=str, default="cochlear_config.json")
|
||||
parser.add_argument("--batch-size", type=int, default=8)
|
||||
parser.add_argument("--epochs", type=int, default=10)
|
||||
parser.add_argument("--lr", type=float, default=1e-3)
|
||||
parser.add_argument("--num-nodes", type=int, default=32)
|
||||
args = parser.parse_args()
|
||||
|
||||
# 1. Setup Directories (Robustly)
|
||||
CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)
|
||||
LOGS_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
print(f"Device: {device}")
|
||||
|
||||
print(f"Loading CSpace model with config: {args.config}")
|
||||
cspace_model = CSpace(config=args.config, device=device)
|
||||
|
||||
print(f"Loading {DATASET_NAME}...")
|
||||
ds = load_dataset(DATASET_NAME, SUBSET, split="dev.clean")
|
||||
ds = ds.cast_column("audio", Audio(sampling_rate=SAMPLE_RATE))
|
||||
|
||||
split = ds.train_test_split(test_size=0.1)
|
||||
train_ds = split["train"]
|
||||
val_ds = split["test"]
|
||||
|
||||
train_dataset = CSpaceDataset(train_ds, cspace_model, sample_rate=SAMPLE_RATE, augment=True)
|
||||
val_dataset = CSpaceDataset(val_ds, cspace_model, sample_rate=SAMPLE_RATE, augment=False)
|
||||
|
||||
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True,
|
||||
collate_fn=collate_variable_length, num_workers=0)
|
||||
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False,
|
||||
collate_fn=collate_variable_length, num_workers=0)
|
||||
|
||||
print("Initializing CSpace Compressor...")
|
||||
model = CSpaceCompressor(num_nodes=args.num_nodes).to(device)
|
||||
|
||||
optimizer = optim.Adam(model.parameters(), lr=args.lr)
|
||||
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs)
|
||||
|
||||
log_file_path = LOGS_DIR / f"training_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv"
|
||||
log_file = open(log_file_path, 'w', newline='')
|
||||
csv_writer = csv.DictWriter(log_file, fieldnames=['epoch', 'batch', 'mse', 'l1', 'grad_norm'])
|
||||
csv_writer.writeheader()
|
||||
|
||||
best_val_loss = float('inf')
|
||||
|
||||
for epoch in range(args.epochs):
|
||||
print(f"\n{'='*80}\nEpoch {epoch + 1}/{args.epochs}\n{'='*80}")
|
||||
train_epoch(model, train_loader, optimizer, device, epoch, csv_writer, log_file)
|
||||
val_loss = validate(model, val_loader, device, epoch, csv_writer, log_file)
|
||||
scheduler.step()
|
||||
|
||||
# Save Logic
|
||||
if val_loss < best_val_loss:
|
||||
best_val_loss = val_loss
|
||||
ckpt_path = CHECKPOINT_DIR / f"best_model_epoch{epoch + 1}.pt"
|
||||
|
||||
# PARANOID SAFETY CHECK: Ensure dir exists right before saving
|
||||
if not CHECKPOINT_DIR.exists():
|
||||
print(f"Warning: {CHECKPOINT_DIR} was missing. Recreating...")
|
||||
CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
torch.save(model.state_dict(), ckpt_path)
|
||||
print(f"✓ Saved checkpoint: {ckpt_path}")
|
||||
|
||||
log_file.close()
|
||||
print("Training complete.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user