Initial Commit
This commit is contained in:
143
visualize.py
Normal file
143
visualize.py
Normal file
@@ -0,0 +1,143 @@
|
||||
import torch
|
||||
import torchaudio
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import argparse
|
||||
from cspace import CSpace
|
||||
import os
|
||||
|
||||
def normalize(data: np.ndarray) -> tuple[np.ndarray, float]:
|
||||
"""Normalizes a numpy array to the range [-1.0, 1.0] and returns it and the peak value."""
|
||||
peak = np.max(np.abs(data))
|
||||
if peak == 0:
|
||||
return data, 1.0
|
||||
return data / peak, peak
|
||||
|
||||
def load_and_prepare_audio(audio_path: str, target_sample_rate: int, device: str) -> np.ndarray:
|
||||
"""Loads an audio file, resamples it, and converts it to a mono float32 numpy array."""
|
||||
waveform, sr = torchaudio.load(audio_path)
|
||||
|
||||
waveform = waveform.to(device)
|
||||
|
||||
# Convert to mono
|
||||
if waveform.shape[0] > 1:
|
||||
waveform = torch.mean(waveform, dim=0, keepdim=True)
|
||||
|
||||
# Resample if necessary
|
||||
if sr != target_sample_rate:
|
||||
resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sample_rate).to(device)
|
||||
waveform = resampler(waveform)
|
||||
|
||||
audio_np = waveform.squeeze().cpu().numpy().astype(np.float32)
|
||||
return audio_np
|
||||
|
||||
def save_audio(audio_data: np.ndarray, path: str, sample_rate: int):
|
||||
"""Saves a numpy audio array to a WAV file."""
|
||||
# torchaudio.save expects a tensor, shape (channels, samples)
|
||||
audio_tensor = torch.from_numpy(audio_data).unsqueeze(0)
|
||||
torchaudio.save(path, audio_tensor, sample_rate)
|
||||
print(f"Reconstructed audio saved to {path}")
|
||||
|
||||
def print_reconstruction_stats(original_signal: np.ndarray, reconstructed_signal: np.ndarray):
|
||||
"""Calculates and prints RMS and MSE between two signals."""
|
||||
# Ensure signals have the same length for comparison
|
||||
min_len = min(len(original_signal), len(reconstructed_signal))
|
||||
original_trimmed = original_signal[:min_len]
|
||||
reconstructed_trimmed = reconstructed_signal[:min_len]
|
||||
|
||||
# RMS Calculation
|
||||
rms_original = np.sqrt(np.mean(np.square(original_trimmed)))
|
||||
rms_reconstructed = np.sqrt(np.mean(np.square(reconstructed_trimmed)))
|
||||
|
||||
# MSE Calculation
|
||||
mse = np.mean(np.square(original_trimmed - reconstructed_trimmed))
|
||||
|
||||
print("\n--- Reconstruction Stats ---")
|
||||
print(f" - Original RMS: {rms_original:.6f}")
|
||||
print(f" - Reconstructed RMS: {rms_reconstructed:.6f}")
|
||||
print(f" - Mean Squared Error (MSE): {mse:.8f}")
|
||||
print("--------------------------\n")
|
||||
|
||||
def visualize_cspace(cspace_data: torch.Tensor, freqs: np.ndarray, output_path: str, sample_rate: float):
|
||||
"""Generates and saves a visualization of the C-Space data."""
|
||||
magnitude_data = torch.abs(cspace_data).cpu().numpy().T
|
||||
|
||||
fig, ax = plt.subplots(figsize=(30, 5))
|
||||
|
||||
im = ax.imshow(
|
||||
magnitude_data,
|
||||
aspect='auto',
|
||||
origin='lower',
|
||||
interpolation='none',
|
||||
cmap='magma'
|
||||
)
|
||||
|
||||
fig.colorbar(im, ax=ax, label='Magnitude')
|
||||
|
||||
ax.set_title('C-Space Visualization')
|
||||
ax.set_xlabel('Time (s)')
|
||||
ax.set_ylabel('Frequency (Hz)')
|
||||
|
||||
num_samples = cspace_data.shape[0]
|
||||
duration_s = num_samples / sample_rate
|
||||
time_ticks = np.linspace(0, num_samples, num=10)
|
||||
time_labels = np.linspace(0, duration_s, num=10)
|
||||
ax.set_xticks(time_ticks)
|
||||
ax.set_xticklabels([f'{t:.2f}' for t in time_labels])
|
||||
|
||||
num_nodes = len(freqs)
|
||||
freq_indices = np.linspace(0, num_nodes - 1, num=10, dtype=int)
|
||||
ax.set_yticks(freq_indices)
|
||||
ax.set_yticklabels([f'{freqs[i]:.0f}' for i in freq_indices])
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(output_path)
|
||||
print(f"Visualization saved to {output_path}")
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Generate a C-Space visualization and reconstructed audio from an audio file."
|
||||
)
|
||||
parser.add_argument("audio_path", type=str, help="Path to the input audio file.")
|
||||
parser.add_argument("--config", type=str, default="cochlear_config.json", help="Path to the cochlear config JSON file.")
|
||||
parser.add_argument("--output-dir", type=str, default=".", help="Directory to save the output files.")
|
||||
parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device to use for computation (e.g., 'cuda', 'cpu').")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Create output directory if it doesn't exist
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
# --- 1. Set up paths ---
|
||||
base_name = os.path.splitext(os.path.basename(args.audio_path))[0]
|
||||
viz_output_path = os.path.join(args.output_dir, f"{base_name}_cspace.png")
|
||||
recon_output_path = os.path.join(args.output_dir, f"{base_name}_recon.wav")
|
||||
|
||||
# --- 2. Load Model and Audio ---
|
||||
print(f"Loading CSpace model with config: {args.config}")
|
||||
cspace_model = CSpace(config=args.config, device=args.device)
|
||||
config = cspace_model.config
|
||||
|
||||
print(f"Loading and preparing audio from: {args.audio_path}")
|
||||
original_audio_np = load_and_prepare_audio(args.audio_path, int(config.sample_rate), args.device)
|
||||
original_audio_np, _ = normalize(original_audio_np)
|
||||
|
||||
# --- 3. Encode and Decode ---
|
||||
print("Encoding audio to C-Space...")
|
||||
cspace_data = cspace_model.encode(original_audio_np)
|
||||
|
||||
print("Decoding C-Space back to audio...")
|
||||
reconstructed_audio_np = cspace_model.decode(cspace_data)
|
||||
|
||||
# --- 4. Normalize, Calculate Stats, and Save Files ---
|
||||
reconstructed_audio_np, _ = normalize(reconstructed_audio_np)
|
||||
|
||||
print_reconstruction_stats(original_audio_np, reconstructed_audio_np)
|
||||
|
||||
save_audio(reconstructed_audio_np, recon_output_path, int(config.sample_rate))
|
||||
|
||||
visualize_cspace(cspace_data, cspace_model.freqs, viz_output_path, config.sample_rate)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user