import torch import numpy as np from torch.utils.data import Dataset from datasets import load_dataset, Audio def collate_variable_length(batch): """Custom collate for variable-length C-Space tensors. Returns (padded_data, mask).""" max_time = max([x.shape[0] for x in batch]) batch_size = len(batch) num_nodes = batch[0].shape[1] padded = torch.zeros(batch_size, max_time, num_nodes, dtype=torch.complex64) mask = torch.zeros(batch_size, max_time, 1, dtype=torch.bool) for i, seq in enumerate(batch): seq_len = seq.shape[0] padded[i, :seq_len, :] = seq mask[i, :seq_len, :] = True return padded, mask class CSpaceDataset(Dataset): """Dataset that processes audio on-the-fly to C-Space representations with sliding windows.""" def __init__(self, hf_dataset, cspace_model, sample_rate=24000, segment_seconds=1.0, overlap=0.5, augment=True, max_duration=5.0, min_duration=1.0): self.dataset = hf_dataset self.cspace_model = cspace_model self.sample_rate = sample_rate self.segment_seconds = segment_seconds self.overlap = overlap self.augment = augment self.max_duration = max_duration self.min_duration = min_duration self.segment_samples = int(segment_seconds * sample_rate) self.hop_samples = int(self.segment_samples * (1 - overlap)) self.max_samples = int(max_duration * sample_rate) self.min_samples = int(min_duration * sample_rate) # Build index of (audio_idx, segment_start) pairs self.segment_indices = [] for i in range(len(hf_dataset)): audio_len = len(hf_dataset[i]["audio"]["array"]) if audio_len >= self.min_samples: # Extract sliding windows truncated_len = min(audio_len, self.max_samples) num_segments = max(1, (truncated_len - self.segment_samples) // self.hop_samples + 1) for seg in range(num_segments): start = seg * self.hop_samples self.segment_indices.append((i, start)) def __len__(self): return len(self.segment_indices) def __getitem__(self, idx): audio_idx, seg_start = self.segment_indices[idx] sample = self.dataset[audio_idx] audio_np = np.array(sample["audio"]["array"], dtype=np.float32) # Extract segment seg_end = seg_start + self.segment_samples audio_segment = audio_np[seg_start:seg_end] # Pad if too short if len(audio_segment) < self.segment_samples: audio_segment = np.pad(audio_segment, (0, self.segment_samples - len(audio_segment))) # Augmentation (phase-tolerant) if self.augment: audio_segment = self._augment_audio(audio_segment) # Normalize peak = np.max(np.abs(audio_segment)) if peak > 0: audio_segment = audio_segment / peak # Encode to C-Space cspace_data = self.cspace_model.encode(audio_segment) # Already a tensor, just ensure dtype if isinstance(cspace_data, torch.Tensor): return cspace_data.to(torch.complex64) return torch.tensor(cspace_data, dtype=torch.complex64) def _augment_audio(self, audio): """Phase-tolerant augmentations.""" # Random gain if np.random.rand() > 0.5: audio = audio * np.random.uniform(0.8, 1.0) # Time stretch (resample) if np.random.rand() > 0.5: stretch_factor = np.random.uniform(0.95, 1.05) # Unsqueeze to (1, 1, time) for interpolate audio_t = torch.tensor(audio).unsqueeze(0).unsqueeze(0) audio_t = torch.nn.functional.interpolate( audio_t, scale_factor=stretch_factor, mode='linear', align_corners=False ) audio = audio_t.squeeze().numpy() # If we stretched it longer than segment_len, crop it; if shorter, pad # For simplicity in this logical block, we usually let the caller handle exact sizing, # but to be safe let's just take the center or start. target_len = len(audio) # (Note: simpler approach is just return it and let encodings handle length, # but the dataset expects fixed window for batches usually. # Here we just return modified audio.) # Add small noise (phase-tolerant) if np.random.rand() > 0.5: noise = np.random.normal(0, 0.01, len(audio)) audio = audio + noise return audio