import torch import torch.nn as nn class ResizeConv2d(nn.Module): """ Upsamples the input and then performs a convolution. This avoids checkerboard artifacts common in ConvTranspose2d. """ def __init__(self, in_channels, out_channels, kernel_size, scale_factor): super().__init__() # 'bilinear' is usually smoother for continuous signals like spectrograms self.upsample = nn.Upsample(scale_factor=scale_factor, mode='bilinear', align_corners=False) # Standard convolution with padding to maintain size after upsampling # Padding = (kernel_size - 1) // 2 for odd kernels padding = (kernel_size[0] // 2, kernel_size[1] // 2) self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=padding) def forward(self, x): x = self.upsample(x) return self.conv(x) class CSpaceCompressor(nn.Module): """Convolutional compressor for C-Space as a 2D image (frequency x time).""" def __init__(self, num_nodes=32, compression_factor=4, latent_dim=128): super().__init__() self.num_nodes = num_nodes self.compression_factor = compression_factor self.latent_dim = latent_dim # Encoder: (batch, 2, num_nodes, time) -> Latent # We keep the encoder largely the same, but standard Convs are fine here. self.encoder = nn.Sequential( nn.Conv2d(2, 32, kernel_size=(3, 5), stride=(1, 2), padding=(1, 2)), nn.ReLU(), nn.Conv2d(32, 64, kernel_size=(3, 5), stride=(2, 2), padding=(1, 2)), nn.ReLU(), nn.Conv2d(64, 128, kernel_size=(3, 5), stride=(2, 2), padding=(1, 2)), nn.ReLU(), nn.Conv2d(128, latent_dim, kernel_size=(3, 5), stride=(2, 1), padding=(1, 2)), ) # Decoder: Replaced ConvTranspose2d with ResizeConv2d # We look at the strides of the Encoder to determine scale factors for Decoder # Encoder strides: (1,2) -> (2,2) -> (2,2) -> (2,1) # Decoder scales: (2,1) -> (2,2) -> (2,2) -> (1,2) self.decoder = nn.Sequential( ResizeConv2d(latent_dim, 128, kernel_size=(3, 5), scale_factor=(2, 1)), nn.ReLU(), ResizeConv2d(128, 64, kernel_size=(3, 5), scale_factor=(2, 2)), nn.ReLU(), ResizeConv2d(64, 32, kernel_size=(3, 5), scale_factor=(2, 2)), nn.ReLU(), ResizeConv2d(32, 2, kernel_size=(3, 5), scale_factor=(1, 2)), # No final ReLU, we need raw coordinate values ) def forward(self, cspace_complex): batch_size, time_steps, num_nodes = cspace_complex.shape # Split into real and imaginary real = cspace_complex.real imag = cspace_complex.imag # Stack: (batch, 2, time, num_nodes) x = torch.stack([real, imag], dim=1) # Transpose to Frequency-First for 2D Conv logic: (batch, 2, num_nodes, time) x = x.transpose(2, 3) # Encode latent = self.encoder(x) # Decode x_recon = self.decoder(latent) # Force exact size match (Upsampling might result in 1-2 pixel differences depending on rounding) # We interpolate to the exact target size if slightly off if x_recon.shape[2] != num_nodes or x_recon.shape[3] != time_steps: x_recon = torch.nn.functional.interpolate( x_recon, size=(num_nodes, time_steps), mode='bilinear', align_corners=False ) # Extract real and imag real_recon = x_recon[:, 0, :, :] imag_recon = x_recon[:, 1, :, :] # Transpose back to (batch, time, num_nodes) real_recon = real_recon.transpose(1, 2) imag_recon = imag_recon.transpose(1, 2) return torch.complex(real_recon, imag_recon)