import torch import numpy as np import argparse import soundfile as sf import os import torchaudio # Local imports from cspace import CSpace from model import CSpaceCompressor def load_and_prepare_audio(audio_path, target_sample_rate, device): """Loads, resamples, mono-mixes, and normalizes audio.""" waveform, sr = torchaudio.load(audio_path) waveform = waveform.to(device) # Mix to mono if waveform.shape[0] > 1: waveform = torch.mean(waveform, dim=0, keepdim=True) # Resample if sr != target_sample_rate: resampler = torchaudio.transforms.Resample(sr, target_sample_rate).to(device) waveform = resampler(waveform) audio_np = waveform.squeeze().cpu().numpy().astype(np.float32) # Normalize input peak = np.max(np.abs(audio_np)) if peak > 0: audio_np = audio_np / peak return audio_np, peak def main(): parser = argparse.ArgumentParser(description="Run inference on an audio file using CSpaceCompressor.") parser.add_argument("input_wav", type=str, help="Path to input .wav file") parser.add_argument("checkpoint", type=str, help="Path to model checkpoint (.pt)") parser.add_argument("--config", type=str, default="cochlear_config.json", help="Path to config json") parser.add_argument("--output", type=str, default="output.wav", help="Path to save output .wav") parser.add_argument("--num-nodes", type=int, default=32, help="Must match training config") parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu") args = parser.parse_args() print(f"Using device: {args.device}") # 1. Load CSpace (the transformation layer) print(f"Loading CSpace configuration from {args.config}...") cspace_model = CSpace(config=args.config, device=args.device) sample_rate = int(cspace_model.config.sample_rate) # 2. Load Compressor Model print(f"Loading model checkpoint from {args.checkpoint}...") model = CSpaceCompressor(num_nodes=args.num_nodes).to(args.device) state_dict = torch.load(args.checkpoint, map_location=args.device) model.load_state_dict(state_dict) model.eval() # 3. Load Audio print(f"Loading audio: {args.input_wav}") audio_input, original_peak = load_and_prepare_audio(args.input_wav, sample_rate, args.device) print(f"Audio loaded: {len(audio_input)} samples @ {sample_rate}Hz") # 4. Encode to C-Space print("Encoding to C-Space...") # cspace_model.encode returns a complex tensor on args.device cspace_data = cspace_model.encode(audio_input) # Model expects (Batch, Time, Nodes), so unsqueeze batch dim cspace_batch = cspace_data.unsqueeze(0) # -> (1, Time, Nodes) # 5. Model Forward Pass print("Running Compressor...") with torch.no_grad(): reconstructed_cspace = model(cspace_batch) # Remove batch dim reconstructed_cspace = reconstructed_cspace.squeeze(0) # -> (Time, Nodes) # 6. Decode back to Audio print("Decoding C-Space to Audio...") audio_output = cspace_model.decode(reconstructed_cspace) # 7. Renormalize/Scale # Normalize output to prevent clipping, but try to respect original volume dynamics if desired # Here we just normalize the output to -1.0 to 1.0 safely. out_peak = np.max(np.abs(audio_output)) if out_peak > 0: audio_output = audio_output / out_peak # 8. Save print(f"Saving to {args.output}...") sf.write(args.output, audio_output, sample_rate) print("Done.") if __name__ == "__main__": main()