import torch import torch.nn as nn import numpy as np import json import math import torch.fft from typing import Optional, List, Tuple, Union from dataclasses import dataclass @dataclass class CochlearConfig: sample_rate: float = 44100.0 min_freq: float = 20.0 max_freq: float = 20000.0 num_nodes: int = 128 q_factor: float = 4.0 segment_size_seconds: float = 1.0 warmup_seconds: float = 0.25 normalize_output: bool = True node_distribution: str = "log" custom_frequencies: Optional[List[float]] = None @classmethod def from_json(cls, path: str): with open(path, 'r') as f: data = json.load(f) return cls(**data) class CSpace(nn.Module): def __init__(self, config: Union[CochlearConfig, str], device: str = "cuda" if torch.cuda.is_available() else "cpu"): super().__init__() if isinstance(config, str): self.config = CochlearConfig.from_json(config) else: self.config = config self.device = torch.device(device) self.setup_kernel() def setup_kernel(self): # 1. Frequency Setup if self.config.node_distribution == "custom" and self.config.custom_frequencies: freqs = np.array(self.config.custom_frequencies) self.config.num_nodes = len(freqs) elif self.config.node_distribution == "linear": freqs = np.linspace(self.config.min_freq, self.config.max_freq, self.config.num_nodes) else: low = math.log10(self.config.min_freq) high = math.log10(self.config.max_freq) freqs = np.logspace(low, high, self.config.num_nodes) self.freqs = freqs omegas = torch.tensor(2.0 * math.pi * freqs, device=self.device, dtype=torch.float32) gammas = omegas / (2.0 * self.config.q_factor) seg_len = int(self.config.segment_size_seconds * self.config.sample_rate) dt = 1.0 / self.config.sample_rate t = torch.arange(seg_len, device=self.device, dtype=torch.float32) * dt t = t.unsqueeze(1) omegas_broad = omegas.unsqueeze(0) gammas_broad = gammas.unsqueeze(0) # 2. Create Unnormalized Decay Kernel (For State Transitions) self.zir_decay_kernel = torch.complex( torch.exp(-gammas_broad * t) * torch.cos(omegas_broad * t), torch.exp(-gammas_broad * t) * torch.sin(omegas_broad * t) ) # Calculate single-step transition factors for the boundaries self.step_factors = self.zir_decay_kernel[1].clone() # 3. Create Normalized Filter Kernel (For Convolution) self.n_fft = seg_len * 2 self.impulse_response = self.zir_decay_kernel.clone() self.kernel_fft = torch.fft.fft(self.impulse_response, n=self.n_fft, dim=0) # Calculate Normalization Factors (Peak = 1.0) num_bins = self.n_fft // 2 + 1 response_matrix = torch.abs(self.kernel_fft[:num_bins]) node_peak_responses = torch.zeros(self.config.num_nodes, device=self.device) for i in range(self.config.num_nodes): peak = torch.max(response_matrix[:, i]) node_peak_responses[i] = peak if peak > 1e-8 else 1.0 self.normalization_factors = 1.0 / node_peak_responses # Apply normalization to the convolution kernel ONLY self.impulse_response = self.impulse_response * self.normalization_factors.unsqueeze(0) self.kernel_fft = torch.fft.fft(self.impulse_response, n=self.n_fft, dim=0) def encode(self, audio_array: np.ndarray) -> torch.Tensor: signal = torch.tensor(audio_array, dtype=torch.float32, device=self.device) total_samples = len(signal) seg_samples = int(self.config.segment_size_seconds * self.config.sample_rate) try: results = torch.empty((total_samples, self.config.num_nodes), dtype=torch.complex64, device=self.device) except RuntimeError: results = torch.empty((total_samples, self.config.num_nodes), dtype=torch.complex64, device="cpu") current_state = torch.zeros(self.config.num_nodes, dtype=torch.complex64, device=self.device) num_segments = math.ceil(total_samples / seg_samples) for i in range(num_segments): start = i * seg_samples end = min(start + seg_samples, total_samples) chunk_len = end - start chunk = signal[start:end] # FFT Convolution setup if chunk_len != seg_samples: fft_size = chunk_len + self.impulse_response.shape[0] - 1 fft_size = 2**math.ceil(math.log2(fft_size)) k_fft = torch.fft.fft(self.impulse_response[:chunk_len], n=fft_size, dim=0) c_fft = torch.fft.fft(chunk, n=fft_size).unsqueeze(1) else: fft_size = self.n_fft k_fft = self.kernel_fft c_fft = torch.fft.fft(chunk, n=fft_size).unsqueeze(1) conv_spectrum = c_fft * k_fft conv_time = torch.fft.ifft(conv_spectrum, n=fft_size, dim=0) zs_response = conv_time[:chunk_len] # ZIR Calculation (State Decay) decay_curve = self.zir_decay_kernel[:chunk_len] zi_response = current_state.unsqueeze(0) * decay_curve total_response = zs_response + zi_response if results.device != total_response.device: results[start:end] = total_response.to(results.device) else: results[start:end] = total_response # State Update for Next Segment current_state = total_response[-1] * self.step_factors return results def decode(self, cspace: torch.Tensor) -> np.ndarray: if cspace.is_complex(): real_part = cspace.real else: real_part = cspace # Assume it's already real # Simple summation of real parts, which is the correct inverse operation recon = torch.sum(real_part, dim=1) return recon.cpu().numpy()