Files
cspace/cspace.py
2025-12-12 20:41:37 -06:00

154 lines
6.1 KiB
Python

import torch
import torch.nn as nn
import numpy as np
import json
import math
import torch.fft
from typing import Optional, List, Tuple, Union
from dataclasses import dataclass
@dataclass
class CochlearConfig:
sample_rate: float = 44100.0
min_freq: float = 20.0
max_freq: float = 20000.0
num_nodes: int = 128
q_factor: float = 4.0
segment_size_seconds: float = 1.0
warmup_seconds: float = 0.25
normalize_output: bool = True
node_distribution: str = "log"
custom_frequencies: Optional[List[float]] = None
@classmethod
def from_json(cls, path: str):
with open(path, 'r') as f:
data = json.load(f)
return cls(**data)
class CSpace(nn.Module):
def __init__(self, config: Union[CochlearConfig, str], device: str = "cuda" if torch.cuda.is_available() else "cpu"):
super().__init__()
if isinstance(config, str):
self.config = CochlearConfig.from_json(config)
else:
self.config = config
self.device = torch.device(device)
self.setup_kernel()
def setup_kernel(self):
# 1. Frequency Setup
if self.config.node_distribution == "custom" and self.config.custom_frequencies:
freqs = np.array(self.config.custom_frequencies)
self.config.num_nodes = len(freqs)
elif self.config.node_distribution == "linear":
freqs = np.linspace(self.config.min_freq, self.config.max_freq, self.config.num_nodes)
else:
low = math.log10(self.config.min_freq)
high = math.log10(self.config.max_freq)
freqs = np.logspace(low, high, self.config.num_nodes)
self.freqs = freqs
omegas = torch.tensor(2.0 * math.pi * freqs, device=self.device, dtype=torch.float32)
gammas = omegas / (2.0 * self.config.q_factor)
seg_len = int(self.config.segment_size_seconds * self.config.sample_rate)
dt = 1.0 / self.config.sample_rate
t = torch.arange(seg_len, device=self.device, dtype=torch.float32) * dt
t = t.unsqueeze(1)
omegas_broad = omegas.unsqueeze(0)
gammas_broad = gammas.unsqueeze(0)
# 2. Create Unnormalized Decay Kernel (For State Transitions)
self.zir_decay_kernel = torch.complex(
torch.exp(-gammas_broad * t) * torch.cos(omegas_broad * t),
torch.exp(-gammas_broad * t) * torch.sin(omegas_broad * t)
)
# Calculate single-step transition factors for the boundaries
self.step_factors = self.zir_decay_kernel[1].clone()
# 3. Create Normalized Filter Kernel (For Convolution)
self.n_fft = seg_len * 2
self.impulse_response = self.zir_decay_kernel.clone()
self.kernel_fft = torch.fft.fft(self.impulse_response, n=self.n_fft, dim=0)
# Calculate Normalization Factors (Peak = 1.0)
num_bins = self.n_fft // 2 + 1
response_matrix = torch.abs(self.kernel_fft[:num_bins])
node_peak_responses = torch.zeros(self.config.num_nodes, device=self.device)
for i in range(self.config.num_nodes):
peak = torch.max(response_matrix[:, i])
node_peak_responses[i] = peak if peak > 1e-8 else 1.0
self.normalization_factors = 1.0 / node_peak_responses
# Apply normalization to the convolution kernel ONLY
self.impulse_response = self.impulse_response * self.normalization_factors.unsqueeze(0)
self.kernel_fft = torch.fft.fft(self.impulse_response, n=self.n_fft, dim=0)
def encode(self, audio_array: np.ndarray) -> torch.Tensor:
signal = torch.tensor(audio_array, dtype=torch.float32, device=self.device)
total_samples = len(signal)
seg_samples = int(self.config.segment_size_seconds * self.config.sample_rate)
try:
results = torch.empty((total_samples, self.config.num_nodes), dtype=torch.complex64, device=self.device)
except RuntimeError:
results = torch.empty((total_samples, self.config.num_nodes), dtype=torch.complex64, device="cpu")
current_state = torch.zeros(self.config.num_nodes, dtype=torch.complex64, device=self.device)
num_segments = math.ceil(total_samples / seg_samples)
for i in range(num_segments):
start = i * seg_samples
end = min(start + seg_samples, total_samples)
chunk_len = end - start
chunk = signal[start:end]
# FFT Convolution setup
if chunk_len != seg_samples:
fft_size = chunk_len + self.impulse_response.shape[0] - 1
fft_size = 2**math.ceil(math.log2(fft_size))
k_fft = torch.fft.fft(self.impulse_response[:chunk_len], n=fft_size, dim=0)
c_fft = torch.fft.fft(chunk, n=fft_size).unsqueeze(1)
else:
fft_size = self.n_fft
k_fft = self.kernel_fft
c_fft = torch.fft.fft(chunk, n=fft_size).unsqueeze(1)
conv_spectrum = c_fft * k_fft
conv_time = torch.fft.ifft(conv_spectrum, n=fft_size, dim=0)
zs_response = conv_time[:chunk_len]
# ZIR Calculation (State Decay)
decay_curve = self.zir_decay_kernel[:chunk_len]
zi_response = current_state.unsqueeze(0) * decay_curve
total_response = zs_response + zi_response
if results.device != total_response.device:
results[start:end] = total_response.to(results.device)
else:
results[start:end] = total_response
# State Update for Next Segment
current_state = total_response[-1] * self.step_factors
return results
def decode(self, cspace: torch.Tensor) -> np.ndarray:
if cspace.is_complex():
real_part = cspace.real
else:
real_part = cspace # Assume it's already real
# Simple summation of real parts, which is the correct inverse operation
recon = torch.sum(real_part, dim=1)
return recon.cpu().numpy()