Initial Commit
This commit is contained in:
116
dataset.py
Normal file
116
dataset.py
Normal file
@@ -0,0 +1,116 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from torch.utils.data import Dataset
|
||||
from datasets import load_dataset, Audio
|
||||
|
||||
def collate_variable_length(batch):
|
||||
"""Custom collate for variable-length C-Space tensors. Returns (padded_data, mask)."""
|
||||
max_time = max([x.shape[0] for x in batch])
|
||||
batch_size = len(batch)
|
||||
num_nodes = batch[0].shape[1]
|
||||
|
||||
padded = torch.zeros(batch_size, max_time, num_nodes, dtype=torch.complex64)
|
||||
mask = torch.zeros(batch_size, max_time, 1, dtype=torch.bool)
|
||||
|
||||
for i, seq in enumerate(batch):
|
||||
seq_len = seq.shape[0]
|
||||
padded[i, :seq_len, :] = seq
|
||||
mask[i, :seq_len, :] = True
|
||||
|
||||
return padded, mask
|
||||
|
||||
class CSpaceDataset(Dataset):
|
||||
"""Dataset that processes audio on-the-fly to C-Space representations with sliding windows."""
|
||||
|
||||
def __init__(self, hf_dataset, cspace_model, sample_rate=24000, segment_seconds=1.0, overlap=0.5, augment=True, max_duration=5.0, min_duration=1.0):
|
||||
self.dataset = hf_dataset
|
||||
self.cspace_model = cspace_model
|
||||
self.sample_rate = sample_rate
|
||||
self.segment_seconds = segment_seconds
|
||||
self.overlap = overlap
|
||||
self.augment = augment
|
||||
self.max_duration = max_duration
|
||||
self.min_duration = min_duration
|
||||
self.segment_samples = int(segment_seconds * sample_rate)
|
||||
self.hop_samples = int(self.segment_samples * (1 - overlap))
|
||||
self.max_samples = int(max_duration * sample_rate)
|
||||
self.min_samples = int(min_duration * sample_rate)
|
||||
|
||||
# Build index of (audio_idx, segment_start) pairs
|
||||
self.segment_indices = []
|
||||
for i in range(len(hf_dataset)):
|
||||
audio_len = len(hf_dataset[i]["audio"]["array"])
|
||||
if audio_len >= self.min_samples:
|
||||
# Extract sliding windows
|
||||
truncated_len = min(audio_len, self.max_samples)
|
||||
num_segments = max(1, (truncated_len - self.segment_samples) // self.hop_samples + 1)
|
||||
for seg in range(num_segments):
|
||||
start = seg * self.hop_samples
|
||||
self.segment_indices.append((i, start))
|
||||
|
||||
def __len__(self):
|
||||
return len(self.segment_indices)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
audio_idx, seg_start = self.segment_indices[idx]
|
||||
sample = self.dataset[audio_idx]
|
||||
audio_np = np.array(sample["audio"]["array"], dtype=np.float32)
|
||||
|
||||
# Extract segment
|
||||
seg_end = seg_start + self.segment_samples
|
||||
audio_segment = audio_np[seg_start:seg_end]
|
||||
|
||||
# Pad if too short
|
||||
if len(audio_segment) < self.segment_samples:
|
||||
audio_segment = np.pad(audio_segment, (0, self.segment_samples - len(audio_segment)))
|
||||
|
||||
# Augmentation (phase-tolerant)
|
||||
if self.augment:
|
||||
audio_segment = self._augment_audio(audio_segment)
|
||||
|
||||
# Normalize
|
||||
peak = np.max(np.abs(audio_segment))
|
||||
if peak > 0:
|
||||
audio_segment = audio_segment / peak
|
||||
|
||||
# Encode to C-Space
|
||||
cspace_data = self.cspace_model.encode(audio_segment)
|
||||
|
||||
# Already a tensor, just ensure dtype
|
||||
if isinstance(cspace_data, torch.Tensor):
|
||||
return cspace_data.to(torch.complex64)
|
||||
return torch.tensor(cspace_data, dtype=torch.complex64)
|
||||
|
||||
def _augment_audio(self, audio):
|
||||
"""Phase-tolerant augmentations."""
|
||||
# Random gain
|
||||
if np.random.rand() > 0.5:
|
||||
audio = audio * np.random.uniform(0.8, 1.0)
|
||||
|
||||
# Time stretch (resample)
|
||||
if np.random.rand() > 0.5:
|
||||
stretch_factor = np.random.uniform(0.95, 1.05)
|
||||
# Unsqueeze to (1, 1, time) for interpolate
|
||||
audio_t = torch.tensor(audio).unsqueeze(0).unsqueeze(0)
|
||||
audio_t = torch.nn.functional.interpolate(
|
||||
audio_t,
|
||||
scale_factor=stretch_factor,
|
||||
mode='linear',
|
||||
align_corners=False
|
||||
)
|
||||
audio = audio_t.squeeze().numpy()
|
||||
|
||||
# If we stretched it longer than segment_len, crop it; if shorter, pad
|
||||
# For simplicity in this logical block, we usually let the caller handle exact sizing,
|
||||
# but to be safe let's just take the center or start.
|
||||
target_len = len(audio)
|
||||
# (Note: simpler approach is just return it and let encodings handle length,
|
||||
# but the dataset expects fixed window for batches usually.
|
||||
# Here we just return modified audio.)
|
||||
|
||||
# Add small noise (phase-tolerant)
|
||||
if np.random.rand() > 0.5:
|
||||
noise = np.random.normal(0, 0.01, len(audio))
|
||||
audio = audio + noise
|
||||
|
||||
return audio
|
||||
Reference in New Issue
Block a user