Initial Commit

This commit is contained in:
2025-12-12 20:41:37 -06:00
commit 782d258660
11 changed files with 3464 additions and 0 deletions

15
.gitignore vendored Normal file
View File

@@ -0,0 +1,15 @@
# Python-generated files
__pycache__/
*.py[oc]
build/
dist/
wheels/
*.egg-info
# Virtual environments
.venv
audio/
data/
logs/
checkpoints/
*.wav

1
.python-version Normal file
View File

@@ -0,0 +1 @@
3.10

12
cochlear_config.json Normal file
View File

@@ -0,0 +1,12 @@
{
"sample_rate": 24000,
"min_freq": 80.0,
"max_freq": 12000.0,
"num_nodes": 32,
"q_factor": 4.0,
"segment_size_seconds": 1.0,
"warmup_seconds": 0.25,
"normalize_output": true,
"node_distribution": "log",
"custom_frequencies": null
}

154
cspace.py Normal file
View File

@@ -0,0 +1,154 @@
import torch
import torch.nn as nn
import numpy as np
import json
import math
import torch.fft
from typing import Optional, List, Tuple, Union
from dataclasses import dataclass
@dataclass
class CochlearConfig:
sample_rate: float = 44100.0
min_freq: float = 20.0
max_freq: float = 20000.0
num_nodes: int = 128
q_factor: float = 4.0
segment_size_seconds: float = 1.0
warmup_seconds: float = 0.25
normalize_output: bool = True
node_distribution: str = "log"
custom_frequencies: Optional[List[float]] = None
@classmethod
def from_json(cls, path: str):
with open(path, 'r') as f:
data = json.load(f)
return cls(**data)
class CSpace(nn.Module):
def __init__(self, config: Union[CochlearConfig, str], device: str = "cuda" if torch.cuda.is_available() else "cpu"):
super().__init__()
if isinstance(config, str):
self.config = CochlearConfig.from_json(config)
else:
self.config = config
self.device = torch.device(device)
self.setup_kernel()
def setup_kernel(self):
# 1. Frequency Setup
if self.config.node_distribution == "custom" and self.config.custom_frequencies:
freqs = np.array(self.config.custom_frequencies)
self.config.num_nodes = len(freqs)
elif self.config.node_distribution == "linear":
freqs = np.linspace(self.config.min_freq, self.config.max_freq, self.config.num_nodes)
else:
low = math.log10(self.config.min_freq)
high = math.log10(self.config.max_freq)
freqs = np.logspace(low, high, self.config.num_nodes)
self.freqs = freqs
omegas = torch.tensor(2.0 * math.pi * freqs, device=self.device, dtype=torch.float32)
gammas = omegas / (2.0 * self.config.q_factor)
seg_len = int(self.config.segment_size_seconds * self.config.sample_rate)
dt = 1.0 / self.config.sample_rate
t = torch.arange(seg_len, device=self.device, dtype=torch.float32) * dt
t = t.unsqueeze(1)
omegas_broad = omegas.unsqueeze(0)
gammas_broad = gammas.unsqueeze(0)
# 2. Create Unnormalized Decay Kernel (For State Transitions)
self.zir_decay_kernel = torch.complex(
torch.exp(-gammas_broad * t) * torch.cos(omegas_broad * t),
torch.exp(-gammas_broad * t) * torch.sin(omegas_broad * t)
)
# Calculate single-step transition factors for the boundaries
self.step_factors = self.zir_decay_kernel[1].clone()
# 3. Create Normalized Filter Kernel (For Convolution)
self.n_fft = seg_len * 2
self.impulse_response = self.zir_decay_kernel.clone()
self.kernel_fft = torch.fft.fft(self.impulse_response, n=self.n_fft, dim=0)
# Calculate Normalization Factors (Peak = 1.0)
num_bins = self.n_fft // 2 + 1
response_matrix = torch.abs(self.kernel_fft[:num_bins])
node_peak_responses = torch.zeros(self.config.num_nodes, device=self.device)
for i in range(self.config.num_nodes):
peak = torch.max(response_matrix[:, i])
node_peak_responses[i] = peak if peak > 1e-8 else 1.0
self.normalization_factors = 1.0 / node_peak_responses
# Apply normalization to the convolution kernel ONLY
self.impulse_response = self.impulse_response * self.normalization_factors.unsqueeze(0)
self.kernel_fft = torch.fft.fft(self.impulse_response, n=self.n_fft, dim=0)
def encode(self, audio_array: np.ndarray) -> torch.Tensor:
signal = torch.tensor(audio_array, dtype=torch.float32, device=self.device)
total_samples = len(signal)
seg_samples = int(self.config.segment_size_seconds * self.config.sample_rate)
try:
results = torch.empty((total_samples, self.config.num_nodes), dtype=torch.complex64, device=self.device)
except RuntimeError:
results = torch.empty((total_samples, self.config.num_nodes), dtype=torch.complex64, device="cpu")
current_state = torch.zeros(self.config.num_nodes, dtype=torch.complex64, device=self.device)
num_segments = math.ceil(total_samples / seg_samples)
for i in range(num_segments):
start = i * seg_samples
end = min(start + seg_samples, total_samples)
chunk_len = end - start
chunk = signal[start:end]
# FFT Convolution setup
if chunk_len != seg_samples:
fft_size = chunk_len + self.impulse_response.shape[0] - 1
fft_size = 2**math.ceil(math.log2(fft_size))
k_fft = torch.fft.fft(self.impulse_response[:chunk_len], n=fft_size, dim=0)
c_fft = torch.fft.fft(chunk, n=fft_size).unsqueeze(1)
else:
fft_size = self.n_fft
k_fft = self.kernel_fft
c_fft = torch.fft.fft(chunk, n=fft_size).unsqueeze(1)
conv_spectrum = c_fft * k_fft
conv_time = torch.fft.ifft(conv_spectrum, n=fft_size, dim=0)
zs_response = conv_time[:chunk_len]
# ZIR Calculation (State Decay)
decay_curve = self.zir_decay_kernel[:chunk_len]
zi_response = current_state.unsqueeze(0) * decay_curve
total_response = zs_response + zi_response
if results.device != total_response.device:
results[start:end] = total_response.to(results.device)
else:
results[start:end] = total_response
# State Update for Next Segment
current_state = total_response[-1] * self.step_factors
return results
def decode(self, cspace: torch.Tensor) -> np.ndarray:
if cspace.is_complex():
real_part = cspace.real
else:
real_part = cspace # Assume it's already real
# Simple summation of real parts, which is the correct inverse operation
recon = torch.sum(real_part, dim=1)
return recon.cpu().numpy()

116
dataset.py Normal file
View File

@@ -0,0 +1,116 @@
import torch
import numpy as np
from torch.utils.data import Dataset
from datasets import load_dataset, Audio
def collate_variable_length(batch):
"""Custom collate for variable-length C-Space tensors. Returns (padded_data, mask)."""
max_time = max([x.shape[0] for x in batch])
batch_size = len(batch)
num_nodes = batch[0].shape[1]
padded = torch.zeros(batch_size, max_time, num_nodes, dtype=torch.complex64)
mask = torch.zeros(batch_size, max_time, 1, dtype=torch.bool)
for i, seq in enumerate(batch):
seq_len = seq.shape[0]
padded[i, :seq_len, :] = seq
mask[i, :seq_len, :] = True
return padded, mask
class CSpaceDataset(Dataset):
"""Dataset that processes audio on-the-fly to C-Space representations with sliding windows."""
def __init__(self, hf_dataset, cspace_model, sample_rate=24000, segment_seconds=1.0, overlap=0.5, augment=True, max_duration=5.0, min_duration=1.0):
self.dataset = hf_dataset
self.cspace_model = cspace_model
self.sample_rate = sample_rate
self.segment_seconds = segment_seconds
self.overlap = overlap
self.augment = augment
self.max_duration = max_duration
self.min_duration = min_duration
self.segment_samples = int(segment_seconds * sample_rate)
self.hop_samples = int(self.segment_samples * (1 - overlap))
self.max_samples = int(max_duration * sample_rate)
self.min_samples = int(min_duration * sample_rate)
# Build index of (audio_idx, segment_start) pairs
self.segment_indices = []
for i in range(len(hf_dataset)):
audio_len = len(hf_dataset[i]["audio"]["array"])
if audio_len >= self.min_samples:
# Extract sliding windows
truncated_len = min(audio_len, self.max_samples)
num_segments = max(1, (truncated_len - self.segment_samples) // self.hop_samples + 1)
for seg in range(num_segments):
start = seg * self.hop_samples
self.segment_indices.append((i, start))
def __len__(self):
return len(self.segment_indices)
def __getitem__(self, idx):
audio_idx, seg_start = self.segment_indices[idx]
sample = self.dataset[audio_idx]
audio_np = np.array(sample["audio"]["array"], dtype=np.float32)
# Extract segment
seg_end = seg_start + self.segment_samples
audio_segment = audio_np[seg_start:seg_end]
# Pad if too short
if len(audio_segment) < self.segment_samples:
audio_segment = np.pad(audio_segment, (0, self.segment_samples - len(audio_segment)))
# Augmentation (phase-tolerant)
if self.augment:
audio_segment = self._augment_audio(audio_segment)
# Normalize
peak = np.max(np.abs(audio_segment))
if peak > 0:
audio_segment = audio_segment / peak
# Encode to C-Space
cspace_data = self.cspace_model.encode(audio_segment)
# Already a tensor, just ensure dtype
if isinstance(cspace_data, torch.Tensor):
return cspace_data.to(torch.complex64)
return torch.tensor(cspace_data, dtype=torch.complex64)
def _augment_audio(self, audio):
"""Phase-tolerant augmentations."""
# Random gain
if np.random.rand() > 0.5:
audio = audio * np.random.uniform(0.8, 1.0)
# Time stretch (resample)
if np.random.rand() > 0.5:
stretch_factor = np.random.uniform(0.95, 1.05)
# Unsqueeze to (1, 1, time) for interpolate
audio_t = torch.tensor(audio).unsqueeze(0).unsqueeze(0)
audio_t = torch.nn.functional.interpolate(
audio_t,
scale_factor=stretch_factor,
mode='linear',
align_corners=False
)
audio = audio_t.squeeze().numpy()
# If we stretched it longer than segment_len, crop it; if shorter, pad
# For simplicity in this logical block, we usually let the caller handle exact sizing,
# but to be safe let's just take the center or start.
target_len = len(audio)
# (Note: simpler approach is just return it and let encodings handle length,
# but the dataset expects fixed window for batches usually.
# Here we just return modified audio.)
# Add small noise (phase-tolerant)
if np.random.rand() > 0.5:
noise = np.random.normal(0, 0.01, len(audio))
audio = audio + noise
return audio

98
inference.py Normal file
View File

@@ -0,0 +1,98 @@
import torch
import numpy as np
import argparse
import soundfile as sf
import os
import torchaudio
# Local imports
from cspace import CSpace
from model import CSpaceCompressor
def load_and_prepare_audio(audio_path, target_sample_rate, device):
"""Loads, resamples, mono-mixes, and normalizes audio."""
waveform, sr = torchaudio.load(audio_path)
waveform = waveform.to(device)
# Mix to mono
if waveform.shape[0] > 1:
waveform = torch.mean(waveform, dim=0, keepdim=True)
# Resample
if sr != target_sample_rate:
resampler = torchaudio.transforms.Resample(sr, target_sample_rate).to(device)
waveform = resampler(waveform)
audio_np = waveform.squeeze().cpu().numpy().astype(np.float32)
# Normalize input
peak = np.max(np.abs(audio_np))
if peak > 0:
audio_np = audio_np / peak
return audio_np, peak
def main():
parser = argparse.ArgumentParser(description="Run inference on an audio file using CSpaceCompressor.")
parser.add_argument("input_wav", type=str, help="Path to input .wav file")
parser.add_argument("checkpoint", type=str, help="Path to model checkpoint (.pt)")
parser.add_argument("--config", type=str, default="cochlear_config.json", help="Path to config json")
parser.add_argument("--output", type=str, default="output.wav", help="Path to save output .wav")
parser.add_argument("--num-nodes", type=int, default=32, help="Must match training config")
parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
args = parser.parse_args()
print(f"Using device: {args.device}")
# 1. Load CSpace (the transformation layer)
print(f"Loading CSpace configuration from {args.config}...")
cspace_model = CSpace(config=args.config, device=args.device)
sample_rate = int(cspace_model.config.sample_rate)
# 2. Load Compressor Model
print(f"Loading model checkpoint from {args.checkpoint}...")
model = CSpaceCompressor(num_nodes=args.num_nodes).to(args.device)
state_dict = torch.load(args.checkpoint, map_location=args.device)
model.load_state_dict(state_dict)
model.eval()
# 3. Load Audio
print(f"Loading audio: {args.input_wav}")
audio_input, original_peak = load_and_prepare_audio(args.input_wav, sample_rate, args.device)
print(f"Audio loaded: {len(audio_input)} samples @ {sample_rate}Hz")
# 4. Encode to C-Space
print("Encoding to C-Space...")
# cspace_model.encode returns a complex tensor on args.device
cspace_data = cspace_model.encode(audio_input)
# Model expects (Batch, Time, Nodes), so unsqueeze batch dim
cspace_batch = cspace_data.unsqueeze(0) # -> (1, Time, Nodes)
# 5. Model Forward Pass
print("Running Compressor...")
with torch.no_grad():
reconstructed_cspace = model(cspace_batch)
# Remove batch dim
reconstructed_cspace = reconstructed_cspace.squeeze(0) # -> (Time, Nodes)
# 6. Decode back to Audio
print("Decoding C-Space to Audio...")
audio_output = cspace_model.decode(reconstructed_cspace)
# 7. Renormalize/Scale
# Normalize output to prevent clipping, but try to respect original volume dynamics if desired
# Here we just normalize the output to -1.0 to 1.0 safely.
out_peak = np.max(np.abs(audio_output))
if out_peak > 0:
audio_output = audio_output / out_peak
# 8. Save
print(f"Saving to {args.output}...")
sf.write(args.output, audio_output, sample_rate)
print("Done.")
if __name__ == "__main__":
main()

95
model.py Normal file
View File

@@ -0,0 +1,95 @@
import torch
import torch.nn as nn
class ResizeConv2d(nn.Module):
"""
Upsamples the input and then performs a convolution.
This avoids checkerboard artifacts common in ConvTranspose2d.
"""
def __init__(self, in_channels, out_channels, kernel_size, scale_factor):
super().__init__()
# 'bilinear' is usually smoother for continuous signals like spectrograms
self.upsample = nn.Upsample(scale_factor=scale_factor, mode='bilinear', align_corners=False)
# Standard convolution with padding to maintain size after upsampling
# Padding = (kernel_size - 1) // 2 for odd kernels
padding = (kernel_size[0] // 2, kernel_size[1] // 2)
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=padding)
def forward(self, x):
x = self.upsample(x)
return self.conv(x)
class CSpaceCompressor(nn.Module):
"""Convolutional compressor for C-Space as a 2D image (frequency x time)."""
def __init__(self, num_nodes=32, compression_factor=4, latent_dim=128):
super().__init__()
self.num_nodes = num_nodes
self.compression_factor = compression_factor
self.latent_dim = latent_dim
# Encoder: (batch, 2, num_nodes, time) -> Latent
# We keep the encoder largely the same, but standard Convs are fine here.
self.encoder = nn.Sequential(
nn.Conv2d(2, 32, kernel_size=(3, 5), stride=(1, 2), padding=(1, 2)),
nn.ReLU(),
nn.Conv2d(32, 64, kernel_size=(3, 5), stride=(2, 2), padding=(1, 2)),
nn.ReLU(),
nn.Conv2d(64, 128, kernel_size=(3, 5), stride=(2, 2), padding=(1, 2)),
nn.ReLU(),
nn.Conv2d(128, latent_dim, kernel_size=(3, 5), stride=(2, 1), padding=(1, 2)),
)
# Decoder: Replaced ConvTranspose2d with ResizeConv2d
# We look at the strides of the Encoder to determine scale factors for Decoder
# Encoder strides: (1,2) -> (2,2) -> (2,2) -> (2,1)
# Decoder scales: (2,1) -> (2,2) -> (2,2) -> (1,2)
self.decoder = nn.Sequential(
ResizeConv2d(latent_dim, 128, kernel_size=(3, 5), scale_factor=(2, 1)),
nn.ReLU(),
ResizeConv2d(128, 64, kernel_size=(3, 5), scale_factor=(2, 2)),
nn.ReLU(),
ResizeConv2d(64, 32, kernel_size=(3, 5), scale_factor=(2, 2)),
nn.ReLU(),
ResizeConv2d(32, 2, kernel_size=(3, 5), scale_factor=(1, 2)),
# No final ReLU, we need raw coordinate values
)
def forward(self, cspace_complex):
batch_size, time_steps, num_nodes = cspace_complex.shape
# Split into real and imaginary
real = cspace_complex.real
imag = cspace_complex.imag
# Stack: (batch, 2, time, num_nodes)
x = torch.stack([real, imag], dim=1)
# Transpose to Frequency-First for 2D Conv logic: (batch, 2, num_nodes, time)
x = x.transpose(2, 3)
# Encode
latent = self.encoder(x)
# Decode
x_recon = self.decoder(latent)
# Force exact size match (Upsampling might result in 1-2 pixel differences depending on rounding)
# We interpolate to the exact target size if slightly off
if x_recon.shape[2] != num_nodes or x_recon.shape[3] != time_steps:
x_recon = torch.nn.functional.interpolate(
x_recon,
size=(num_nodes, time_steps),
mode='bilinear',
align_corners=False
)
# Extract real and imag
real_recon = x_recon[:, 0, :, :]
imag_recon = x_recon[:, 1, :, :]
# Transpose back to (batch, time, num_nodes)
real_recon = real_recon.transpose(1, 2)
imag_recon = imag_recon.transpose(1, 2)
return torch.complex(real_recon, imag_recon)

15
pyproject.toml Normal file
View File

@@ -0,0 +1,15 @@
[project]
name = "cspace"
version = "0.1.0"
description = "Add your description here"
readme = "README.md"
requires-python = ">=3.10"
dependencies = [
"datasets>=4.4.1",
"matplotlib>=3.10.8",
"numpy>=2.2.6",
"soundfile>=0.13.1",
"torch>=2.9.1",
"torchaudio>=2.9.1",
"torchcodec>=0.9.1",
]

190
train.py Normal file
View File

@@ -0,0 +1,190 @@
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import numpy as np
import argparse
from pathlib import Path
from collections import deque
import csv
from datetime import datetime
from datasets import load_dataset, Audio
import os
# Local imports
from cspace import CSpace
from model import CSpaceCompressor
from dataset import CSpaceDataset, collate_variable_length
# Configuration Constants
DATASET_NAME = "mythicinfinity/libritts"
SUBSET = "dev"
SAMPLE_RATE = 24000
LOG_INTERVAL = 10
CHECKPOINT_DIR = Path("checkpoints")
LOGS_DIR = Path("logs")
def cspace_mse_loss(pred, target, mask=None):
"""MSE loss directly in C-Space (complex domain), optionally masked."""
error = torch.abs(pred - target) ** 2
if mask is not None:
error = error * mask
return torch.sum(error) / torch.sum(mask)
return torch.mean(error)
def cspace_l1_loss(pred, target, mask=None):
"""L1 loss directly in C-Space (complex domain), optionally masked."""
error = torch.abs(pred - target)
if mask is not None:
error = error * mask
return torch.sum(error) / torch.sum(mask)
return torch.mean(error)
def get_gradient_norm(model):
total_norm = 0.0
for p in model.parameters():
if p.grad is not None:
param_norm = p.grad.data.norm(2)
total_norm += param_norm.item() ** 2
return total_norm ** 0.5
def train_epoch(model, train_loader, optimizer, device, epoch, csv_writer, csv_file):
model.train()
mses = deque(maxlen=LOG_INTERVAL)
l1s = deque(maxlen=LOG_INTERVAL)
grad_norms = deque(maxlen=LOG_INTERVAL)
for batch_idx, (cspace_batch, mask) in enumerate(train_loader):
cspace_batch = cspace_batch.to(device)
mask = mask.to(device)
optimizer.zero_grad()
pred = model(cspace_batch)
loss = cspace_mse_loss(pred, cspace_batch, mask=mask)
loss.backward()
grad_norm = get_gradient_norm(model)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
mse = loss.item()
l1 = cspace_l1_loss(pred, cspace_batch, mask=mask).item()
mses.append(mse)
l1s.append(l1)
grad_norms.append(grad_norm)
if (batch_idx + 1) % LOG_INTERVAL == 0:
avg_mse = np.mean(list(mses))
avg_l1 = np.mean(list(l1s))
avg_grad = np.mean(list(grad_norms))
print(f"E{epoch+1} B{batch_idx+1:>4} | MSE: {avg_mse:.4f} | L1: {avg_l1:.4f} | Grad: {avg_grad:.3f}")
csv_writer.writerow({
'epoch': epoch + 1, 'batch': batch_idx + 1,
'mse': f"{avg_mse:.4f}", 'l1': f"{avg_l1:.4f}", 'grad_norm': f"{avg_grad:.3f}"
})
csv_file.flush()
return np.mean(list(mses))
def validate(model, val_loader, device, epoch, csv_writer, csv_file):
model.eval()
val_mses = []
val_l1s = []
with torch.no_grad():
for cspace_batch, mask in val_loader:
cspace_batch = cspace_batch.to(device)
mask = mask.to(device)
pred = model(cspace_batch)
mse = cspace_mse_loss(pred, cspace_batch, mask=mask).item()
l1 = cspace_l1_loss(pred, cspace_batch, mask=mask).item()
val_mses.append(mse)
val_l1s.append(l1)
avg_mse = np.mean(val_mses)
avg_l1 = np.mean(val_l1s)
print(f"VAL E{epoch+1}: MSE={avg_mse:.4f} L1={avg_l1:.4f}")
csv_writer.writerow({
'epoch': epoch + 1, 'batch': 'VAL',
'mse': f"{avg_mse:.4f}", 'l1': f"{avg_l1:.4f}", 'grad_norm': 'N/A'
})
csv_file.flush()
return avg_mse
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default="cochlear_config.json")
parser.add_argument("--batch-size", type=int, default=8)
parser.add_argument("--epochs", type=int, default=10)
parser.add_argument("--lr", type=float, default=1e-3)
parser.add_argument("--num-nodes", type=int, default=32)
args = parser.parse_args()
# 1. Setup Directories (Robustly)
CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)
LOGS_DIR.mkdir(parents=True, exist_ok=True)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")
print(f"Loading CSpace model with config: {args.config}")
cspace_model = CSpace(config=args.config, device=device)
print(f"Loading {DATASET_NAME}...")
ds = load_dataset(DATASET_NAME, SUBSET, split="dev.clean")
ds = ds.cast_column("audio", Audio(sampling_rate=SAMPLE_RATE))
split = ds.train_test_split(test_size=0.1)
train_ds = split["train"]
val_ds = split["test"]
train_dataset = CSpaceDataset(train_ds, cspace_model, sample_rate=SAMPLE_RATE, augment=True)
val_dataset = CSpaceDataset(val_ds, cspace_model, sample_rate=SAMPLE_RATE, augment=False)
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True,
collate_fn=collate_variable_length, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False,
collate_fn=collate_variable_length, num_workers=0)
print("Initializing CSpace Compressor...")
model = CSpaceCompressor(num_nodes=args.num_nodes).to(device)
optimizer = optim.Adam(model.parameters(), lr=args.lr)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs)
log_file_path = LOGS_DIR / f"training_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv"
log_file = open(log_file_path, 'w', newline='')
csv_writer = csv.DictWriter(log_file, fieldnames=['epoch', 'batch', 'mse', 'l1', 'grad_norm'])
csv_writer.writeheader()
best_val_loss = float('inf')
for epoch in range(args.epochs):
print(f"\n{'='*80}\nEpoch {epoch + 1}/{args.epochs}\n{'='*80}")
train_epoch(model, train_loader, optimizer, device, epoch, csv_writer, log_file)
val_loss = validate(model, val_loader, device, epoch, csv_writer, log_file)
scheduler.step()
# Save Logic
if val_loss < best_val_loss:
best_val_loss = val_loss
ckpt_path = CHECKPOINT_DIR / f"best_model_epoch{epoch + 1}.pt"
# PARANOID SAFETY CHECK: Ensure dir exists right before saving
if not CHECKPOINT_DIR.exists():
print(f"Warning: {CHECKPOINT_DIR} was missing. Recreating...")
CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)
torch.save(model.state_dict(), ckpt_path)
print(f"✓ Saved checkpoint: {ckpt_path}")
log_file.close()
print("Training complete.")
if __name__ == "__main__":
main()

2625
uv.lock generated Normal file

File diff suppressed because it is too large Load Diff

143
visualize.py Normal file
View File

@@ -0,0 +1,143 @@
import torch
import torchaudio
import numpy as np
import matplotlib.pyplot as plt
import argparse
from cspace import CSpace
import os
def normalize(data: np.ndarray) -> tuple[np.ndarray, float]:
"""Normalizes a numpy array to the range [-1.0, 1.0] and returns it and the peak value."""
peak = np.max(np.abs(data))
if peak == 0:
return data, 1.0
return data / peak, peak
def load_and_prepare_audio(audio_path: str, target_sample_rate: int, device: str) -> np.ndarray:
"""Loads an audio file, resamples it, and converts it to a mono float32 numpy array."""
waveform, sr = torchaudio.load(audio_path)
waveform = waveform.to(device)
# Convert to mono
if waveform.shape[0] > 1:
waveform = torch.mean(waveform, dim=0, keepdim=True)
# Resample if necessary
if sr != target_sample_rate:
resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sample_rate).to(device)
waveform = resampler(waveform)
audio_np = waveform.squeeze().cpu().numpy().astype(np.float32)
return audio_np
def save_audio(audio_data: np.ndarray, path: str, sample_rate: int):
"""Saves a numpy audio array to a WAV file."""
# torchaudio.save expects a tensor, shape (channels, samples)
audio_tensor = torch.from_numpy(audio_data).unsqueeze(0)
torchaudio.save(path, audio_tensor, sample_rate)
print(f"Reconstructed audio saved to {path}")
def print_reconstruction_stats(original_signal: np.ndarray, reconstructed_signal: np.ndarray):
"""Calculates and prints RMS and MSE between two signals."""
# Ensure signals have the same length for comparison
min_len = min(len(original_signal), len(reconstructed_signal))
original_trimmed = original_signal[:min_len]
reconstructed_trimmed = reconstructed_signal[:min_len]
# RMS Calculation
rms_original = np.sqrt(np.mean(np.square(original_trimmed)))
rms_reconstructed = np.sqrt(np.mean(np.square(reconstructed_trimmed)))
# MSE Calculation
mse = np.mean(np.square(original_trimmed - reconstructed_trimmed))
print("\n--- Reconstruction Stats ---")
print(f" - Original RMS: {rms_original:.6f}")
print(f" - Reconstructed RMS: {rms_reconstructed:.6f}")
print(f" - Mean Squared Error (MSE): {mse:.8f}")
print("--------------------------\n")
def visualize_cspace(cspace_data: torch.Tensor, freqs: np.ndarray, output_path: str, sample_rate: float):
"""Generates and saves a visualization of the C-Space data."""
magnitude_data = torch.abs(cspace_data).cpu().numpy().T
fig, ax = plt.subplots(figsize=(30, 5))
im = ax.imshow(
magnitude_data,
aspect='auto',
origin='lower',
interpolation='none',
cmap='magma'
)
fig.colorbar(im, ax=ax, label='Magnitude')
ax.set_title('C-Space Visualization')
ax.set_xlabel('Time (s)')
ax.set_ylabel('Frequency (Hz)')
num_samples = cspace_data.shape[0]
duration_s = num_samples / sample_rate
time_ticks = np.linspace(0, num_samples, num=10)
time_labels = np.linspace(0, duration_s, num=10)
ax.set_xticks(time_ticks)
ax.set_xticklabels([f'{t:.2f}' for t in time_labels])
num_nodes = len(freqs)
freq_indices = np.linspace(0, num_nodes - 1, num=10, dtype=int)
ax.set_yticks(freq_indices)
ax.set_yticklabels([f'{freqs[i]:.0f}' for i in freq_indices])
plt.tight_layout()
plt.savefig(output_path)
print(f"Visualization saved to {output_path}")
def main():
parser = argparse.ArgumentParser(
description="Generate a C-Space visualization and reconstructed audio from an audio file."
)
parser.add_argument("audio_path", type=str, help="Path to the input audio file.")
parser.add_argument("--config", type=str, default="cochlear_config.json", help="Path to the cochlear config JSON file.")
parser.add_argument("--output-dir", type=str, default=".", help="Directory to save the output files.")
parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device to use for computation (e.g., 'cuda', 'cpu').")
args = parser.parse_args()
# Create output directory if it doesn't exist
os.makedirs(args.output_dir, exist_ok=True)
# --- 1. Set up paths ---
base_name = os.path.splitext(os.path.basename(args.audio_path))[0]
viz_output_path = os.path.join(args.output_dir, f"{base_name}_cspace.png")
recon_output_path = os.path.join(args.output_dir, f"{base_name}_recon.wav")
# --- 2. Load Model and Audio ---
print(f"Loading CSpace model with config: {args.config}")
cspace_model = CSpace(config=args.config, device=args.device)
config = cspace_model.config
print(f"Loading and preparing audio from: {args.audio_path}")
original_audio_np = load_and_prepare_audio(args.audio_path, int(config.sample_rate), args.device)
original_audio_np, _ = normalize(original_audio_np)
# --- 3. Encode and Decode ---
print("Encoding audio to C-Space...")
cspace_data = cspace_model.encode(original_audio_np)
print("Decoding C-Space back to audio...")
reconstructed_audio_np = cspace_model.decode(cspace_data)
# --- 4. Normalize, Calculate Stats, and Save Files ---
reconstructed_audio_np, _ = normalize(reconstructed_audio_np)
print_reconstruction_stats(original_audio_np, reconstructed_audio_np)
save_audio(reconstructed_audio_np, recon_output_path, int(config.sample_rate))
visualize_cspace(cspace_data, cspace_model.freqs, viz_output_path, config.sample_rate)
if __name__ == "__main__":
main()