116 lines
4.7 KiB
Python
116 lines
4.7 KiB
Python
import torch
|
|
import numpy as np
|
|
from torch.utils.data import Dataset
|
|
from datasets import load_dataset, Audio
|
|
|
|
def collate_variable_length(batch):
|
|
"""Custom collate for variable-length C-Space tensors. Returns (padded_data, mask)."""
|
|
max_time = max([x.shape[0] for x in batch])
|
|
batch_size = len(batch)
|
|
num_nodes = batch[0].shape[1]
|
|
|
|
padded = torch.zeros(batch_size, max_time, num_nodes, dtype=torch.complex64)
|
|
mask = torch.zeros(batch_size, max_time, 1, dtype=torch.bool)
|
|
|
|
for i, seq in enumerate(batch):
|
|
seq_len = seq.shape[0]
|
|
padded[i, :seq_len, :] = seq
|
|
mask[i, :seq_len, :] = True
|
|
|
|
return padded, mask
|
|
|
|
class CSpaceDataset(Dataset):
|
|
"""Dataset that processes audio on-the-fly to C-Space representations with sliding windows."""
|
|
|
|
def __init__(self, hf_dataset, cspace_model, sample_rate=24000, segment_seconds=1.0, overlap=0.5, augment=True, max_duration=5.0, min_duration=1.0):
|
|
self.dataset = hf_dataset
|
|
self.cspace_model = cspace_model
|
|
self.sample_rate = sample_rate
|
|
self.segment_seconds = segment_seconds
|
|
self.overlap = overlap
|
|
self.augment = augment
|
|
self.max_duration = max_duration
|
|
self.min_duration = min_duration
|
|
self.segment_samples = int(segment_seconds * sample_rate)
|
|
self.hop_samples = int(self.segment_samples * (1 - overlap))
|
|
self.max_samples = int(max_duration * sample_rate)
|
|
self.min_samples = int(min_duration * sample_rate)
|
|
|
|
# Build index of (audio_idx, segment_start) pairs
|
|
self.segment_indices = []
|
|
for i in range(len(hf_dataset)):
|
|
audio_len = len(hf_dataset[i]["audio"]["array"])
|
|
if audio_len >= self.min_samples:
|
|
# Extract sliding windows
|
|
truncated_len = min(audio_len, self.max_samples)
|
|
num_segments = max(1, (truncated_len - self.segment_samples) // self.hop_samples + 1)
|
|
for seg in range(num_segments):
|
|
start = seg * self.hop_samples
|
|
self.segment_indices.append((i, start))
|
|
|
|
def __len__(self):
|
|
return len(self.segment_indices)
|
|
|
|
def __getitem__(self, idx):
|
|
audio_idx, seg_start = self.segment_indices[idx]
|
|
sample = self.dataset[audio_idx]
|
|
audio_np = np.array(sample["audio"]["array"], dtype=np.float32)
|
|
|
|
# Extract segment
|
|
seg_end = seg_start + self.segment_samples
|
|
audio_segment = audio_np[seg_start:seg_end]
|
|
|
|
# Pad if too short
|
|
if len(audio_segment) < self.segment_samples:
|
|
audio_segment = np.pad(audio_segment, (0, self.segment_samples - len(audio_segment)))
|
|
|
|
# Augmentation (phase-tolerant)
|
|
if self.augment:
|
|
audio_segment = self._augment_audio(audio_segment)
|
|
|
|
# Normalize
|
|
peak = np.max(np.abs(audio_segment))
|
|
if peak > 0:
|
|
audio_segment = audio_segment / peak
|
|
|
|
# Encode to C-Space
|
|
cspace_data = self.cspace_model.encode(audio_segment)
|
|
|
|
# Already a tensor, just ensure dtype
|
|
if isinstance(cspace_data, torch.Tensor):
|
|
return cspace_data.to(torch.complex64)
|
|
return torch.tensor(cspace_data, dtype=torch.complex64)
|
|
|
|
def _augment_audio(self, audio):
|
|
"""Phase-tolerant augmentations."""
|
|
# Random gain
|
|
if np.random.rand() > 0.5:
|
|
audio = audio * np.random.uniform(0.8, 1.0)
|
|
|
|
# Time stretch (resample)
|
|
if np.random.rand() > 0.5:
|
|
stretch_factor = np.random.uniform(0.95, 1.05)
|
|
# Unsqueeze to (1, 1, time) for interpolate
|
|
audio_t = torch.tensor(audio).unsqueeze(0).unsqueeze(0)
|
|
audio_t = torch.nn.functional.interpolate(
|
|
audio_t,
|
|
scale_factor=stretch_factor,
|
|
mode='linear',
|
|
align_corners=False
|
|
)
|
|
audio = audio_t.squeeze().numpy()
|
|
|
|
# If we stretched it longer than segment_len, crop it; if shorter, pad
|
|
# For simplicity in this logical block, we usually let the caller handle exact sizing,
|
|
# but to be safe let's just take the center or start.
|
|
target_len = len(audio)
|
|
# (Note: simpler approach is just return it and let encodings handle length,
|
|
# but the dataset expects fixed window for batches usually.
|
|
# Here we just return modified audio.)
|
|
|
|
# Add small noise (phase-tolerant)
|
|
if np.random.rand() > 0.5:
|
|
noise = np.random.normal(0, 0.01, len(audio))
|
|
audio = audio + noise
|
|
|
|
return audio |