Initial Commit
This commit is contained in:
98
inference.py
Normal file
98
inference.py
Normal file
@@ -0,0 +1,98 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
import argparse
|
||||
import soundfile as sf
|
||||
import os
|
||||
import torchaudio
|
||||
|
||||
# Local imports
|
||||
from cspace import CSpace
|
||||
from model import CSpaceCompressor
|
||||
|
||||
def load_and_prepare_audio(audio_path, target_sample_rate, device):
|
||||
"""Loads, resamples, mono-mixes, and normalizes audio."""
|
||||
waveform, sr = torchaudio.load(audio_path)
|
||||
waveform = waveform.to(device)
|
||||
|
||||
# Mix to mono
|
||||
if waveform.shape[0] > 1:
|
||||
waveform = torch.mean(waveform, dim=0, keepdim=True)
|
||||
|
||||
# Resample
|
||||
if sr != target_sample_rate:
|
||||
resampler = torchaudio.transforms.Resample(sr, target_sample_rate).to(device)
|
||||
waveform = resampler(waveform)
|
||||
|
||||
audio_np = waveform.squeeze().cpu().numpy().astype(np.float32)
|
||||
|
||||
# Normalize input
|
||||
peak = np.max(np.abs(audio_np))
|
||||
if peak > 0:
|
||||
audio_np = audio_np / peak
|
||||
|
||||
return audio_np, peak
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Run inference on an audio file using CSpaceCompressor.")
|
||||
parser.add_argument("input_wav", type=str, help="Path to input .wav file")
|
||||
parser.add_argument("checkpoint", type=str, help="Path to model checkpoint (.pt)")
|
||||
parser.add_argument("--config", type=str, default="cochlear_config.json", help="Path to config json")
|
||||
parser.add_argument("--output", type=str, default="output.wav", help="Path to save output .wav")
|
||||
parser.add_argument("--num-nodes", type=int, default=32, help="Must match training config")
|
||||
parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
print(f"Using device: {args.device}")
|
||||
|
||||
# 1. Load CSpace (the transformation layer)
|
||||
print(f"Loading CSpace configuration from {args.config}...")
|
||||
cspace_model = CSpace(config=args.config, device=args.device)
|
||||
sample_rate = int(cspace_model.config.sample_rate)
|
||||
|
||||
# 2. Load Compressor Model
|
||||
print(f"Loading model checkpoint from {args.checkpoint}...")
|
||||
model = CSpaceCompressor(num_nodes=args.num_nodes).to(args.device)
|
||||
state_dict = torch.load(args.checkpoint, map_location=args.device)
|
||||
model.load_state_dict(state_dict)
|
||||
model.eval()
|
||||
|
||||
# 3. Load Audio
|
||||
print(f"Loading audio: {args.input_wav}")
|
||||
audio_input, original_peak = load_and_prepare_audio(args.input_wav, sample_rate, args.device)
|
||||
print(f"Audio loaded: {len(audio_input)} samples @ {sample_rate}Hz")
|
||||
|
||||
# 4. Encode to C-Space
|
||||
print("Encoding to C-Space...")
|
||||
# cspace_model.encode returns a complex tensor on args.device
|
||||
cspace_data = cspace_model.encode(audio_input)
|
||||
|
||||
# Model expects (Batch, Time, Nodes), so unsqueeze batch dim
|
||||
cspace_batch = cspace_data.unsqueeze(0) # -> (1, Time, Nodes)
|
||||
|
||||
# 5. Model Forward Pass
|
||||
print("Running Compressor...")
|
||||
with torch.no_grad():
|
||||
reconstructed_cspace = model(cspace_batch)
|
||||
|
||||
# Remove batch dim
|
||||
reconstructed_cspace = reconstructed_cspace.squeeze(0) # -> (Time, Nodes)
|
||||
|
||||
# 6. Decode back to Audio
|
||||
print("Decoding C-Space to Audio...")
|
||||
audio_output = cspace_model.decode(reconstructed_cspace)
|
||||
|
||||
# 7. Renormalize/Scale
|
||||
# Normalize output to prevent clipping, but try to respect original volume dynamics if desired
|
||||
# Here we just normalize the output to -1.0 to 1.0 safely.
|
||||
out_peak = np.max(np.abs(audio_output))
|
||||
if out_peak > 0:
|
||||
audio_output = audio_output / out_peak
|
||||
|
||||
# 8. Save
|
||||
print(f"Saving to {args.output}...")
|
||||
sf.write(args.output, audio_output, sample_rate)
|
||||
print("Done.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user