Initial Commit
This commit is contained in:
154
cspace.py
Normal file
154
cspace.py
Normal file
@@ -0,0 +1,154 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
import json
|
||||
import math
|
||||
import torch.fft
|
||||
from typing import Optional, List, Tuple, Union
|
||||
from dataclasses import dataclass
|
||||
|
||||
@dataclass
|
||||
class CochlearConfig:
|
||||
sample_rate: float = 44100.0
|
||||
min_freq: float = 20.0
|
||||
max_freq: float = 20000.0
|
||||
num_nodes: int = 128
|
||||
q_factor: float = 4.0
|
||||
segment_size_seconds: float = 1.0
|
||||
warmup_seconds: float = 0.25
|
||||
normalize_output: bool = True
|
||||
node_distribution: str = "log"
|
||||
custom_frequencies: Optional[List[float]] = None
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, path: str):
|
||||
with open(path, 'r') as f:
|
||||
data = json.load(f)
|
||||
return cls(**data)
|
||||
|
||||
class CSpace(nn.Module):
|
||||
def __init__(self, config: Union[CochlearConfig, str], device: str = "cuda" if torch.cuda.is_available() else "cpu"):
|
||||
super().__init__()
|
||||
|
||||
if isinstance(config, str):
|
||||
self.config = CochlearConfig.from_json(config)
|
||||
else:
|
||||
self.config = config
|
||||
|
||||
self.device = torch.device(device)
|
||||
self.setup_kernel()
|
||||
|
||||
def setup_kernel(self):
|
||||
# 1. Frequency Setup
|
||||
if self.config.node_distribution == "custom" and self.config.custom_frequencies:
|
||||
freqs = np.array(self.config.custom_frequencies)
|
||||
self.config.num_nodes = len(freqs)
|
||||
elif self.config.node_distribution == "linear":
|
||||
freqs = np.linspace(self.config.min_freq, self.config.max_freq, self.config.num_nodes)
|
||||
else:
|
||||
low = math.log10(self.config.min_freq)
|
||||
high = math.log10(self.config.max_freq)
|
||||
freqs = np.logspace(low, high, self.config.num_nodes)
|
||||
|
||||
self.freqs = freqs
|
||||
omegas = torch.tensor(2.0 * math.pi * freqs, device=self.device, dtype=torch.float32)
|
||||
gammas = omegas / (2.0 * self.config.q_factor)
|
||||
|
||||
seg_len = int(self.config.segment_size_seconds * self.config.sample_rate)
|
||||
dt = 1.0 / self.config.sample_rate
|
||||
t = torch.arange(seg_len, device=self.device, dtype=torch.float32) * dt
|
||||
|
||||
t = t.unsqueeze(1)
|
||||
omegas_broad = omegas.unsqueeze(0)
|
||||
gammas_broad = gammas.unsqueeze(0)
|
||||
|
||||
# 2. Create Unnormalized Decay Kernel (For State Transitions)
|
||||
self.zir_decay_kernel = torch.complex(
|
||||
torch.exp(-gammas_broad * t) * torch.cos(omegas_broad * t),
|
||||
torch.exp(-gammas_broad * t) * torch.sin(omegas_broad * t)
|
||||
)
|
||||
|
||||
# Calculate single-step transition factors for the boundaries
|
||||
self.step_factors = self.zir_decay_kernel[1].clone()
|
||||
|
||||
# 3. Create Normalized Filter Kernel (For Convolution)
|
||||
self.n_fft = seg_len * 2
|
||||
|
||||
self.impulse_response = self.zir_decay_kernel.clone()
|
||||
self.kernel_fft = torch.fft.fft(self.impulse_response, n=self.n_fft, dim=0)
|
||||
|
||||
# Calculate Normalization Factors (Peak = 1.0)
|
||||
num_bins = self.n_fft // 2 + 1
|
||||
response_matrix = torch.abs(self.kernel_fft[:num_bins])
|
||||
|
||||
node_peak_responses = torch.zeros(self.config.num_nodes, device=self.device)
|
||||
for i in range(self.config.num_nodes):
|
||||
peak = torch.max(response_matrix[:, i])
|
||||
node_peak_responses[i] = peak if peak > 1e-8 else 1.0
|
||||
|
||||
self.normalization_factors = 1.0 / node_peak_responses
|
||||
|
||||
# Apply normalization to the convolution kernel ONLY
|
||||
self.impulse_response = self.impulse_response * self.normalization_factors.unsqueeze(0)
|
||||
self.kernel_fft = torch.fft.fft(self.impulse_response, n=self.n_fft, dim=0)
|
||||
|
||||
def encode(self, audio_array: np.ndarray) -> torch.Tensor:
|
||||
signal = torch.tensor(audio_array, dtype=torch.float32, device=self.device)
|
||||
total_samples = len(signal)
|
||||
seg_samples = int(self.config.segment_size_seconds * self.config.sample_rate)
|
||||
|
||||
try:
|
||||
results = torch.empty((total_samples, self.config.num_nodes), dtype=torch.complex64, device=self.device)
|
||||
except RuntimeError:
|
||||
results = torch.empty((total_samples, self.config.num_nodes), dtype=torch.complex64, device="cpu")
|
||||
|
||||
current_state = torch.zeros(self.config.num_nodes, dtype=torch.complex64, device=self.device)
|
||||
num_segments = math.ceil(total_samples / seg_samples)
|
||||
|
||||
for i in range(num_segments):
|
||||
start = i * seg_samples
|
||||
end = min(start + seg_samples, total_samples)
|
||||
chunk_len = end - start
|
||||
chunk = signal[start:end]
|
||||
|
||||
# FFT Convolution setup
|
||||
if chunk_len != seg_samples:
|
||||
fft_size = chunk_len + self.impulse_response.shape[0] - 1
|
||||
fft_size = 2**math.ceil(math.log2(fft_size))
|
||||
k_fft = torch.fft.fft(self.impulse_response[:chunk_len], n=fft_size, dim=0)
|
||||
c_fft = torch.fft.fft(chunk, n=fft_size).unsqueeze(1)
|
||||
else:
|
||||
fft_size = self.n_fft
|
||||
k_fft = self.kernel_fft
|
||||
c_fft = torch.fft.fft(chunk, n=fft_size).unsqueeze(1)
|
||||
|
||||
conv_spectrum = c_fft * k_fft
|
||||
conv_time = torch.fft.ifft(conv_spectrum, n=fft_size, dim=0)
|
||||
zs_response = conv_time[:chunk_len]
|
||||
|
||||
# ZIR Calculation (State Decay)
|
||||
decay_curve = self.zir_decay_kernel[:chunk_len]
|
||||
zi_response = current_state.unsqueeze(0) * decay_curve
|
||||
|
||||
total_response = zs_response + zi_response
|
||||
|
||||
if results.device != total_response.device:
|
||||
results[start:end] = total_response.to(results.device)
|
||||
else:
|
||||
results[start:end] = total_response
|
||||
|
||||
# State Update for Next Segment
|
||||
current_state = total_response[-1] * self.step_factors
|
||||
|
||||
return results
|
||||
|
||||
def decode(self, cspace: torch.Tensor) -> np.ndarray:
|
||||
if cspace.is_complex():
|
||||
real_part = cspace.real
|
||||
else:
|
||||
real_part = cspace # Assume it's already real
|
||||
|
||||
# Simple summation of real parts, which is the correct inverse operation
|
||||
recon = torch.sum(real_part, dim=1)
|
||||
|
||||
return recon.cpu().numpy()
|
||||
Reference in New Issue
Block a user