Files
cspace/dataset.py
2025-12-12 20:41:37 -06:00

116 lines
4.7 KiB
Python

import torch
import numpy as np
from torch.utils.data import Dataset
from datasets import load_dataset, Audio
def collate_variable_length(batch):
"""Custom collate for variable-length C-Space tensors. Returns (padded_data, mask)."""
max_time = max([x.shape[0] for x in batch])
batch_size = len(batch)
num_nodes = batch[0].shape[1]
padded = torch.zeros(batch_size, max_time, num_nodes, dtype=torch.complex64)
mask = torch.zeros(batch_size, max_time, 1, dtype=torch.bool)
for i, seq in enumerate(batch):
seq_len = seq.shape[0]
padded[i, :seq_len, :] = seq
mask[i, :seq_len, :] = True
return padded, mask
class CSpaceDataset(Dataset):
"""Dataset that processes audio on-the-fly to C-Space representations with sliding windows."""
def __init__(self, hf_dataset, cspace_model, sample_rate=24000, segment_seconds=1.0, overlap=0.5, augment=True, max_duration=5.0, min_duration=1.0):
self.dataset = hf_dataset
self.cspace_model = cspace_model
self.sample_rate = sample_rate
self.segment_seconds = segment_seconds
self.overlap = overlap
self.augment = augment
self.max_duration = max_duration
self.min_duration = min_duration
self.segment_samples = int(segment_seconds * sample_rate)
self.hop_samples = int(self.segment_samples * (1 - overlap))
self.max_samples = int(max_duration * sample_rate)
self.min_samples = int(min_duration * sample_rate)
# Build index of (audio_idx, segment_start) pairs
self.segment_indices = []
for i in range(len(hf_dataset)):
audio_len = len(hf_dataset[i]["audio"]["array"])
if audio_len >= self.min_samples:
# Extract sliding windows
truncated_len = min(audio_len, self.max_samples)
num_segments = max(1, (truncated_len - self.segment_samples) // self.hop_samples + 1)
for seg in range(num_segments):
start = seg * self.hop_samples
self.segment_indices.append((i, start))
def __len__(self):
return len(self.segment_indices)
def __getitem__(self, idx):
audio_idx, seg_start = self.segment_indices[idx]
sample = self.dataset[audio_idx]
audio_np = np.array(sample["audio"]["array"], dtype=np.float32)
# Extract segment
seg_end = seg_start + self.segment_samples
audio_segment = audio_np[seg_start:seg_end]
# Pad if too short
if len(audio_segment) < self.segment_samples:
audio_segment = np.pad(audio_segment, (0, self.segment_samples - len(audio_segment)))
# Augmentation (phase-tolerant)
if self.augment:
audio_segment = self._augment_audio(audio_segment)
# Normalize
peak = np.max(np.abs(audio_segment))
if peak > 0:
audio_segment = audio_segment / peak
# Encode to C-Space
cspace_data = self.cspace_model.encode(audio_segment)
# Already a tensor, just ensure dtype
if isinstance(cspace_data, torch.Tensor):
return cspace_data.to(torch.complex64)
return torch.tensor(cspace_data, dtype=torch.complex64)
def _augment_audio(self, audio):
"""Phase-tolerant augmentations."""
# Random gain
if np.random.rand() > 0.5:
audio = audio * np.random.uniform(0.8, 1.0)
# Time stretch (resample)
if np.random.rand() > 0.5:
stretch_factor = np.random.uniform(0.95, 1.05)
# Unsqueeze to (1, 1, time) for interpolate
audio_t = torch.tensor(audio).unsqueeze(0).unsqueeze(0)
audio_t = torch.nn.functional.interpolate(
audio_t,
scale_factor=stretch_factor,
mode='linear',
align_corners=False
)
audio = audio_t.squeeze().numpy()
# If we stretched it longer than segment_len, crop it; if shorter, pad
# For simplicity in this logical block, we usually let the caller handle exact sizing,
# but to be safe let's just take the center or start.
target_len = len(audio)
# (Note: simpler approach is just return it and let encodings handle length,
# but the dataset expects fixed window for batches usually.
# Here we just return modified audio.)
# Add small noise (phase-tolerant)
if np.random.rand() > 0.5:
noise = np.random.normal(0, 0.01, len(audio))
audio = audio + noise
return audio