154 lines
6.1 KiB
Python
154 lines
6.1 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import numpy as np
|
|
import json
|
|
import math
|
|
import torch.fft
|
|
from typing import Optional, List, Tuple, Union
|
|
from dataclasses import dataclass
|
|
|
|
@dataclass
|
|
class CochlearConfig:
|
|
sample_rate: float = 44100.0
|
|
min_freq: float = 20.0
|
|
max_freq: float = 20000.0
|
|
num_nodes: int = 128
|
|
q_factor: float = 4.0
|
|
segment_size_seconds: float = 1.0
|
|
warmup_seconds: float = 0.25
|
|
normalize_output: bool = True
|
|
node_distribution: str = "log"
|
|
custom_frequencies: Optional[List[float]] = None
|
|
|
|
@classmethod
|
|
def from_json(cls, path: str):
|
|
with open(path, 'r') as f:
|
|
data = json.load(f)
|
|
return cls(**data)
|
|
|
|
class CSpace(nn.Module):
|
|
def __init__(self, config: Union[CochlearConfig, str], device: str = "cuda" if torch.cuda.is_available() else "cpu"):
|
|
super().__init__()
|
|
|
|
if isinstance(config, str):
|
|
self.config = CochlearConfig.from_json(config)
|
|
else:
|
|
self.config = config
|
|
|
|
self.device = torch.device(device)
|
|
self.setup_kernel()
|
|
|
|
def setup_kernel(self):
|
|
# 1. Frequency Setup
|
|
if self.config.node_distribution == "custom" and self.config.custom_frequencies:
|
|
freqs = np.array(self.config.custom_frequencies)
|
|
self.config.num_nodes = len(freqs)
|
|
elif self.config.node_distribution == "linear":
|
|
freqs = np.linspace(self.config.min_freq, self.config.max_freq, self.config.num_nodes)
|
|
else:
|
|
low = math.log10(self.config.min_freq)
|
|
high = math.log10(self.config.max_freq)
|
|
freqs = np.logspace(low, high, self.config.num_nodes)
|
|
|
|
self.freqs = freqs
|
|
omegas = torch.tensor(2.0 * math.pi * freqs, device=self.device, dtype=torch.float32)
|
|
gammas = omegas / (2.0 * self.config.q_factor)
|
|
|
|
seg_len = int(self.config.segment_size_seconds * self.config.sample_rate)
|
|
dt = 1.0 / self.config.sample_rate
|
|
t = torch.arange(seg_len, device=self.device, dtype=torch.float32) * dt
|
|
|
|
t = t.unsqueeze(1)
|
|
omegas_broad = omegas.unsqueeze(0)
|
|
gammas_broad = gammas.unsqueeze(0)
|
|
|
|
# 2. Create Unnormalized Decay Kernel (For State Transitions)
|
|
self.zir_decay_kernel = torch.complex(
|
|
torch.exp(-gammas_broad * t) * torch.cos(omegas_broad * t),
|
|
torch.exp(-gammas_broad * t) * torch.sin(omegas_broad * t)
|
|
)
|
|
|
|
# Calculate single-step transition factors for the boundaries
|
|
self.step_factors = self.zir_decay_kernel[1].clone()
|
|
|
|
# 3. Create Normalized Filter Kernel (For Convolution)
|
|
self.n_fft = seg_len * 2
|
|
|
|
self.impulse_response = self.zir_decay_kernel.clone()
|
|
self.kernel_fft = torch.fft.fft(self.impulse_response, n=self.n_fft, dim=0)
|
|
|
|
# Calculate Normalization Factors (Peak = 1.0)
|
|
num_bins = self.n_fft // 2 + 1
|
|
response_matrix = torch.abs(self.kernel_fft[:num_bins])
|
|
|
|
node_peak_responses = torch.zeros(self.config.num_nodes, device=self.device)
|
|
for i in range(self.config.num_nodes):
|
|
peak = torch.max(response_matrix[:, i])
|
|
node_peak_responses[i] = peak if peak > 1e-8 else 1.0
|
|
|
|
self.normalization_factors = 1.0 / node_peak_responses
|
|
|
|
# Apply normalization to the convolution kernel ONLY
|
|
self.impulse_response = self.impulse_response * self.normalization_factors.unsqueeze(0)
|
|
self.kernel_fft = torch.fft.fft(self.impulse_response, n=self.n_fft, dim=0)
|
|
|
|
def encode(self, audio_array: np.ndarray) -> torch.Tensor:
|
|
signal = torch.tensor(audio_array, dtype=torch.float32, device=self.device)
|
|
total_samples = len(signal)
|
|
seg_samples = int(self.config.segment_size_seconds * self.config.sample_rate)
|
|
|
|
try:
|
|
results = torch.empty((total_samples, self.config.num_nodes), dtype=torch.complex64, device=self.device)
|
|
except RuntimeError:
|
|
results = torch.empty((total_samples, self.config.num_nodes), dtype=torch.complex64, device="cpu")
|
|
|
|
current_state = torch.zeros(self.config.num_nodes, dtype=torch.complex64, device=self.device)
|
|
num_segments = math.ceil(total_samples / seg_samples)
|
|
|
|
for i in range(num_segments):
|
|
start = i * seg_samples
|
|
end = min(start + seg_samples, total_samples)
|
|
chunk_len = end - start
|
|
chunk = signal[start:end]
|
|
|
|
# FFT Convolution setup
|
|
if chunk_len != seg_samples:
|
|
fft_size = chunk_len + self.impulse_response.shape[0] - 1
|
|
fft_size = 2**math.ceil(math.log2(fft_size))
|
|
k_fft = torch.fft.fft(self.impulse_response[:chunk_len], n=fft_size, dim=0)
|
|
c_fft = torch.fft.fft(chunk, n=fft_size).unsqueeze(1)
|
|
else:
|
|
fft_size = self.n_fft
|
|
k_fft = self.kernel_fft
|
|
c_fft = torch.fft.fft(chunk, n=fft_size).unsqueeze(1)
|
|
|
|
conv_spectrum = c_fft * k_fft
|
|
conv_time = torch.fft.ifft(conv_spectrum, n=fft_size, dim=0)
|
|
zs_response = conv_time[:chunk_len]
|
|
|
|
# ZIR Calculation (State Decay)
|
|
decay_curve = self.zir_decay_kernel[:chunk_len]
|
|
zi_response = current_state.unsqueeze(0) * decay_curve
|
|
|
|
total_response = zs_response + zi_response
|
|
|
|
if results.device != total_response.device:
|
|
results[start:end] = total_response.to(results.device)
|
|
else:
|
|
results[start:end] = total_response
|
|
|
|
# State Update for Next Segment
|
|
current_state = total_response[-1] * self.step_factors
|
|
|
|
return results
|
|
|
|
def decode(self, cspace: torch.Tensor) -> np.ndarray:
|
|
if cspace.is_complex():
|
|
real_part = cspace.real
|
|
else:
|
|
real_part = cspace # Assume it's already real
|
|
|
|
# Simple summation of real parts, which is the correct inverse operation
|
|
recon = torch.sum(real_part, dim=1)
|
|
|
|
return recon.cpu().numpy() |