Initial Commit
This commit is contained in:
15
.gitignore
vendored
Normal file
15
.gitignore
vendored
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
# Python-generated files
|
||||||
|
__pycache__/
|
||||||
|
*.py[oc]
|
||||||
|
build/
|
||||||
|
dist/
|
||||||
|
wheels/
|
||||||
|
*.egg-info
|
||||||
|
|
||||||
|
# Virtual environments
|
||||||
|
.venv
|
||||||
|
audio/
|
||||||
|
data/
|
||||||
|
logs/
|
||||||
|
checkpoints/
|
||||||
|
*.wav
|
||||||
1
.python-version
Normal file
1
.python-version
Normal file
@@ -0,0 +1 @@
|
|||||||
|
3.10
|
||||||
12
cochlear_config.json
Normal file
12
cochlear_config.json
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
{
|
||||||
|
"sample_rate": 24000,
|
||||||
|
"min_freq": 80.0,
|
||||||
|
"max_freq": 12000.0,
|
||||||
|
"num_nodes": 32,
|
||||||
|
"q_factor": 4.0,
|
||||||
|
"segment_size_seconds": 1.0,
|
||||||
|
"warmup_seconds": 0.25,
|
||||||
|
"normalize_output": true,
|
||||||
|
"node_distribution": "log",
|
||||||
|
"custom_frequencies": null
|
||||||
|
}
|
||||||
154
cspace.py
Normal file
154
cspace.py
Normal file
@@ -0,0 +1,154 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import numpy as np
|
||||||
|
import json
|
||||||
|
import math
|
||||||
|
import torch.fft
|
||||||
|
from typing import Optional, List, Tuple, Union
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CochlearConfig:
|
||||||
|
sample_rate: float = 44100.0
|
||||||
|
min_freq: float = 20.0
|
||||||
|
max_freq: float = 20000.0
|
||||||
|
num_nodes: int = 128
|
||||||
|
q_factor: float = 4.0
|
||||||
|
segment_size_seconds: float = 1.0
|
||||||
|
warmup_seconds: float = 0.25
|
||||||
|
normalize_output: bool = True
|
||||||
|
node_distribution: str = "log"
|
||||||
|
custom_frequencies: Optional[List[float]] = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_json(cls, path: str):
|
||||||
|
with open(path, 'r') as f:
|
||||||
|
data = json.load(f)
|
||||||
|
return cls(**data)
|
||||||
|
|
||||||
|
class CSpace(nn.Module):
|
||||||
|
def __init__(self, config: Union[CochlearConfig, str], device: str = "cuda" if torch.cuda.is_available() else "cpu"):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
if isinstance(config, str):
|
||||||
|
self.config = CochlearConfig.from_json(config)
|
||||||
|
else:
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
self.device = torch.device(device)
|
||||||
|
self.setup_kernel()
|
||||||
|
|
||||||
|
def setup_kernel(self):
|
||||||
|
# 1. Frequency Setup
|
||||||
|
if self.config.node_distribution == "custom" and self.config.custom_frequencies:
|
||||||
|
freqs = np.array(self.config.custom_frequencies)
|
||||||
|
self.config.num_nodes = len(freqs)
|
||||||
|
elif self.config.node_distribution == "linear":
|
||||||
|
freqs = np.linspace(self.config.min_freq, self.config.max_freq, self.config.num_nodes)
|
||||||
|
else:
|
||||||
|
low = math.log10(self.config.min_freq)
|
||||||
|
high = math.log10(self.config.max_freq)
|
||||||
|
freqs = np.logspace(low, high, self.config.num_nodes)
|
||||||
|
|
||||||
|
self.freqs = freqs
|
||||||
|
omegas = torch.tensor(2.0 * math.pi * freqs, device=self.device, dtype=torch.float32)
|
||||||
|
gammas = omegas / (2.0 * self.config.q_factor)
|
||||||
|
|
||||||
|
seg_len = int(self.config.segment_size_seconds * self.config.sample_rate)
|
||||||
|
dt = 1.0 / self.config.sample_rate
|
||||||
|
t = torch.arange(seg_len, device=self.device, dtype=torch.float32) * dt
|
||||||
|
|
||||||
|
t = t.unsqueeze(1)
|
||||||
|
omegas_broad = omegas.unsqueeze(0)
|
||||||
|
gammas_broad = gammas.unsqueeze(0)
|
||||||
|
|
||||||
|
# 2. Create Unnormalized Decay Kernel (For State Transitions)
|
||||||
|
self.zir_decay_kernel = torch.complex(
|
||||||
|
torch.exp(-gammas_broad * t) * torch.cos(omegas_broad * t),
|
||||||
|
torch.exp(-gammas_broad * t) * torch.sin(omegas_broad * t)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Calculate single-step transition factors for the boundaries
|
||||||
|
self.step_factors = self.zir_decay_kernel[1].clone()
|
||||||
|
|
||||||
|
# 3. Create Normalized Filter Kernel (For Convolution)
|
||||||
|
self.n_fft = seg_len * 2
|
||||||
|
|
||||||
|
self.impulse_response = self.zir_decay_kernel.clone()
|
||||||
|
self.kernel_fft = torch.fft.fft(self.impulse_response, n=self.n_fft, dim=0)
|
||||||
|
|
||||||
|
# Calculate Normalization Factors (Peak = 1.0)
|
||||||
|
num_bins = self.n_fft // 2 + 1
|
||||||
|
response_matrix = torch.abs(self.kernel_fft[:num_bins])
|
||||||
|
|
||||||
|
node_peak_responses = torch.zeros(self.config.num_nodes, device=self.device)
|
||||||
|
for i in range(self.config.num_nodes):
|
||||||
|
peak = torch.max(response_matrix[:, i])
|
||||||
|
node_peak_responses[i] = peak if peak > 1e-8 else 1.0
|
||||||
|
|
||||||
|
self.normalization_factors = 1.0 / node_peak_responses
|
||||||
|
|
||||||
|
# Apply normalization to the convolution kernel ONLY
|
||||||
|
self.impulse_response = self.impulse_response * self.normalization_factors.unsqueeze(0)
|
||||||
|
self.kernel_fft = torch.fft.fft(self.impulse_response, n=self.n_fft, dim=0)
|
||||||
|
|
||||||
|
def encode(self, audio_array: np.ndarray) -> torch.Tensor:
|
||||||
|
signal = torch.tensor(audio_array, dtype=torch.float32, device=self.device)
|
||||||
|
total_samples = len(signal)
|
||||||
|
seg_samples = int(self.config.segment_size_seconds * self.config.sample_rate)
|
||||||
|
|
||||||
|
try:
|
||||||
|
results = torch.empty((total_samples, self.config.num_nodes), dtype=torch.complex64, device=self.device)
|
||||||
|
except RuntimeError:
|
||||||
|
results = torch.empty((total_samples, self.config.num_nodes), dtype=torch.complex64, device="cpu")
|
||||||
|
|
||||||
|
current_state = torch.zeros(self.config.num_nodes, dtype=torch.complex64, device=self.device)
|
||||||
|
num_segments = math.ceil(total_samples / seg_samples)
|
||||||
|
|
||||||
|
for i in range(num_segments):
|
||||||
|
start = i * seg_samples
|
||||||
|
end = min(start + seg_samples, total_samples)
|
||||||
|
chunk_len = end - start
|
||||||
|
chunk = signal[start:end]
|
||||||
|
|
||||||
|
# FFT Convolution setup
|
||||||
|
if chunk_len != seg_samples:
|
||||||
|
fft_size = chunk_len + self.impulse_response.shape[0] - 1
|
||||||
|
fft_size = 2**math.ceil(math.log2(fft_size))
|
||||||
|
k_fft = torch.fft.fft(self.impulse_response[:chunk_len], n=fft_size, dim=0)
|
||||||
|
c_fft = torch.fft.fft(chunk, n=fft_size).unsqueeze(1)
|
||||||
|
else:
|
||||||
|
fft_size = self.n_fft
|
||||||
|
k_fft = self.kernel_fft
|
||||||
|
c_fft = torch.fft.fft(chunk, n=fft_size).unsqueeze(1)
|
||||||
|
|
||||||
|
conv_spectrum = c_fft * k_fft
|
||||||
|
conv_time = torch.fft.ifft(conv_spectrum, n=fft_size, dim=0)
|
||||||
|
zs_response = conv_time[:chunk_len]
|
||||||
|
|
||||||
|
# ZIR Calculation (State Decay)
|
||||||
|
decay_curve = self.zir_decay_kernel[:chunk_len]
|
||||||
|
zi_response = current_state.unsqueeze(0) * decay_curve
|
||||||
|
|
||||||
|
total_response = zs_response + zi_response
|
||||||
|
|
||||||
|
if results.device != total_response.device:
|
||||||
|
results[start:end] = total_response.to(results.device)
|
||||||
|
else:
|
||||||
|
results[start:end] = total_response
|
||||||
|
|
||||||
|
# State Update for Next Segment
|
||||||
|
current_state = total_response[-1] * self.step_factors
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def decode(self, cspace: torch.Tensor) -> np.ndarray:
|
||||||
|
if cspace.is_complex():
|
||||||
|
real_part = cspace.real
|
||||||
|
else:
|
||||||
|
real_part = cspace # Assume it's already real
|
||||||
|
|
||||||
|
# Simple summation of real parts, which is the correct inverse operation
|
||||||
|
recon = torch.sum(real_part, dim=1)
|
||||||
|
|
||||||
|
return recon.cpu().numpy()
|
||||||
116
dataset.py
Normal file
116
dataset.py
Normal file
@@ -0,0 +1,116 @@
|
|||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
from datasets import load_dataset, Audio
|
||||||
|
|
||||||
|
def collate_variable_length(batch):
|
||||||
|
"""Custom collate for variable-length C-Space tensors. Returns (padded_data, mask)."""
|
||||||
|
max_time = max([x.shape[0] for x in batch])
|
||||||
|
batch_size = len(batch)
|
||||||
|
num_nodes = batch[0].shape[1]
|
||||||
|
|
||||||
|
padded = torch.zeros(batch_size, max_time, num_nodes, dtype=torch.complex64)
|
||||||
|
mask = torch.zeros(batch_size, max_time, 1, dtype=torch.bool)
|
||||||
|
|
||||||
|
for i, seq in enumerate(batch):
|
||||||
|
seq_len = seq.shape[0]
|
||||||
|
padded[i, :seq_len, :] = seq
|
||||||
|
mask[i, :seq_len, :] = True
|
||||||
|
|
||||||
|
return padded, mask
|
||||||
|
|
||||||
|
class CSpaceDataset(Dataset):
|
||||||
|
"""Dataset that processes audio on-the-fly to C-Space representations with sliding windows."""
|
||||||
|
|
||||||
|
def __init__(self, hf_dataset, cspace_model, sample_rate=24000, segment_seconds=1.0, overlap=0.5, augment=True, max_duration=5.0, min_duration=1.0):
|
||||||
|
self.dataset = hf_dataset
|
||||||
|
self.cspace_model = cspace_model
|
||||||
|
self.sample_rate = sample_rate
|
||||||
|
self.segment_seconds = segment_seconds
|
||||||
|
self.overlap = overlap
|
||||||
|
self.augment = augment
|
||||||
|
self.max_duration = max_duration
|
||||||
|
self.min_duration = min_duration
|
||||||
|
self.segment_samples = int(segment_seconds * sample_rate)
|
||||||
|
self.hop_samples = int(self.segment_samples * (1 - overlap))
|
||||||
|
self.max_samples = int(max_duration * sample_rate)
|
||||||
|
self.min_samples = int(min_duration * sample_rate)
|
||||||
|
|
||||||
|
# Build index of (audio_idx, segment_start) pairs
|
||||||
|
self.segment_indices = []
|
||||||
|
for i in range(len(hf_dataset)):
|
||||||
|
audio_len = len(hf_dataset[i]["audio"]["array"])
|
||||||
|
if audio_len >= self.min_samples:
|
||||||
|
# Extract sliding windows
|
||||||
|
truncated_len = min(audio_len, self.max_samples)
|
||||||
|
num_segments = max(1, (truncated_len - self.segment_samples) // self.hop_samples + 1)
|
||||||
|
for seg in range(num_segments):
|
||||||
|
start = seg * self.hop_samples
|
||||||
|
self.segment_indices.append((i, start))
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.segment_indices)
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
audio_idx, seg_start = self.segment_indices[idx]
|
||||||
|
sample = self.dataset[audio_idx]
|
||||||
|
audio_np = np.array(sample["audio"]["array"], dtype=np.float32)
|
||||||
|
|
||||||
|
# Extract segment
|
||||||
|
seg_end = seg_start + self.segment_samples
|
||||||
|
audio_segment = audio_np[seg_start:seg_end]
|
||||||
|
|
||||||
|
# Pad if too short
|
||||||
|
if len(audio_segment) < self.segment_samples:
|
||||||
|
audio_segment = np.pad(audio_segment, (0, self.segment_samples - len(audio_segment)))
|
||||||
|
|
||||||
|
# Augmentation (phase-tolerant)
|
||||||
|
if self.augment:
|
||||||
|
audio_segment = self._augment_audio(audio_segment)
|
||||||
|
|
||||||
|
# Normalize
|
||||||
|
peak = np.max(np.abs(audio_segment))
|
||||||
|
if peak > 0:
|
||||||
|
audio_segment = audio_segment / peak
|
||||||
|
|
||||||
|
# Encode to C-Space
|
||||||
|
cspace_data = self.cspace_model.encode(audio_segment)
|
||||||
|
|
||||||
|
# Already a tensor, just ensure dtype
|
||||||
|
if isinstance(cspace_data, torch.Tensor):
|
||||||
|
return cspace_data.to(torch.complex64)
|
||||||
|
return torch.tensor(cspace_data, dtype=torch.complex64)
|
||||||
|
|
||||||
|
def _augment_audio(self, audio):
|
||||||
|
"""Phase-tolerant augmentations."""
|
||||||
|
# Random gain
|
||||||
|
if np.random.rand() > 0.5:
|
||||||
|
audio = audio * np.random.uniform(0.8, 1.0)
|
||||||
|
|
||||||
|
# Time stretch (resample)
|
||||||
|
if np.random.rand() > 0.5:
|
||||||
|
stretch_factor = np.random.uniform(0.95, 1.05)
|
||||||
|
# Unsqueeze to (1, 1, time) for interpolate
|
||||||
|
audio_t = torch.tensor(audio).unsqueeze(0).unsqueeze(0)
|
||||||
|
audio_t = torch.nn.functional.interpolate(
|
||||||
|
audio_t,
|
||||||
|
scale_factor=stretch_factor,
|
||||||
|
mode='linear',
|
||||||
|
align_corners=False
|
||||||
|
)
|
||||||
|
audio = audio_t.squeeze().numpy()
|
||||||
|
|
||||||
|
# If we stretched it longer than segment_len, crop it; if shorter, pad
|
||||||
|
# For simplicity in this logical block, we usually let the caller handle exact sizing,
|
||||||
|
# but to be safe let's just take the center or start.
|
||||||
|
target_len = len(audio)
|
||||||
|
# (Note: simpler approach is just return it and let encodings handle length,
|
||||||
|
# but the dataset expects fixed window for batches usually.
|
||||||
|
# Here we just return modified audio.)
|
||||||
|
|
||||||
|
# Add small noise (phase-tolerant)
|
||||||
|
if np.random.rand() > 0.5:
|
||||||
|
noise = np.random.normal(0, 0.01, len(audio))
|
||||||
|
audio = audio + noise
|
||||||
|
|
||||||
|
return audio
|
||||||
98
inference.py
Normal file
98
inference.py
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
import argparse
|
||||||
|
import soundfile as sf
|
||||||
|
import os
|
||||||
|
import torchaudio
|
||||||
|
|
||||||
|
# Local imports
|
||||||
|
from cspace import CSpace
|
||||||
|
from model import CSpaceCompressor
|
||||||
|
|
||||||
|
def load_and_prepare_audio(audio_path, target_sample_rate, device):
|
||||||
|
"""Loads, resamples, mono-mixes, and normalizes audio."""
|
||||||
|
waveform, sr = torchaudio.load(audio_path)
|
||||||
|
waveform = waveform.to(device)
|
||||||
|
|
||||||
|
# Mix to mono
|
||||||
|
if waveform.shape[0] > 1:
|
||||||
|
waveform = torch.mean(waveform, dim=0, keepdim=True)
|
||||||
|
|
||||||
|
# Resample
|
||||||
|
if sr != target_sample_rate:
|
||||||
|
resampler = torchaudio.transforms.Resample(sr, target_sample_rate).to(device)
|
||||||
|
waveform = resampler(waveform)
|
||||||
|
|
||||||
|
audio_np = waveform.squeeze().cpu().numpy().astype(np.float32)
|
||||||
|
|
||||||
|
# Normalize input
|
||||||
|
peak = np.max(np.abs(audio_np))
|
||||||
|
if peak > 0:
|
||||||
|
audio_np = audio_np / peak
|
||||||
|
|
||||||
|
return audio_np, peak
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="Run inference on an audio file using CSpaceCompressor.")
|
||||||
|
parser.add_argument("input_wav", type=str, help="Path to input .wav file")
|
||||||
|
parser.add_argument("checkpoint", type=str, help="Path to model checkpoint (.pt)")
|
||||||
|
parser.add_argument("--config", type=str, default="cochlear_config.json", help="Path to config json")
|
||||||
|
parser.add_argument("--output", type=str, default="output.wav", help="Path to save output .wav")
|
||||||
|
parser.add_argument("--num-nodes", type=int, default=32, help="Must match training config")
|
||||||
|
parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
print(f"Using device: {args.device}")
|
||||||
|
|
||||||
|
# 1. Load CSpace (the transformation layer)
|
||||||
|
print(f"Loading CSpace configuration from {args.config}...")
|
||||||
|
cspace_model = CSpace(config=args.config, device=args.device)
|
||||||
|
sample_rate = int(cspace_model.config.sample_rate)
|
||||||
|
|
||||||
|
# 2. Load Compressor Model
|
||||||
|
print(f"Loading model checkpoint from {args.checkpoint}...")
|
||||||
|
model = CSpaceCompressor(num_nodes=args.num_nodes).to(args.device)
|
||||||
|
state_dict = torch.load(args.checkpoint, map_location=args.device)
|
||||||
|
model.load_state_dict(state_dict)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
# 3. Load Audio
|
||||||
|
print(f"Loading audio: {args.input_wav}")
|
||||||
|
audio_input, original_peak = load_and_prepare_audio(args.input_wav, sample_rate, args.device)
|
||||||
|
print(f"Audio loaded: {len(audio_input)} samples @ {sample_rate}Hz")
|
||||||
|
|
||||||
|
# 4. Encode to C-Space
|
||||||
|
print("Encoding to C-Space...")
|
||||||
|
# cspace_model.encode returns a complex tensor on args.device
|
||||||
|
cspace_data = cspace_model.encode(audio_input)
|
||||||
|
|
||||||
|
# Model expects (Batch, Time, Nodes), so unsqueeze batch dim
|
||||||
|
cspace_batch = cspace_data.unsqueeze(0) # -> (1, Time, Nodes)
|
||||||
|
|
||||||
|
# 5. Model Forward Pass
|
||||||
|
print("Running Compressor...")
|
||||||
|
with torch.no_grad():
|
||||||
|
reconstructed_cspace = model(cspace_batch)
|
||||||
|
|
||||||
|
# Remove batch dim
|
||||||
|
reconstructed_cspace = reconstructed_cspace.squeeze(0) # -> (Time, Nodes)
|
||||||
|
|
||||||
|
# 6. Decode back to Audio
|
||||||
|
print("Decoding C-Space to Audio...")
|
||||||
|
audio_output = cspace_model.decode(reconstructed_cspace)
|
||||||
|
|
||||||
|
# 7. Renormalize/Scale
|
||||||
|
# Normalize output to prevent clipping, but try to respect original volume dynamics if desired
|
||||||
|
# Here we just normalize the output to -1.0 to 1.0 safely.
|
||||||
|
out_peak = np.max(np.abs(audio_output))
|
||||||
|
if out_peak > 0:
|
||||||
|
audio_output = audio_output / out_peak
|
||||||
|
|
||||||
|
# 8. Save
|
||||||
|
print(f"Saving to {args.output}...")
|
||||||
|
sf.write(args.output, audio_output, sample_rate)
|
||||||
|
print("Done.")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
95
model.py
Normal file
95
model.py
Normal file
@@ -0,0 +1,95 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
class ResizeConv2d(nn.Module):
|
||||||
|
"""
|
||||||
|
Upsamples the input and then performs a convolution.
|
||||||
|
This avoids checkerboard artifacts common in ConvTranspose2d.
|
||||||
|
"""
|
||||||
|
def __init__(self, in_channels, out_channels, kernel_size, scale_factor):
|
||||||
|
super().__init__()
|
||||||
|
# 'bilinear' is usually smoother for continuous signals like spectrograms
|
||||||
|
self.upsample = nn.Upsample(scale_factor=scale_factor, mode='bilinear', align_corners=False)
|
||||||
|
# Standard convolution with padding to maintain size after upsampling
|
||||||
|
# Padding = (kernel_size - 1) // 2 for odd kernels
|
||||||
|
padding = (kernel_size[0] // 2, kernel_size[1] // 2)
|
||||||
|
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=padding)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.upsample(x)
|
||||||
|
return self.conv(x)
|
||||||
|
|
||||||
|
class CSpaceCompressor(nn.Module):
|
||||||
|
"""Convolutional compressor for C-Space as a 2D image (frequency x time)."""
|
||||||
|
|
||||||
|
def __init__(self, num_nodes=32, compression_factor=4, latent_dim=128):
|
||||||
|
super().__init__()
|
||||||
|
self.num_nodes = num_nodes
|
||||||
|
self.compression_factor = compression_factor
|
||||||
|
self.latent_dim = latent_dim
|
||||||
|
|
||||||
|
# Encoder: (batch, 2, num_nodes, time) -> Latent
|
||||||
|
# We keep the encoder largely the same, but standard Convs are fine here.
|
||||||
|
self.encoder = nn.Sequential(
|
||||||
|
nn.Conv2d(2, 32, kernel_size=(3, 5), stride=(1, 2), padding=(1, 2)),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Conv2d(32, 64, kernel_size=(3, 5), stride=(2, 2), padding=(1, 2)),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Conv2d(64, 128, kernel_size=(3, 5), stride=(2, 2), padding=(1, 2)),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Conv2d(128, latent_dim, kernel_size=(3, 5), stride=(2, 1), padding=(1, 2)),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Decoder: Replaced ConvTranspose2d with ResizeConv2d
|
||||||
|
# We look at the strides of the Encoder to determine scale factors for Decoder
|
||||||
|
# Encoder strides: (1,2) -> (2,2) -> (2,2) -> (2,1)
|
||||||
|
# Decoder scales: (2,1) -> (2,2) -> (2,2) -> (1,2)
|
||||||
|
|
||||||
|
self.decoder = nn.Sequential(
|
||||||
|
ResizeConv2d(latent_dim, 128, kernel_size=(3, 5), scale_factor=(2, 1)),
|
||||||
|
nn.ReLU(),
|
||||||
|
ResizeConv2d(128, 64, kernel_size=(3, 5), scale_factor=(2, 2)),
|
||||||
|
nn.ReLU(),
|
||||||
|
ResizeConv2d(64, 32, kernel_size=(3, 5), scale_factor=(2, 2)),
|
||||||
|
nn.ReLU(),
|
||||||
|
ResizeConv2d(32, 2, kernel_size=(3, 5), scale_factor=(1, 2)),
|
||||||
|
# No final ReLU, we need raw coordinate values
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, cspace_complex):
|
||||||
|
batch_size, time_steps, num_nodes = cspace_complex.shape
|
||||||
|
|
||||||
|
# Split into real and imaginary
|
||||||
|
real = cspace_complex.real
|
||||||
|
imag = cspace_complex.imag
|
||||||
|
|
||||||
|
# Stack: (batch, 2, time, num_nodes)
|
||||||
|
x = torch.stack([real, imag], dim=1)
|
||||||
|
# Transpose to Frequency-First for 2D Conv logic: (batch, 2, num_nodes, time)
|
||||||
|
x = x.transpose(2, 3)
|
||||||
|
|
||||||
|
# Encode
|
||||||
|
latent = self.encoder(x)
|
||||||
|
|
||||||
|
# Decode
|
||||||
|
x_recon = self.decoder(latent)
|
||||||
|
|
||||||
|
# Force exact size match (Upsampling might result in 1-2 pixel differences depending on rounding)
|
||||||
|
# We interpolate to the exact target size if slightly off
|
||||||
|
if x_recon.shape[2] != num_nodes or x_recon.shape[3] != time_steps:
|
||||||
|
x_recon = torch.nn.functional.interpolate(
|
||||||
|
x_recon,
|
||||||
|
size=(num_nodes, time_steps),
|
||||||
|
mode='bilinear',
|
||||||
|
align_corners=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract real and imag
|
||||||
|
real_recon = x_recon[:, 0, :, :]
|
||||||
|
imag_recon = x_recon[:, 1, :, :]
|
||||||
|
|
||||||
|
# Transpose back to (batch, time, num_nodes)
|
||||||
|
real_recon = real_recon.transpose(1, 2)
|
||||||
|
imag_recon = imag_recon.transpose(1, 2)
|
||||||
|
|
||||||
|
return torch.complex(real_recon, imag_recon)
|
||||||
15
pyproject.toml
Normal file
15
pyproject.toml
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
[project]
|
||||||
|
name = "cspace"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = "Add your description here"
|
||||||
|
readme = "README.md"
|
||||||
|
requires-python = ">=3.10"
|
||||||
|
dependencies = [
|
||||||
|
"datasets>=4.4.1",
|
||||||
|
"matplotlib>=3.10.8",
|
||||||
|
"numpy>=2.2.6",
|
||||||
|
"soundfile>=0.13.1",
|
||||||
|
"torch>=2.9.1",
|
||||||
|
"torchaudio>=2.9.1",
|
||||||
|
"torchcodec>=0.9.1",
|
||||||
|
]
|
||||||
190
train.py
Normal file
190
train.py
Normal file
@@ -0,0 +1,190 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.optim as optim
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
import numpy as np
|
||||||
|
import argparse
|
||||||
|
from pathlib import Path
|
||||||
|
from collections import deque
|
||||||
|
import csv
|
||||||
|
from datetime import datetime
|
||||||
|
from datasets import load_dataset, Audio
|
||||||
|
import os
|
||||||
|
|
||||||
|
# Local imports
|
||||||
|
from cspace import CSpace
|
||||||
|
from model import CSpaceCompressor
|
||||||
|
from dataset import CSpaceDataset, collate_variable_length
|
||||||
|
|
||||||
|
# Configuration Constants
|
||||||
|
DATASET_NAME = "mythicinfinity/libritts"
|
||||||
|
SUBSET = "dev"
|
||||||
|
SAMPLE_RATE = 24000
|
||||||
|
LOG_INTERVAL = 10
|
||||||
|
CHECKPOINT_DIR = Path("checkpoints")
|
||||||
|
LOGS_DIR = Path("logs")
|
||||||
|
|
||||||
|
def cspace_mse_loss(pred, target, mask=None):
|
||||||
|
"""MSE loss directly in C-Space (complex domain), optionally masked."""
|
||||||
|
error = torch.abs(pred - target) ** 2
|
||||||
|
if mask is not None:
|
||||||
|
error = error * mask
|
||||||
|
return torch.sum(error) / torch.sum(mask)
|
||||||
|
return torch.mean(error)
|
||||||
|
|
||||||
|
def cspace_l1_loss(pred, target, mask=None):
|
||||||
|
"""L1 loss directly in C-Space (complex domain), optionally masked."""
|
||||||
|
error = torch.abs(pred - target)
|
||||||
|
if mask is not None:
|
||||||
|
error = error * mask
|
||||||
|
return torch.sum(error) / torch.sum(mask)
|
||||||
|
return torch.mean(error)
|
||||||
|
|
||||||
|
def get_gradient_norm(model):
|
||||||
|
total_norm = 0.0
|
||||||
|
for p in model.parameters():
|
||||||
|
if p.grad is not None:
|
||||||
|
param_norm = p.grad.data.norm(2)
|
||||||
|
total_norm += param_norm.item() ** 2
|
||||||
|
return total_norm ** 0.5
|
||||||
|
|
||||||
|
def train_epoch(model, train_loader, optimizer, device, epoch, csv_writer, csv_file):
|
||||||
|
model.train()
|
||||||
|
mses = deque(maxlen=LOG_INTERVAL)
|
||||||
|
l1s = deque(maxlen=LOG_INTERVAL)
|
||||||
|
grad_norms = deque(maxlen=LOG_INTERVAL)
|
||||||
|
|
||||||
|
for batch_idx, (cspace_batch, mask) in enumerate(train_loader):
|
||||||
|
cspace_batch = cspace_batch.to(device)
|
||||||
|
mask = mask.to(device)
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
pred = model(cspace_batch)
|
||||||
|
loss = cspace_mse_loss(pred, cspace_batch, mask=mask)
|
||||||
|
|
||||||
|
loss.backward()
|
||||||
|
grad_norm = get_gradient_norm(model)
|
||||||
|
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
mse = loss.item()
|
||||||
|
l1 = cspace_l1_loss(pred, cspace_batch, mask=mask).item()
|
||||||
|
|
||||||
|
mses.append(mse)
|
||||||
|
l1s.append(l1)
|
||||||
|
grad_norms.append(grad_norm)
|
||||||
|
|
||||||
|
if (batch_idx + 1) % LOG_INTERVAL == 0:
|
||||||
|
avg_mse = np.mean(list(mses))
|
||||||
|
avg_l1 = np.mean(list(l1s))
|
||||||
|
avg_grad = np.mean(list(grad_norms))
|
||||||
|
|
||||||
|
print(f"E{epoch+1} B{batch_idx+1:>4} | MSE: {avg_mse:.4f} | L1: {avg_l1:.4f} | Grad: {avg_grad:.3f}")
|
||||||
|
|
||||||
|
csv_writer.writerow({
|
||||||
|
'epoch': epoch + 1, 'batch': batch_idx + 1,
|
||||||
|
'mse': f"{avg_mse:.4f}", 'l1': f"{avg_l1:.4f}", 'grad_norm': f"{avg_grad:.3f}"
|
||||||
|
})
|
||||||
|
csv_file.flush()
|
||||||
|
|
||||||
|
return np.mean(list(mses))
|
||||||
|
|
||||||
|
def validate(model, val_loader, device, epoch, csv_writer, csv_file):
|
||||||
|
model.eval()
|
||||||
|
val_mses = []
|
||||||
|
val_l1s = []
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for cspace_batch, mask in val_loader:
|
||||||
|
cspace_batch = cspace_batch.to(device)
|
||||||
|
mask = mask.to(device)
|
||||||
|
pred = model(cspace_batch)
|
||||||
|
mse = cspace_mse_loss(pred, cspace_batch, mask=mask).item()
|
||||||
|
l1 = cspace_l1_loss(pred, cspace_batch, mask=mask).item()
|
||||||
|
val_mses.append(mse)
|
||||||
|
val_l1s.append(l1)
|
||||||
|
|
||||||
|
avg_mse = np.mean(val_mses)
|
||||||
|
avg_l1 = np.mean(val_l1s)
|
||||||
|
print(f"VAL E{epoch+1}: MSE={avg_mse:.4f} L1={avg_l1:.4f}")
|
||||||
|
|
||||||
|
csv_writer.writerow({
|
||||||
|
'epoch': epoch + 1, 'batch': 'VAL',
|
||||||
|
'mse': f"{avg_mse:.4f}", 'l1': f"{avg_l1:.4f}", 'grad_norm': 'N/A'
|
||||||
|
})
|
||||||
|
csv_file.flush()
|
||||||
|
return avg_mse
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--config", type=str, default="cochlear_config.json")
|
||||||
|
parser.add_argument("--batch-size", type=int, default=8)
|
||||||
|
parser.add_argument("--epochs", type=int, default=10)
|
||||||
|
parser.add_argument("--lr", type=float, default=1e-3)
|
||||||
|
parser.add_argument("--num-nodes", type=int, default=32)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# 1. Setup Directories (Robustly)
|
||||||
|
CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
LOGS_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
print(f"Device: {device}")
|
||||||
|
|
||||||
|
print(f"Loading CSpace model with config: {args.config}")
|
||||||
|
cspace_model = CSpace(config=args.config, device=device)
|
||||||
|
|
||||||
|
print(f"Loading {DATASET_NAME}...")
|
||||||
|
ds = load_dataset(DATASET_NAME, SUBSET, split="dev.clean")
|
||||||
|
ds = ds.cast_column("audio", Audio(sampling_rate=SAMPLE_RATE))
|
||||||
|
|
||||||
|
split = ds.train_test_split(test_size=0.1)
|
||||||
|
train_ds = split["train"]
|
||||||
|
val_ds = split["test"]
|
||||||
|
|
||||||
|
train_dataset = CSpaceDataset(train_ds, cspace_model, sample_rate=SAMPLE_RATE, augment=True)
|
||||||
|
val_dataset = CSpaceDataset(val_ds, cspace_model, sample_rate=SAMPLE_RATE, augment=False)
|
||||||
|
|
||||||
|
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True,
|
||||||
|
collate_fn=collate_variable_length, num_workers=0)
|
||||||
|
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False,
|
||||||
|
collate_fn=collate_variable_length, num_workers=0)
|
||||||
|
|
||||||
|
print("Initializing CSpace Compressor...")
|
||||||
|
model = CSpaceCompressor(num_nodes=args.num_nodes).to(device)
|
||||||
|
|
||||||
|
optimizer = optim.Adam(model.parameters(), lr=args.lr)
|
||||||
|
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs)
|
||||||
|
|
||||||
|
log_file_path = LOGS_DIR / f"training_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv"
|
||||||
|
log_file = open(log_file_path, 'w', newline='')
|
||||||
|
csv_writer = csv.DictWriter(log_file, fieldnames=['epoch', 'batch', 'mse', 'l1', 'grad_norm'])
|
||||||
|
csv_writer.writeheader()
|
||||||
|
|
||||||
|
best_val_loss = float('inf')
|
||||||
|
|
||||||
|
for epoch in range(args.epochs):
|
||||||
|
print(f"\n{'='*80}\nEpoch {epoch + 1}/{args.epochs}\n{'='*80}")
|
||||||
|
train_epoch(model, train_loader, optimizer, device, epoch, csv_writer, log_file)
|
||||||
|
val_loss = validate(model, val_loader, device, epoch, csv_writer, log_file)
|
||||||
|
scheduler.step()
|
||||||
|
|
||||||
|
# Save Logic
|
||||||
|
if val_loss < best_val_loss:
|
||||||
|
best_val_loss = val_loss
|
||||||
|
ckpt_path = CHECKPOINT_DIR / f"best_model_epoch{epoch + 1}.pt"
|
||||||
|
|
||||||
|
# PARANOID SAFETY CHECK: Ensure dir exists right before saving
|
||||||
|
if not CHECKPOINT_DIR.exists():
|
||||||
|
print(f"Warning: {CHECKPOINT_DIR} was missing. Recreating...")
|
||||||
|
CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
torch.save(model.state_dict(), ckpt_path)
|
||||||
|
print(f"✓ Saved checkpoint: {ckpt_path}")
|
||||||
|
|
||||||
|
log_file.close()
|
||||||
|
print("Training complete.")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
143
visualize.py
Normal file
143
visualize.py
Normal file
@@ -0,0 +1,143 @@
|
|||||||
|
import torch
|
||||||
|
import torchaudio
|
||||||
|
import numpy as np
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import argparse
|
||||||
|
from cspace import CSpace
|
||||||
|
import os
|
||||||
|
|
||||||
|
def normalize(data: np.ndarray) -> tuple[np.ndarray, float]:
|
||||||
|
"""Normalizes a numpy array to the range [-1.0, 1.0] and returns it and the peak value."""
|
||||||
|
peak = np.max(np.abs(data))
|
||||||
|
if peak == 0:
|
||||||
|
return data, 1.0
|
||||||
|
return data / peak, peak
|
||||||
|
|
||||||
|
def load_and_prepare_audio(audio_path: str, target_sample_rate: int, device: str) -> np.ndarray:
|
||||||
|
"""Loads an audio file, resamples it, and converts it to a mono float32 numpy array."""
|
||||||
|
waveform, sr = torchaudio.load(audio_path)
|
||||||
|
|
||||||
|
waveform = waveform.to(device)
|
||||||
|
|
||||||
|
# Convert to mono
|
||||||
|
if waveform.shape[0] > 1:
|
||||||
|
waveform = torch.mean(waveform, dim=0, keepdim=True)
|
||||||
|
|
||||||
|
# Resample if necessary
|
||||||
|
if sr != target_sample_rate:
|
||||||
|
resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sample_rate).to(device)
|
||||||
|
waveform = resampler(waveform)
|
||||||
|
|
||||||
|
audio_np = waveform.squeeze().cpu().numpy().astype(np.float32)
|
||||||
|
return audio_np
|
||||||
|
|
||||||
|
def save_audio(audio_data: np.ndarray, path: str, sample_rate: int):
|
||||||
|
"""Saves a numpy audio array to a WAV file."""
|
||||||
|
# torchaudio.save expects a tensor, shape (channels, samples)
|
||||||
|
audio_tensor = torch.from_numpy(audio_data).unsqueeze(0)
|
||||||
|
torchaudio.save(path, audio_tensor, sample_rate)
|
||||||
|
print(f"Reconstructed audio saved to {path}")
|
||||||
|
|
||||||
|
def print_reconstruction_stats(original_signal: np.ndarray, reconstructed_signal: np.ndarray):
|
||||||
|
"""Calculates and prints RMS and MSE between two signals."""
|
||||||
|
# Ensure signals have the same length for comparison
|
||||||
|
min_len = min(len(original_signal), len(reconstructed_signal))
|
||||||
|
original_trimmed = original_signal[:min_len]
|
||||||
|
reconstructed_trimmed = reconstructed_signal[:min_len]
|
||||||
|
|
||||||
|
# RMS Calculation
|
||||||
|
rms_original = np.sqrt(np.mean(np.square(original_trimmed)))
|
||||||
|
rms_reconstructed = np.sqrt(np.mean(np.square(reconstructed_trimmed)))
|
||||||
|
|
||||||
|
# MSE Calculation
|
||||||
|
mse = np.mean(np.square(original_trimmed - reconstructed_trimmed))
|
||||||
|
|
||||||
|
print("\n--- Reconstruction Stats ---")
|
||||||
|
print(f" - Original RMS: {rms_original:.6f}")
|
||||||
|
print(f" - Reconstructed RMS: {rms_reconstructed:.6f}")
|
||||||
|
print(f" - Mean Squared Error (MSE): {mse:.8f}")
|
||||||
|
print("--------------------------\n")
|
||||||
|
|
||||||
|
def visualize_cspace(cspace_data: torch.Tensor, freqs: np.ndarray, output_path: str, sample_rate: float):
|
||||||
|
"""Generates and saves a visualization of the C-Space data."""
|
||||||
|
magnitude_data = torch.abs(cspace_data).cpu().numpy().T
|
||||||
|
|
||||||
|
fig, ax = plt.subplots(figsize=(30, 5))
|
||||||
|
|
||||||
|
im = ax.imshow(
|
||||||
|
magnitude_data,
|
||||||
|
aspect='auto',
|
||||||
|
origin='lower',
|
||||||
|
interpolation='none',
|
||||||
|
cmap='magma'
|
||||||
|
)
|
||||||
|
|
||||||
|
fig.colorbar(im, ax=ax, label='Magnitude')
|
||||||
|
|
||||||
|
ax.set_title('C-Space Visualization')
|
||||||
|
ax.set_xlabel('Time (s)')
|
||||||
|
ax.set_ylabel('Frequency (Hz)')
|
||||||
|
|
||||||
|
num_samples = cspace_data.shape[0]
|
||||||
|
duration_s = num_samples / sample_rate
|
||||||
|
time_ticks = np.linspace(0, num_samples, num=10)
|
||||||
|
time_labels = np.linspace(0, duration_s, num=10)
|
||||||
|
ax.set_xticks(time_ticks)
|
||||||
|
ax.set_xticklabels([f'{t:.2f}' for t in time_labels])
|
||||||
|
|
||||||
|
num_nodes = len(freqs)
|
||||||
|
freq_indices = np.linspace(0, num_nodes - 1, num=10, dtype=int)
|
||||||
|
ax.set_yticks(freq_indices)
|
||||||
|
ax.set_yticklabels([f'{freqs[i]:.0f}' for i in freq_indices])
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.savefig(output_path)
|
||||||
|
print(f"Visualization saved to {output_path}")
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Generate a C-Space visualization and reconstructed audio from an audio file."
|
||||||
|
)
|
||||||
|
parser.add_argument("audio_path", type=str, help="Path to the input audio file.")
|
||||||
|
parser.add_argument("--config", type=str, default="cochlear_config.json", help="Path to the cochlear config JSON file.")
|
||||||
|
parser.add_argument("--output-dir", type=str, default=".", help="Directory to save the output files.")
|
||||||
|
parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device to use for computation (e.g., 'cuda', 'cpu').")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Create output directory if it doesn't exist
|
||||||
|
os.makedirs(args.output_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# --- 1. Set up paths ---
|
||||||
|
base_name = os.path.splitext(os.path.basename(args.audio_path))[0]
|
||||||
|
viz_output_path = os.path.join(args.output_dir, f"{base_name}_cspace.png")
|
||||||
|
recon_output_path = os.path.join(args.output_dir, f"{base_name}_recon.wav")
|
||||||
|
|
||||||
|
# --- 2. Load Model and Audio ---
|
||||||
|
print(f"Loading CSpace model with config: {args.config}")
|
||||||
|
cspace_model = CSpace(config=args.config, device=args.device)
|
||||||
|
config = cspace_model.config
|
||||||
|
|
||||||
|
print(f"Loading and preparing audio from: {args.audio_path}")
|
||||||
|
original_audio_np = load_and_prepare_audio(args.audio_path, int(config.sample_rate), args.device)
|
||||||
|
original_audio_np, _ = normalize(original_audio_np)
|
||||||
|
|
||||||
|
# --- 3. Encode and Decode ---
|
||||||
|
print("Encoding audio to C-Space...")
|
||||||
|
cspace_data = cspace_model.encode(original_audio_np)
|
||||||
|
|
||||||
|
print("Decoding C-Space back to audio...")
|
||||||
|
reconstructed_audio_np = cspace_model.decode(cspace_data)
|
||||||
|
|
||||||
|
# --- 4. Normalize, Calculate Stats, and Save Files ---
|
||||||
|
reconstructed_audio_np, _ = normalize(reconstructed_audio_np)
|
||||||
|
|
||||||
|
print_reconstruction_stats(original_audio_np, reconstructed_audio_np)
|
||||||
|
|
||||||
|
save_audio(reconstructed_audio_np, recon_output_path, int(config.sample_rate))
|
||||||
|
|
||||||
|
visualize_cspace(cspace_data, cspace_model.freqs, viz_output_path, config.sample_rate)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Reference in New Issue
Block a user