95 lines
3.9 KiB
Python
95 lines
3.9 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
|
|
class ResizeConv2d(nn.Module):
|
|
"""
|
|
Upsamples the input and then performs a convolution.
|
|
This avoids checkerboard artifacts common in ConvTranspose2d.
|
|
"""
|
|
def __init__(self, in_channels, out_channels, kernel_size, scale_factor):
|
|
super().__init__()
|
|
# 'bilinear' is usually smoother for continuous signals like spectrograms
|
|
self.upsample = nn.Upsample(scale_factor=scale_factor, mode='bilinear', align_corners=False)
|
|
# Standard convolution with padding to maintain size after upsampling
|
|
# Padding = (kernel_size - 1) // 2 for odd kernels
|
|
padding = (kernel_size[0] // 2, kernel_size[1] // 2)
|
|
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=padding)
|
|
|
|
def forward(self, x):
|
|
x = self.upsample(x)
|
|
return self.conv(x)
|
|
|
|
class CSpaceCompressor(nn.Module):
|
|
"""Convolutional compressor for C-Space as a 2D image (frequency x time)."""
|
|
|
|
def __init__(self, num_nodes=32, compression_factor=4, latent_dim=128):
|
|
super().__init__()
|
|
self.num_nodes = num_nodes
|
|
self.compression_factor = compression_factor
|
|
self.latent_dim = latent_dim
|
|
|
|
# Encoder: (batch, 2, num_nodes, time) -> Latent
|
|
# We keep the encoder largely the same, but standard Convs are fine here.
|
|
self.encoder = nn.Sequential(
|
|
nn.Conv2d(2, 32, kernel_size=(3, 5), stride=(1, 2), padding=(1, 2)),
|
|
nn.ReLU(),
|
|
nn.Conv2d(32, 64, kernel_size=(3, 5), stride=(2, 2), padding=(1, 2)),
|
|
nn.ReLU(),
|
|
nn.Conv2d(64, 128, kernel_size=(3, 5), stride=(2, 2), padding=(1, 2)),
|
|
nn.ReLU(),
|
|
nn.Conv2d(128, latent_dim, kernel_size=(3, 5), stride=(2, 1), padding=(1, 2)),
|
|
)
|
|
|
|
# Decoder: Replaced ConvTranspose2d with ResizeConv2d
|
|
# We look at the strides of the Encoder to determine scale factors for Decoder
|
|
# Encoder strides: (1,2) -> (2,2) -> (2,2) -> (2,1)
|
|
# Decoder scales: (2,1) -> (2,2) -> (2,2) -> (1,2)
|
|
|
|
self.decoder = nn.Sequential(
|
|
ResizeConv2d(latent_dim, 128, kernel_size=(3, 5), scale_factor=(2, 1)),
|
|
nn.ReLU(),
|
|
ResizeConv2d(128, 64, kernel_size=(3, 5), scale_factor=(2, 2)),
|
|
nn.ReLU(),
|
|
ResizeConv2d(64, 32, kernel_size=(3, 5), scale_factor=(2, 2)),
|
|
nn.ReLU(),
|
|
ResizeConv2d(32, 2, kernel_size=(3, 5), scale_factor=(1, 2)),
|
|
# No final ReLU, we need raw coordinate values
|
|
)
|
|
|
|
def forward(self, cspace_complex):
|
|
batch_size, time_steps, num_nodes = cspace_complex.shape
|
|
|
|
# Split into real and imaginary
|
|
real = cspace_complex.real
|
|
imag = cspace_complex.imag
|
|
|
|
# Stack: (batch, 2, time, num_nodes)
|
|
x = torch.stack([real, imag], dim=1)
|
|
# Transpose to Frequency-First for 2D Conv logic: (batch, 2, num_nodes, time)
|
|
x = x.transpose(2, 3)
|
|
|
|
# Encode
|
|
latent = self.encoder(x)
|
|
|
|
# Decode
|
|
x_recon = self.decoder(latent)
|
|
|
|
# Force exact size match (Upsampling might result in 1-2 pixel differences depending on rounding)
|
|
# We interpolate to the exact target size if slightly off
|
|
if x_recon.shape[2] != num_nodes or x_recon.shape[3] != time_steps:
|
|
x_recon = torch.nn.functional.interpolate(
|
|
x_recon,
|
|
size=(num_nodes, time_steps),
|
|
mode='bilinear',
|
|
align_corners=False
|
|
)
|
|
|
|
# Extract real and imag
|
|
real_recon = x_recon[:, 0, :, :]
|
|
imag_recon = x_recon[:, 1, :, :]
|
|
|
|
# Transpose back to (batch, time, num_nodes)
|
|
real_recon = real_recon.transpose(1, 2)
|
|
imag_recon = imag_recon.transpose(1, 2)
|
|
|
|
return torch.complex(real_recon, imag_recon) |