import torch import torchaudio import numpy as np import matplotlib.pyplot as plt import argparse from cspace import CSpace import os def normalize(data: np.ndarray) -> tuple[np.ndarray, float]: """Normalizes a numpy array to the range [-1.0, 1.0] and returns it and the peak value.""" peak = np.max(np.abs(data)) if peak == 0: return data, 1.0 return data / peak, peak def load_and_prepare_audio(audio_path: str, target_sample_rate: int, device: str) -> np.ndarray: """Loads an audio file, resamples it, and converts it to a mono float32 numpy array.""" waveform, sr = torchaudio.load(audio_path) waveform = waveform.to(device) # Convert to mono if waveform.shape[0] > 1: waveform = torch.mean(waveform, dim=0, keepdim=True) # Resample if necessary if sr != target_sample_rate: resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sample_rate).to(device) waveform = resampler(waveform) audio_np = waveform.squeeze().cpu().numpy().astype(np.float32) return audio_np def save_audio(audio_data: np.ndarray, path: str, sample_rate: int): """Saves a numpy audio array to a WAV file.""" # torchaudio.save expects a tensor, shape (channels, samples) audio_tensor = torch.from_numpy(audio_data).unsqueeze(0) torchaudio.save(path, audio_tensor, sample_rate) print(f"Reconstructed audio saved to {path}") def print_reconstruction_stats(original_signal: np.ndarray, reconstructed_signal: np.ndarray): """Calculates and prints RMS and MSE between two signals.""" # Ensure signals have the same length for comparison min_len = min(len(original_signal), len(reconstructed_signal)) original_trimmed = original_signal[:min_len] reconstructed_trimmed = reconstructed_signal[:min_len] # RMS Calculation rms_original = np.sqrt(np.mean(np.square(original_trimmed))) rms_reconstructed = np.sqrt(np.mean(np.square(reconstructed_trimmed))) # MSE Calculation mse = np.mean(np.square(original_trimmed - reconstructed_trimmed)) print("\n--- Reconstruction Stats ---") print(f" - Original RMS: {rms_original:.6f}") print(f" - Reconstructed RMS: {rms_reconstructed:.6f}") print(f" - Mean Squared Error (MSE): {mse:.8f}") print("--------------------------\n") def visualize_cspace(cspace_data: torch.Tensor, freqs: np.ndarray, output_path: str, sample_rate: float): """Generates and saves a visualization of the C-Space data.""" magnitude_data = torch.abs(cspace_data).cpu().numpy().T fig, ax = plt.subplots(figsize=(30, 5)) im = ax.imshow( magnitude_data, aspect='auto', origin='lower', interpolation='none', cmap='magma' ) fig.colorbar(im, ax=ax, label='Magnitude') ax.set_title('C-Space Visualization') ax.set_xlabel('Time (s)') ax.set_ylabel('Frequency (Hz)') num_samples = cspace_data.shape[0] duration_s = num_samples / sample_rate time_ticks = np.linspace(0, num_samples, num=10) time_labels = np.linspace(0, duration_s, num=10) ax.set_xticks(time_ticks) ax.set_xticklabels([f'{t:.2f}' for t in time_labels]) num_nodes = len(freqs) freq_indices = np.linspace(0, num_nodes - 1, num=10, dtype=int) ax.set_yticks(freq_indices) ax.set_yticklabels([f'{freqs[i]:.0f}' for i in freq_indices]) plt.tight_layout() plt.savefig(output_path) print(f"Visualization saved to {output_path}") def main(): parser = argparse.ArgumentParser( description="Generate a C-Space visualization and reconstructed audio from an audio file." ) parser.add_argument("audio_path", type=str, help="Path to the input audio file.") parser.add_argument("--config", type=str, default="cochlear_config.json", help="Path to the cochlear config JSON file.") parser.add_argument("--output-dir", type=str, default=".", help="Directory to save the output files.") parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device to use for computation (e.g., 'cuda', 'cpu').") args = parser.parse_args() # Create output directory if it doesn't exist os.makedirs(args.output_dir, exist_ok=True) # --- 1. Set up paths --- base_name = os.path.splitext(os.path.basename(args.audio_path))[0] viz_output_path = os.path.join(args.output_dir, f"{base_name}_cspace.png") recon_output_path = os.path.join(args.output_dir, f"{base_name}_recon.wav") # --- 2. Load Model and Audio --- print(f"Loading CSpace model with config: {args.config}") cspace_model = CSpace(config=args.config, device=args.device) config = cspace_model.config print(f"Loading and preparing audio from: {args.audio_path}") original_audio_np = load_and_prepare_audio(args.audio_path, int(config.sample_rate), args.device) original_audio_np, _ = normalize(original_audio_np) # --- 3. Encode and Decode --- print("Encoding audio to C-Space...") cspace_data = cspace_model.encode(original_audio_np) print("Decoding C-Space back to audio...") reconstructed_audio_np = cspace_model.decode(cspace_data) # --- 4. Normalize, Calculate Stats, and Save Files --- reconstructed_audio_np, _ = normalize(reconstructed_audio_np) print_reconstruction_stats(original_audio_np, reconstructed_audio_np) save_audio(reconstructed_audio_np, recon_output_path, int(config.sample_rate)) visualize_cspace(cspace_data, cspace_model.freqs, viz_output_path, config.sample_rate) if __name__ == "__main__": main()