Initial Commit
This commit is contained in:
15
.gitignore
vendored
Normal file
15
.gitignore
vendored
Normal file
@@ -0,0 +1,15 @@
|
||||
# Python-generated files
|
||||
__pycache__/
|
||||
*.py[oc]
|
||||
build/
|
||||
dist/
|
||||
wheels/
|
||||
*.egg-info
|
||||
|
||||
# Virtual environments
|
||||
.venv
|
||||
audio/
|
||||
data/
|
||||
logs/
|
||||
checkpoints/
|
||||
*.wav
|
||||
1
.python-version
Normal file
1
.python-version
Normal file
@@ -0,0 +1 @@
|
||||
3.10
|
||||
12
cochlear_config.json
Normal file
12
cochlear_config.json
Normal file
@@ -0,0 +1,12 @@
|
||||
{
|
||||
"sample_rate": 24000,
|
||||
"min_freq": 80.0,
|
||||
"max_freq": 12000.0,
|
||||
"num_nodes": 32,
|
||||
"q_factor": 4.0,
|
||||
"segment_size_seconds": 1.0,
|
||||
"warmup_seconds": 0.25,
|
||||
"normalize_output": true,
|
||||
"node_distribution": "log",
|
||||
"custom_frequencies": null
|
||||
}
|
||||
154
cspace.py
Normal file
154
cspace.py
Normal file
@@ -0,0 +1,154 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
import json
|
||||
import math
|
||||
import torch.fft
|
||||
from typing import Optional, List, Tuple, Union
|
||||
from dataclasses import dataclass
|
||||
|
||||
@dataclass
|
||||
class CochlearConfig:
|
||||
sample_rate: float = 44100.0
|
||||
min_freq: float = 20.0
|
||||
max_freq: float = 20000.0
|
||||
num_nodes: int = 128
|
||||
q_factor: float = 4.0
|
||||
segment_size_seconds: float = 1.0
|
||||
warmup_seconds: float = 0.25
|
||||
normalize_output: bool = True
|
||||
node_distribution: str = "log"
|
||||
custom_frequencies: Optional[List[float]] = None
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, path: str):
|
||||
with open(path, 'r') as f:
|
||||
data = json.load(f)
|
||||
return cls(**data)
|
||||
|
||||
class CSpace(nn.Module):
|
||||
def __init__(self, config: Union[CochlearConfig, str], device: str = "cuda" if torch.cuda.is_available() else "cpu"):
|
||||
super().__init__()
|
||||
|
||||
if isinstance(config, str):
|
||||
self.config = CochlearConfig.from_json(config)
|
||||
else:
|
||||
self.config = config
|
||||
|
||||
self.device = torch.device(device)
|
||||
self.setup_kernel()
|
||||
|
||||
def setup_kernel(self):
|
||||
# 1. Frequency Setup
|
||||
if self.config.node_distribution == "custom" and self.config.custom_frequencies:
|
||||
freqs = np.array(self.config.custom_frequencies)
|
||||
self.config.num_nodes = len(freqs)
|
||||
elif self.config.node_distribution == "linear":
|
||||
freqs = np.linspace(self.config.min_freq, self.config.max_freq, self.config.num_nodes)
|
||||
else:
|
||||
low = math.log10(self.config.min_freq)
|
||||
high = math.log10(self.config.max_freq)
|
||||
freqs = np.logspace(low, high, self.config.num_nodes)
|
||||
|
||||
self.freqs = freqs
|
||||
omegas = torch.tensor(2.0 * math.pi * freqs, device=self.device, dtype=torch.float32)
|
||||
gammas = omegas / (2.0 * self.config.q_factor)
|
||||
|
||||
seg_len = int(self.config.segment_size_seconds * self.config.sample_rate)
|
||||
dt = 1.0 / self.config.sample_rate
|
||||
t = torch.arange(seg_len, device=self.device, dtype=torch.float32) * dt
|
||||
|
||||
t = t.unsqueeze(1)
|
||||
omegas_broad = omegas.unsqueeze(0)
|
||||
gammas_broad = gammas.unsqueeze(0)
|
||||
|
||||
# 2. Create Unnormalized Decay Kernel (For State Transitions)
|
||||
self.zir_decay_kernel = torch.complex(
|
||||
torch.exp(-gammas_broad * t) * torch.cos(omegas_broad * t),
|
||||
torch.exp(-gammas_broad * t) * torch.sin(omegas_broad * t)
|
||||
)
|
||||
|
||||
# Calculate single-step transition factors for the boundaries
|
||||
self.step_factors = self.zir_decay_kernel[1].clone()
|
||||
|
||||
# 3. Create Normalized Filter Kernel (For Convolution)
|
||||
self.n_fft = seg_len * 2
|
||||
|
||||
self.impulse_response = self.zir_decay_kernel.clone()
|
||||
self.kernel_fft = torch.fft.fft(self.impulse_response, n=self.n_fft, dim=0)
|
||||
|
||||
# Calculate Normalization Factors (Peak = 1.0)
|
||||
num_bins = self.n_fft // 2 + 1
|
||||
response_matrix = torch.abs(self.kernel_fft[:num_bins])
|
||||
|
||||
node_peak_responses = torch.zeros(self.config.num_nodes, device=self.device)
|
||||
for i in range(self.config.num_nodes):
|
||||
peak = torch.max(response_matrix[:, i])
|
||||
node_peak_responses[i] = peak if peak > 1e-8 else 1.0
|
||||
|
||||
self.normalization_factors = 1.0 / node_peak_responses
|
||||
|
||||
# Apply normalization to the convolution kernel ONLY
|
||||
self.impulse_response = self.impulse_response * self.normalization_factors.unsqueeze(0)
|
||||
self.kernel_fft = torch.fft.fft(self.impulse_response, n=self.n_fft, dim=0)
|
||||
|
||||
def encode(self, audio_array: np.ndarray) -> torch.Tensor:
|
||||
signal = torch.tensor(audio_array, dtype=torch.float32, device=self.device)
|
||||
total_samples = len(signal)
|
||||
seg_samples = int(self.config.segment_size_seconds * self.config.sample_rate)
|
||||
|
||||
try:
|
||||
results = torch.empty((total_samples, self.config.num_nodes), dtype=torch.complex64, device=self.device)
|
||||
except RuntimeError:
|
||||
results = torch.empty((total_samples, self.config.num_nodes), dtype=torch.complex64, device="cpu")
|
||||
|
||||
current_state = torch.zeros(self.config.num_nodes, dtype=torch.complex64, device=self.device)
|
||||
num_segments = math.ceil(total_samples / seg_samples)
|
||||
|
||||
for i in range(num_segments):
|
||||
start = i * seg_samples
|
||||
end = min(start + seg_samples, total_samples)
|
||||
chunk_len = end - start
|
||||
chunk = signal[start:end]
|
||||
|
||||
# FFT Convolution setup
|
||||
if chunk_len != seg_samples:
|
||||
fft_size = chunk_len + self.impulse_response.shape[0] - 1
|
||||
fft_size = 2**math.ceil(math.log2(fft_size))
|
||||
k_fft = torch.fft.fft(self.impulse_response[:chunk_len], n=fft_size, dim=0)
|
||||
c_fft = torch.fft.fft(chunk, n=fft_size).unsqueeze(1)
|
||||
else:
|
||||
fft_size = self.n_fft
|
||||
k_fft = self.kernel_fft
|
||||
c_fft = torch.fft.fft(chunk, n=fft_size).unsqueeze(1)
|
||||
|
||||
conv_spectrum = c_fft * k_fft
|
||||
conv_time = torch.fft.ifft(conv_spectrum, n=fft_size, dim=0)
|
||||
zs_response = conv_time[:chunk_len]
|
||||
|
||||
# ZIR Calculation (State Decay)
|
||||
decay_curve = self.zir_decay_kernel[:chunk_len]
|
||||
zi_response = current_state.unsqueeze(0) * decay_curve
|
||||
|
||||
total_response = zs_response + zi_response
|
||||
|
||||
if results.device != total_response.device:
|
||||
results[start:end] = total_response.to(results.device)
|
||||
else:
|
||||
results[start:end] = total_response
|
||||
|
||||
# State Update for Next Segment
|
||||
current_state = total_response[-1] * self.step_factors
|
||||
|
||||
return results
|
||||
|
||||
def decode(self, cspace: torch.Tensor) -> np.ndarray:
|
||||
if cspace.is_complex():
|
||||
real_part = cspace.real
|
||||
else:
|
||||
real_part = cspace # Assume it's already real
|
||||
|
||||
# Simple summation of real parts, which is the correct inverse operation
|
||||
recon = torch.sum(real_part, dim=1)
|
||||
|
||||
return recon.cpu().numpy()
|
||||
116
dataset.py
Normal file
116
dataset.py
Normal file
@@ -0,0 +1,116 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from torch.utils.data import Dataset
|
||||
from datasets import load_dataset, Audio
|
||||
|
||||
def collate_variable_length(batch):
|
||||
"""Custom collate for variable-length C-Space tensors. Returns (padded_data, mask)."""
|
||||
max_time = max([x.shape[0] for x in batch])
|
||||
batch_size = len(batch)
|
||||
num_nodes = batch[0].shape[1]
|
||||
|
||||
padded = torch.zeros(batch_size, max_time, num_nodes, dtype=torch.complex64)
|
||||
mask = torch.zeros(batch_size, max_time, 1, dtype=torch.bool)
|
||||
|
||||
for i, seq in enumerate(batch):
|
||||
seq_len = seq.shape[0]
|
||||
padded[i, :seq_len, :] = seq
|
||||
mask[i, :seq_len, :] = True
|
||||
|
||||
return padded, mask
|
||||
|
||||
class CSpaceDataset(Dataset):
|
||||
"""Dataset that processes audio on-the-fly to C-Space representations with sliding windows."""
|
||||
|
||||
def __init__(self, hf_dataset, cspace_model, sample_rate=24000, segment_seconds=1.0, overlap=0.5, augment=True, max_duration=5.0, min_duration=1.0):
|
||||
self.dataset = hf_dataset
|
||||
self.cspace_model = cspace_model
|
||||
self.sample_rate = sample_rate
|
||||
self.segment_seconds = segment_seconds
|
||||
self.overlap = overlap
|
||||
self.augment = augment
|
||||
self.max_duration = max_duration
|
||||
self.min_duration = min_duration
|
||||
self.segment_samples = int(segment_seconds * sample_rate)
|
||||
self.hop_samples = int(self.segment_samples * (1 - overlap))
|
||||
self.max_samples = int(max_duration * sample_rate)
|
||||
self.min_samples = int(min_duration * sample_rate)
|
||||
|
||||
# Build index of (audio_idx, segment_start) pairs
|
||||
self.segment_indices = []
|
||||
for i in range(len(hf_dataset)):
|
||||
audio_len = len(hf_dataset[i]["audio"]["array"])
|
||||
if audio_len >= self.min_samples:
|
||||
# Extract sliding windows
|
||||
truncated_len = min(audio_len, self.max_samples)
|
||||
num_segments = max(1, (truncated_len - self.segment_samples) // self.hop_samples + 1)
|
||||
for seg in range(num_segments):
|
||||
start = seg * self.hop_samples
|
||||
self.segment_indices.append((i, start))
|
||||
|
||||
def __len__(self):
|
||||
return len(self.segment_indices)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
audio_idx, seg_start = self.segment_indices[idx]
|
||||
sample = self.dataset[audio_idx]
|
||||
audio_np = np.array(sample["audio"]["array"], dtype=np.float32)
|
||||
|
||||
# Extract segment
|
||||
seg_end = seg_start + self.segment_samples
|
||||
audio_segment = audio_np[seg_start:seg_end]
|
||||
|
||||
# Pad if too short
|
||||
if len(audio_segment) < self.segment_samples:
|
||||
audio_segment = np.pad(audio_segment, (0, self.segment_samples - len(audio_segment)))
|
||||
|
||||
# Augmentation (phase-tolerant)
|
||||
if self.augment:
|
||||
audio_segment = self._augment_audio(audio_segment)
|
||||
|
||||
# Normalize
|
||||
peak = np.max(np.abs(audio_segment))
|
||||
if peak > 0:
|
||||
audio_segment = audio_segment / peak
|
||||
|
||||
# Encode to C-Space
|
||||
cspace_data = self.cspace_model.encode(audio_segment)
|
||||
|
||||
# Already a tensor, just ensure dtype
|
||||
if isinstance(cspace_data, torch.Tensor):
|
||||
return cspace_data.to(torch.complex64)
|
||||
return torch.tensor(cspace_data, dtype=torch.complex64)
|
||||
|
||||
def _augment_audio(self, audio):
|
||||
"""Phase-tolerant augmentations."""
|
||||
# Random gain
|
||||
if np.random.rand() > 0.5:
|
||||
audio = audio * np.random.uniform(0.8, 1.0)
|
||||
|
||||
# Time stretch (resample)
|
||||
if np.random.rand() > 0.5:
|
||||
stretch_factor = np.random.uniform(0.95, 1.05)
|
||||
# Unsqueeze to (1, 1, time) for interpolate
|
||||
audio_t = torch.tensor(audio).unsqueeze(0).unsqueeze(0)
|
||||
audio_t = torch.nn.functional.interpolate(
|
||||
audio_t,
|
||||
scale_factor=stretch_factor,
|
||||
mode='linear',
|
||||
align_corners=False
|
||||
)
|
||||
audio = audio_t.squeeze().numpy()
|
||||
|
||||
# If we stretched it longer than segment_len, crop it; if shorter, pad
|
||||
# For simplicity in this logical block, we usually let the caller handle exact sizing,
|
||||
# but to be safe let's just take the center or start.
|
||||
target_len = len(audio)
|
||||
# (Note: simpler approach is just return it and let encodings handle length,
|
||||
# but the dataset expects fixed window for batches usually.
|
||||
# Here we just return modified audio.)
|
||||
|
||||
# Add small noise (phase-tolerant)
|
||||
if np.random.rand() > 0.5:
|
||||
noise = np.random.normal(0, 0.01, len(audio))
|
||||
audio = audio + noise
|
||||
|
||||
return audio
|
||||
98
inference.py
Normal file
98
inference.py
Normal file
@@ -0,0 +1,98 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
import argparse
|
||||
import soundfile as sf
|
||||
import os
|
||||
import torchaudio
|
||||
|
||||
# Local imports
|
||||
from cspace import CSpace
|
||||
from model import CSpaceCompressor
|
||||
|
||||
def load_and_prepare_audio(audio_path, target_sample_rate, device):
|
||||
"""Loads, resamples, mono-mixes, and normalizes audio."""
|
||||
waveform, sr = torchaudio.load(audio_path)
|
||||
waveform = waveform.to(device)
|
||||
|
||||
# Mix to mono
|
||||
if waveform.shape[0] > 1:
|
||||
waveform = torch.mean(waveform, dim=0, keepdim=True)
|
||||
|
||||
# Resample
|
||||
if sr != target_sample_rate:
|
||||
resampler = torchaudio.transforms.Resample(sr, target_sample_rate).to(device)
|
||||
waveform = resampler(waveform)
|
||||
|
||||
audio_np = waveform.squeeze().cpu().numpy().astype(np.float32)
|
||||
|
||||
# Normalize input
|
||||
peak = np.max(np.abs(audio_np))
|
||||
if peak > 0:
|
||||
audio_np = audio_np / peak
|
||||
|
||||
return audio_np, peak
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Run inference on an audio file using CSpaceCompressor.")
|
||||
parser.add_argument("input_wav", type=str, help="Path to input .wav file")
|
||||
parser.add_argument("checkpoint", type=str, help="Path to model checkpoint (.pt)")
|
||||
parser.add_argument("--config", type=str, default="cochlear_config.json", help="Path to config json")
|
||||
parser.add_argument("--output", type=str, default="output.wav", help="Path to save output .wav")
|
||||
parser.add_argument("--num-nodes", type=int, default=32, help="Must match training config")
|
||||
parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
print(f"Using device: {args.device}")
|
||||
|
||||
# 1. Load CSpace (the transformation layer)
|
||||
print(f"Loading CSpace configuration from {args.config}...")
|
||||
cspace_model = CSpace(config=args.config, device=args.device)
|
||||
sample_rate = int(cspace_model.config.sample_rate)
|
||||
|
||||
# 2. Load Compressor Model
|
||||
print(f"Loading model checkpoint from {args.checkpoint}...")
|
||||
model = CSpaceCompressor(num_nodes=args.num_nodes).to(args.device)
|
||||
state_dict = torch.load(args.checkpoint, map_location=args.device)
|
||||
model.load_state_dict(state_dict)
|
||||
model.eval()
|
||||
|
||||
# 3. Load Audio
|
||||
print(f"Loading audio: {args.input_wav}")
|
||||
audio_input, original_peak = load_and_prepare_audio(args.input_wav, sample_rate, args.device)
|
||||
print(f"Audio loaded: {len(audio_input)} samples @ {sample_rate}Hz")
|
||||
|
||||
# 4. Encode to C-Space
|
||||
print("Encoding to C-Space...")
|
||||
# cspace_model.encode returns a complex tensor on args.device
|
||||
cspace_data = cspace_model.encode(audio_input)
|
||||
|
||||
# Model expects (Batch, Time, Nodes), so unsqueeze batch dim
|
||||
cspace_batch = cspace_data.unsqueeze(0) # -> (1, Time, Nodes)
|
||||
|
||||
# 5. Model Forward Pass
|
||||
print("Running Compressor...")
|
||||
with torch.no_grad():
|
||||
reconstructed_cspace = model(cspace_batch)
|
||||
|
||||
# Remove batch dim
|
||||
reconstructed_cspace = reconstructed_cspace.squeeze(0) # -> (Time, Nodes)
|
||||
|
||||
# 6. Decode back to Audio
|
||||
print("Decoding C-Space to Audio...")
|
||||
audio_output = cspace_model.decode(reconstructed_cspace)
|
||||
|
||||
# 7. Renormalize/Scale
|
||||
# Normalize output to prevent clipping, but try to respect original volume dynamics if desired
|
||||
# Here we just normalize the output to -1.0 to 1.0 safely.
|
||||
out_peak = np.max(np.abs(audio_output))
|
||||
if out_peak > 0:
|
||||
audio_output = audio_output / out_peak
|
||||
|
||||
# 8. Save
|
||||
print(f"Saving to {args.output}...")
|
||||
sf.write(args.output, audio_output, sample_rate)
|
||||
print("Done.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
95
model.py
Normal file
95
model.py
Normal file
@@ -0,0 +1,95 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
class ResizeConv2d(nn.Module):
|
||||
"""
|
||||
Upsamples the input and then performs a convolution.
|
||||
This avoids checkerboard artifacts common in ConvTranspose2d.
|
||||
"""
|
||||
def __init__(self, in_channels, out_channels, kernel_size, scale_factor):
|
||||
super().__init__()
|
||||
# 'bilinear' is usually smoother for continuous signals like spectrograms
|
||||
self.upsample = nn.Upsample(scale_factor=scale_factor, mode='bilinear', align_corners=False)
|
||||
# Standard convolution with padding to maintain size after upsampling
|
||||
# Padding = (kernel_size - 1) // 2 for odd kernels
|
||||
padding = (kernel_size[0] // 2, kernel_size[1] // 2)
|
||||
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=padding)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.upsample(x)
|
||||
return self.conv(x)
|
||||
|
||||
class CSpaceCompressor(nn.Module):
|
||||
"""Convolutional compressor for C-Space as a 2D image (frequency x time)."""
|
||||
|
||||
def __init__(self, num_nodes=32, compression_factor=4, latent_dim=128):
|
||||
super().__init__()
|
||||
self.num_nodes = num_nodes
|
||||
self.compression_factor = compression_factor
|
||||
self.latent_dim = latent_dim
|
||||
|
||||
# Encoder: (batch, 2, num_nodes, time) -> Latent
|
||||
# We keep the encoder largely the same, but standard Convs are fine here.
|
||||
self.encoder = nn.Sequential(
|
||||
nn.Conv2d(2, 32, kernel_size=(3, 5), stride=(1, 2), padding=(1, 2)),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(32, 64, kernel_size=(3, 5), stride=(2, 2), padding=(1, 2)),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(64, 128, kernel_size=(3, 5), stride=(2, 2), padding=(1, 2)),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(128, latent_dim, kernel_size=(3, 5), stride=(2, 1), padding=(1, 2)),
|
||||
)
|
||||
|
||||
# Decoder: Replaced ConvTranspose2d with ResizeConv2d
|
||||
# We look at the strides of the Encoder to determine scale factors for Decoder
|
||||
# Encoder strides: (1,2) -> (2,2) -> (2,2) -> (2,1)
|
||||
# Decoder scales: (2,1) -> (2,2) -> (2,2) -> (1,2)
|
||||
|
||||
self.decoder = nn.Sequential(
|
||||
ResizeConv2d(latent_dim, 128, kernel_size=(3, 5), scale_factor=(2, 1)),
|
||||
nn.ReLU(),
|
||||
ResizeConv2d(128, 64, kernel_size=(3, 5), scale_factor=(2, 2)),
|
||||
nn.ReLU(),
|
||||
ResizeConv2d(64, 32, kernel_size=(3, 5), scale_factor=(2, 2)),
|
||||
nn.ReLU(),
|
||||
ResizeConv2d(32, 2, kernel_size=(3, 5), scale_factor=(1, 2)),
|
||||
# No final ReLU, we need raw coordinate values
|
||||
)
|
||||
|
||||
def forward(self, cspace_complex):
|
||||
batch_size, time_steps, num_nodes = cspace_complex.shape
|
||||
|
||||
# Split into real and imaginary
|
||||
real = cspace_complex.real
|
||||
imag = cspace_complex.imag
|
||||
|
||||
# Stack: (batch, 2, time, num_nodes)
|
||||
x = torch.stack([real, imag], dim=1)
|
||||
# Transpose to Frequency-First for 2D Conv logic: (batch, 2, num_nodes, time)
|
||||
x = x.transpose(2, 3)
|
||||
|
||||
# Encode
|
||||
latent = self.encoder(x)
|
||||
|
||||
# Decode
|
||||
x_recon = self.decoder(latent)
|
||||
|
||||
# Force exact size match (Upsampling might result in 1-2 pixel differences depending on rounding)
|
||||
# We interpolate to the exact target size if slightly off
|
||||
if x_recon.shape[2] != num_nodes or x_recon.shape[3] != time_steps:
|
||||
x_recon = torch.nn.functional.interpolate(
|
||||
x_recon,
|
||||
size=(num_nodes, time_steps),
|
||||
mode='bilinear',
|
||||
align_corners=False
|
||||
)
|
||||
|
||||
# Extract real and imag
|
||||
real_recon = x_recon[:, 0, :, :]
|
||||
imag_recon = x_recon[:, 1, :, :]
|
||||
|
||||
# Transpose back to (batch, time, num_nodes)
|
||||
real_recon = real_recon.transpose(1, 2)
|
||||
imag_recon = imag_recon.transpose(1, 2)
|
||||
|
||||
return torch.complex(real_recon, imag_recon)
|
||||
15
pyproject.toml
Normal file
15
pyproject.toml
Normal file
@@ -0,0 +1,15 @@
|
||||
[project]
|
||||
name = "cspace"
|
||||
version = "0.1.0"
|
||||
description = "Add your description here"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
dependencies = [
|
||||
"datasets>=4.4.1",
|
||||
"matplotlib>=3.10.8",
|
||||
"numpy>=2.2.6",
|
||||
"soundfile>=0.13.1",
|
||||
"torch>=2.9.1",
|
||||
"torchaudio>=2.9.1",
|
||||
"torchcodec>=0.9.1",
|
||||
]
|
||||
190
train.py
Normal file
190
train.py
Normal file
@@ -0,0 +1,190 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch.utils.data import DataLoader
|
||||
import numpy as np
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
from collections import deque
|
||||
import csv
|
||||
from datetime import datetime
|
||||
from datasets import load_dataset, Audio
|
||||
import os
|
||||
|
||||
# Local imports
|
||||
from cspace import CSpace
|
||||
from model import CSpaceCompressor
|
||||
from dataset import CSpaceDataset, collate_variable_length
|
||||
|
||||
# Configuration Constants
|
||||
DATASET_NAME = "mythicinfinity/libritts"
|
||||
SUBSET = "dev"
|
||||
SAMPLE_RATE = 24000
|
||||
LOG_INTERVAL = 10
|
||||
CHECKPOINT_DIR = Path("checkpoints")
|
||||
LOGS_DIR = Path("logs")
|
||||
|
||||
def cspace_mse_loss(pred, target, mask=None):
|
||||
"""MSE loss directly in C-Space (complex domain), optionally masked."""
|
||||
error = torch.abs(pred - target) ** 2
|
||||
if mask is not None:
|
||||
error = error * mask
|
||||
return torch.sum(error) / torch.sum(mask)
|
||||
return torch.mean(error)
|
||||
|
||||
def cspace_l1_loss(pred, target, mask=None):
|
||||
"""L1 loss directly in C-Space (complex domain), optionally masked."""
|
||||
error = torch.abs(pred - target)
|
||||
if mask is not None:
|
||||
error = error * mask
|
||||
return torch.sum(error) / torch.sum(mask)
|
||||
return torch.mean(error)
|
||||
|
||||
def get_gradient_norm(model):
|
||||
total_norm = 0.0
|
||||
for p in model.parameters():
|
||||
if p.grad is not None:
|
||||
param_norm = p.grad.data.norm(2)
|
||||
total_norm += param_norm.item() ** 2
|
||||
return total_norm ** 0.5
|
||||
|
||||
def train_epoch(model, train_loader, optimizer, device, epoch, csv_writer, csv_file):
|
||||
model.train()
|
||||
mses = deque(maxlen=LOG_INTERVAL)
|
||||
l1s = deque(maxlen=LOG_INTERVAL)
|
||||
grad_norms = deque(maxlen=LOG_INTERVAL)
|
||||
|
||||
for batch_idx, (cspace_batch, mask) in enumerate(train_loader):
|
||||
cspace_batch = cspace_batch.to(device)
|
||||
mask = mask.to(device)
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
pred = model(cspace_batch)
|
||||
loss = cspace_mse_loss(pred, cspace_batch, mask=mask)
|
||||
|
||||
loss.backward()
|
||||
grad_norm = get_gradient_norm(model)
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
||||
optimizer.step()
|
||||
|
||||
mse = loss.item()
|
||||
l1 = cspace_l1_loss(pred, cspace_batch, mask=mask).item()
|
||||
|
||||
mses.append(mse)
|
||||
l1s.append(l1)
|
||||
grad_norms.append(grad_norm)
|
||||
|
||||
if (batch_idx + 1) % LOG_INTERVAL == 0:
|
||||
avg_mse = np.mean(list(mses))
|
||||
avg_l1 = np.mean(list(l1s))
|
||||
avg_grad = np.mean(list(grad_norms))
|
||||
|
||||
print(f"E{epoch+1} B{batch_idx+1:>4} | MSE: {avg_mse:.4f} | L1: {avg_l1:.4f} | Grad: {avg_grad:.3f}")
|
||||
|
||||
csv_writer.writerow({
|
||||
'epoch': epoch + 1, 'batch': batch_idx + 1,
|
||||
'mse': f"{avg_mse:.4f}", 'l1': f"{avg_l1:.4f}", 'grad_norm': f"{avg_grad:.3f}"
|
||||
})
|
||||
csv_file.flush()
|
||||
|
||||
return np.mean(list(mses))
|
||||
|
||||
def validate(model, val_loader, device, epoch, csv_writer, csv_file):
|
||||
model.eval()
|
||||
val_mses = []
|
||||
val_l1s = []
|
||||
|
||||
with torch.no_grad():
|
||||
for cspace_batch, mask in val_loader:
|
||||
cspace_batch = cspace_batch.to(device)
|
||||
mask = mask.to(device)
|
||||
pred = model(cspace_batch)
|
||||
mse = cspace_mse_loss(pred, cspace_batch, mask=mask).item()
|
||||
l1 = cspace_l1_loss(pred, cspace_batch, mask=mask).item()
|
||||
val_mses.append(mse)
|
||||
val_l1s.append(l1)
|
||||
|
||||
avg_mse = np.mean(val_mses)
|
||||
avg_l1 = np.mean(val_l1s)
|
||||
print(f"VAL E{epoch+1}: MSE={avg_mse:.4f} L1={avg_l1:.4f}")
|
||||
|
||||
csv_writer.writerow({
|
||||
'epoch': epoch + 1, 'batch': 'VAL',
|
||||
'mse': f"{avg_mse:.4f}", 'l1': f"{avg_l1:.4f}", 'grad_norm': 'N/A'
|
||||
})
|
||||
csv_file.flush()
|
||||
return avg_mse
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--config", type=str, default="cochlear_config.json")
|
||||
parser.add_argument("--batch-size", type=int, default=8)
|
||||
parser.add_argument("--epochs", type=int, default=10)
|
||||
parser.add_argument("--lr", type=float, default=1e-3)
|
||||
parser.add_argument("--num-nodes", type=int, default=32)
|
||||
args = parser.parse_args()
|
||||
|
||||
# 1. Setup Directories (Robustly)
|
||||
CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)
|
||||
LOGS_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
print(f"Device: {device}")
|
||||
|
||||
print(f"Loading CSpace model with config: {args.config}")
|
||||
cspace_model = CSpace(config=args.config, device=device)
|
||||
|
||||
print(f"Loading {DATASET_NAME}...")
|
||||
ds = load_dataset(DATASET_NAME, SUBSET, split="dev.clean")
|
||||
ds = ds.cast_column("audio", Audio(sampling_rate=SAMPLE_RATE))
|
||||
|
||||
split = ds.train_test_split(test_size=0.1)
|
||||
train_ds = split["train"]
|
||||
val_ds = split["test"]
|
||||
|
||||
train_dataset = CSpaceDataset(train_ds, cspace_model, sample_rate=SAMPLE_RATE, augment=True)
|
||||
val_dataset = CSpaceDataset(val_ds, cspace_model, sample_rate=SAMPLE_RATE, augment=False)
|
||||
|
||||
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True,
|
||||
collate_fn=collate_variable_length, num_workers=0)
|
||||
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False,
|
||||
collate_fn=collate_variable_length, num_workers=0)
|
||||
|
||||
print("Initializing CSpace Compressor...")
|
||||
model = CSpaceCompressor(num_nodes=args.num_nodes).to(device)
|
||||
|
||||
optimizer = optim.Adam(model.parameters(), lr=args.lr)
|
||||
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs)
|
||||
|
||||
log_file_path = LOGS_DIR / f"training_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv"
|
||||
log_file = open(log_file_path, 'w', newline='')
|
||||
csv_writer = csv.DictWriter(log_file, fieldnames=['epoch', 'batch', 'mse', 'l1', 'grad_norm'])
|
||||
csv_writer.writeheader()
|
||||
|
||||
best_val_loss = float('inf')
|
||||
|
||||
for epoch in range(args.epochs):
|
||||
print(f"\n{'='*80}\nEpoch {epoch + 1}/{args.epochs}\n{'='*80}")
|
||||
train_epoch(model, train_loader, optimizer, device, epoch, csv_writer, log_file)
|
||||
val_loss = validate(model, val_loader, device, epoch, csv_writer, log_file)
|
||||
scheduler.step()
|
||||
|
||||
# Save Logic
|
||||
if val_loss < best_val_loss:
|
||||
best_val_loss = val_loss
|
||||
ckpt_path = CHECKPOINT_DIR / f"best_model_epoch{epoch + 1}.pt"
|
||||
|
||||
# PARANOID SAFETY CHECK: Ensure dir exists right before saving
|
||||
if not CHECKPOINT_DIR.exists():
|
||||
print(f"Warning: {CHECKPOINT_DIR} was missing. Recreating...")
|
||||
CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
torch.save(model.state_dict(), ckpt_path)
|
||||
print(f"✓ Saved checkpoint: {ckpt_path}")
|
||||
|
||||
log_file.close()
|
||||
print("Training complete.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
143
visualize.py
Normal file
143
visualize.py
Normal file
@@ -0,0 +1,143 @@
|
||||
import torch
|
||||
import torchaudio
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import argparse
|
||||
from cspace import CSpace
|
||||
import os
|
||||
|
||||
def normalize(data: np.ndarray) -> tuple[np.ndarray, float]:
|
||||
"""Normalizes a numpy array to the range [-1.0, 1.0] and returns it and the peak value."""
|
||||
peak = np.max(np.abs(data))
|
||||
if peak == 0:
|
||||
return data, 1.0
|
||||
return data / peak, peak
|
||||
|
||||
def load_and_prepare_audio(audio_path: str, target_sample_rate: int, device: str) -> np.ndarray:
|
||||
"""Loads an audio file, resamples it, and converts it to a mono float32 numpy array."""
|
||||
waveform, sr = torchaudio.load(audio_path)
|
||||
|
||||
waveform = waveform.to(device)
|
||||
|
||||
# Convert to mono
|
||||
if waveform.shape[0] > 1:
|
||||
waveform = torch.mean(waveform, dim=0, keepdim=True)
|
||||
|
||||
# Resample if necessary
|
||||
if sr != target_sample_rate:
|
||||
resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sample_rate).to(device)
|
||||
waveform = resampler(waveform)
|
||||
|
||||
audio_np = waveform.squeeze().cpu().numpy().astype(np.float32)
|
||||
return audio_np
|
||||
|
||||
def save_audio(audio_data: np.ndarray, path: str, sample_rate: int):
|
||||
"""Saves a numpy audio array to a WAV file."""
|
||||
# torchaudio.save expects a tensor, shape (channels, samples)
|
||||
audio_tensor = torch.from_numpy(audio_data).unsqueeze(0)
|
||||
torchaudio.save(path, audio_tensor, sample_rate)
|
||||
print(f"Reconstructed audio saved to {path}")
|
||||
|
||||
def print_reconstruction_stats(original_signal: np.ndarray, reconstructed_signal: np.ndarray):
|
||||
"""Calculates and prints RMS and MSE between two signals."""
|
||||
# Ensure signals have the same length for comparison
|
||||
min_len = min(len(original_signal), len(reconstructed_signal))
|
||||
original_trimmed = original_signal[:min_len]
|
||||
reconstructed_trimmed = reconstructed_signal[:min_len]
|
||||
|
||||
# RMS Calculation
|
||||
rms_original = np.sqrt(np.mean(np.square(original_trimmed)))
|
||||
rms_reconstructed = np.sqrt(np.mean(np.square(reconstructed_trimmed)))
|
||||
|
||||
# MSE Calculation
|
||||
mse = np.mean(np.square(original_trimmed - reconstructed_trimmed))
|
||||
|
||||
print("\n--- Reconstruction Stats ---")
|
||||
print(f" - Original RMS: {rms_original:.6f}")
|
||||
print(f" - Reconstructed RMS: {rms_reconstructed:.6f}")
|
||||
print(f" - Mean Squared Error (MSE): {mse:.8f}")
|
||||
print("--------------------------\n")
|
||||
|
||||
def visualize_cspace(cspace_data: torch.Tensor, freqs: np.ndarray, output_path: str, sample_rate: float):
|
||||
"""Generates and saves a visualization of the C-Space data."""
|
||||
magnitude_data = torch.abs(cspace_data).cpu().numpy().T
|
||||
|
||||
fig, ax = plt.subplots(figsize=(30, 5))
|
||||
|
||||
im = ax.imshow(
|
||||
magnitude_data,
|
||||
aspect='auto',
|
||||
origin='lower',
|
||||
interpolation='none',
|
||||
cmap='magma'
|
||||
)
|
||||
|
||||
fig.colorbar(im, ax=ax, label='Magnitude')
|
||||
|
||||
ax.set_title('C-Space Visualization')
|
||||
ax.set_xlabel('Time (s)')
|
||||
ax.set_ylabel('Frequency (Hz)')
|
||||
|
||||
num_samples = cspace_data.shape[0]
|
||||
duration_s = num_samples / sample_rate
|
||||
time_ticks = np.linspace(0, num_samples, num=10)
|
||||
time_labels = np.linspace(0, duration_s, num=10)
|
||||
ax.set_xticks(time_ticks)
|
||||
ax.set_xticklabels([f'{t:.2f}' for t in time_labels])
|
||||
|
||||
num_nodes = len(freqs)
|
||||
freq_indices = np.linspace(0, num_nodes - 1, num=10, dtype=int)
|
||||
ax.set_yticks(freq_indices)
|
||||
ax.set_yticklabels([f'{freqs[i]:.0f}' for i in freq_indices])
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(output_path)
|
||||
print(f"Visualization saved to {output_path}")
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Generate a C-Space visualization and reconstructed audio from an audio file."
|
||||
)
|
||||
parser.add_argument("audio_path", type=str, help="Path to the input audio file.")
|
||||
parser.add_argument("--config", type=str, default="cochlear_config.json", help="Path to the cochlear config JSON file.")
|
||||
parser.add_argument("--output-dir", type=str, default=".", help="Directory to save the output files.")
|
||||
parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device to use for computation (e.g., 'cuda', 'cpu').")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Create output directory if it doesn't exist
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
# --- 1. Set up paths ---
|
||||
base_name = os.path.splitext(os.path.basename(args.audio_path))[0]
|
||||
viz_output_path = os.path.join(args.output_dir, f"{base_name}_cspace.png")
|
||||
recon_output_path = os.path.join(args.output_dir, f"{base_name}_recon.wav")
|
||||
|
||||
# --- 2. Load Model and Audio ---
|
||||
print(f"Loading CSpace model with config: {args.config}")
|
||||
cspace_model = CSpace(config=args.config, device=args.device)
|
||||
config = cspace_model.config
|
||||
|
||||
print(f"Loading and preparing audio from: {args.audio_path}")
|
||||
original_audio_np = load_and_prepare_audio(args.audio_path, int(config.sample_rate), args.device)
|
||||
original_audio_np, _ = normalize(original_audio_np)
|
||||
|
||||
# --- 3. Encode and Decode ---
|
||||
print("Encoding audio to C-Space...")
|
||||
cspace_data = cspace_model.encode(original_audio_np)
|
||||
|
||||
print("Decoding C-Space back to audio...")
|
||||
reconstructed_audio_np = cspace_model.decode(cspace_data)
|
||||
|
||||
# --- 4. Normalize, Calculate Stats, and Save Files ---
|
||||
reconstructed_audio_np, _ = normalize(reconstructed_audio_np)
|
||||
|
||||
print_reconstruction_stats(original_audio_np, reconstructed_audio_np)
|
||||
|
||||
save_audio(reconstructed_audio_np, recon_output_path, int(config.sample_rate))
|
||||
|
||||
visualize_cspace(cspace_data, cspace_model.freqs, viz_output_path, config.sample_rate)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user