Initial Commit

This commit is contained in:
2025-12-12 20:41:37 -06:00
commit 782d258660
11 changed files with 3464 additions and 0 deletions

154
cspace.py Normal file
View File

@@ -0,0 +1,154 @@
import torch
import torch.nn as nn
import numpy as np
import json
import math
import torch.fft
from typing import Optional, List, Tuple, Union
from dataclasses import dataclass
@dataclass
class CochlearConfig:
sample_rate: float = 44100.0
min_freq: float = 20.0
max_freq: float = 20000.0
num_nodes: int = 128
q_factor: float = 4.0
segment_size_seconds: float = 1.0
warmup_seconds: float = 0.25
normalize_output: bool = True
node_distribution: str = "log"
custom_frequencies: Optional[List[float]] = None
@classmethod
def from_json(cls, path: str):
with open(path, 'r') as f:
data = json.load(f)
return cls(**data)
class CSpace(nn.Module):
def __init__(self, config: Union[CochlearConfig, str], device: str = "cuda" if torch.cuda.is_available() else "cpu"):
super().__init__()
if isinstance(config, str):
self.config = CochlearConfig.from_json(config)
else:
self.config = config
self.device = torch.device(device)
self.setup_kernel()
def setup_kernel(self):
# 1. Frequency Setup
if self.config.node_distribution == "custom" and self.config.custom_frequencies:
freqs = np.array(self.config.custom_frequencies)
self.config.num_nodes = len(freqs)
elif self.config.node_distribution == "linear":
freqs = np.linspace(self.config.min_freq, self.config.max_freq, self.config.num_nodes)
else:
low = math.log10(self.config.min_freq)
high = math.log10(self.config.max_freq)
freqs = np.logspace(low, high, self.config.num_nodes)
self.freqs = freqs
omegas = torch.tensor(2.0 * math.pi * freqs, device=self.device, dtype=torch.float32)
gammas = omegas / (2.0 * self.config.q_factor)
seg_len = int(self.config.segment_size_seconds * self.config.sample_rate)
dt = 1.0 / self.config.sample_rate
t = torch.arange(seg_len, device=self.device, dtype=torch.float32) * dt
t = t.unsqueeze(1)
omegas_broad = omegas.unsqueeze(0)
gammas_broad = gammas.unsqueeze(0)
# 2. Create Unnormalized Decay Kernel (For State Transitions)
self.zir_decay_kernel = torch.complex(
torch.exp(-gammas_broad * t) * torch.cos(omegas_broad * t),
torch.exp(-gammas_broad * t) * torch.sin(omegas_broad * t)
)
# Calculate single-step transition factors for the boundaries
self.step_factors = self.zir_decay_kernel[1].clone()
# 3. Create Normalized Filter Kernel (For Convolution)
self.n_fft = seg_len * 2
self.impulse_response = self.zir_decay_kernel.clone()
self.kernel_fft = torch.fft.fft(self.impulse_response, n=self.n_fft, dim=0)
# Calculate Normalization Factors (Peak = 1.0)
num_bins = self.n_fft // 2 + 1
response_matrix = torch.abs(self.kernel_fft[:num_bins])
node_peak_responses = torch.zeros(self.config.num_nodes, device=self.device)
for i in range(self.config.num_nodes):
peak = torch.max(response_matrix[:, i])
node_peak_responses[i] = peak if peak > 1e-8 else 1.0
self.normalization_factors = 1.0 / node_peak_responses
# Apply normalization to the convolution kernel ONLY
self.impulse_response = self.impulse_response * self.normalization_factors.unsqueeze(0)
self.kernel_fft = torch.fft.fft(self.impulse_response, n=self.n_fft, dim=0)
def encode(self, audio_array: np.ndarray) -> torch.Tensor:
signal = torch.tensor(audio_array, dtype=torch.float32, device=self.device)
total_samples = len(signal)
seg_samples = int(self.config.segment_size_seconds * self.config.sample_rate)
try:
results = torch.empty((total_samples, self.config.num_nodes), dtype=torch.complex64, device=self.device)
except RuntimeError:
results = torch.empty((total_samples, self.config.num_nodes), dtype=torch.complex64, device="cpu")
current_state = torch.zeros(self.config.num_nodes, dtype=torch.complex64, device=self.device)
num_segments = math.ceil(total_samples / seg_samples)
for i in range(num_segments):
start = i * seg_samples
end = min(start + seg_samples, total_samples)
chunk_len = end - start
chunk = signal[start:end]
# FFT Convolution setup
if chunk_len != seg_samples:
fft_size = chunk_len + self.impulse_response.shape[0] - 1
fft_size = 2**math.ceil(math.log2(fft_size))
k_fft = torch.fft.fft(self.impulse_response[:chunk_len], n=fft_size, dim=0)
c_fft = torch.fft.fft(chunk, n=fft_size).unsqueeze(1)
else:
fft_size = self.n_fft
k_fft = self.kernel_fft
c_fft = torch.fft.fft(chunk, n=fft_size).unsqueeze(1)
conv_spectrum = c_fft * k_fft
conv_time = torch.fft.ifft(conv_spectrum, n=fft_size, dim=0)
zs_response = conv_time[:chunk_len]
# ZIR Calculation (State Decay)
decay_curve = self.zir_decay_kernel[:chunk_len]
zi_response = current_state.unsqueeze(0) * decay_curve
total_response = zs_response + zi_response
if results.device != total_response.device:
results[start:end] = total_response.to(results.device)
else:
results[start:end] = total_response
# State Update for Next Segment
current_state = total_response[-1] * self.step_factors
return results
def decode(self, cspace: torch.Tensor) -> np.ndarray:
if cspace.is_complex():
real_part = cspace.real
else:
real_part = cspace # Assume it's already real
# Simple summation of real parts, which is the correct inverse operation
recon = torch.sum(real_part, dim=1)
return recon.cpu().numpy()