Files
cspace/model.py
2025-12-12 20:41:37 -06:00

95 lines
3.9 KiB
Python

import torch
import torch.nn as nn
class ResizeConv2d(nn.Module):
"""
Upsamples the input and then performs a convolution.
This avoids checkerboard artifacts common in ConvTranspose2d.
"""
def __init__(self, in_channels, out_channels, kernel_size, scale_factor):
super().__init__()
# 'bilinear' is usually smoother for continuous signals like spectrograms
self.upsample = nn.Upsample(scale_factor=scale_factor, mode='bilinear', align_corners=False)
# Standard convolution with padding to maintain size after upsampling
# Padding = (kernel_size - 1) // 2 for odd kernels
padding = (kernel_size[0] // 2, kernel_size[1] // 2)
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=padding)
def forward(self, x):
x = self.upsample(x)
return self.conv(x)
class CSpaceCompressor(nn.Module):
"""Convolutional compressor for C-Space as a 2D image (frequency x time)."""
def __init__(self, num_nodes=32, compression_factor=4, latent_dim=128):
super().__init__()
self.num_nodes = num_nodes
self.compression_factor = compression_factor
self.latent_dim = latent_dim
# Encoder: (batch, 2, num_nodes, time) -> Latent
# We keep the encoder largely the same, but standard Convs are fine here.
self.encoder = nn.Sequential(
nn.Conv2d(2, 32, kernel_size=(3, 5), stride=(1, 2), padding=(1, 2)),
nn.ReLU(),
nn.Conv2d(32, 64, kernel_size=(3, 5), stride=(2, 2), padding=(1, 2)),
nn.ReLU(),
nn.Conv2d(64, 128, kernel_size=(3, 5), stride=(2, 2), padding=(1, 2)),
nn.ReLU(),
nn.Conv2d(128, latent_dim, kernel_size=(3, 5), stride=(2, 1), padding=(1, 2)),
)
# Decoder: Replaced ConvTranspose2d with ResizeConv2d
# We look at the strides of the Encoder to determine scale factors for Decoder
# Encoder strides: (1,2) -> (2,2) -> (2,2) -> (2,1)
# Decoder scales: (2,1) -> (2,2) -> (2,2) -> (1,2)
self.decoder = nn.Sequential(
ResizeConv2d(latent_dim, 128, kernel_size=(3, 5), scale_factor=(2, 1)),
nn.ReLU(),
ResizeConv2d(128, 64, kernel_size=(3, 5), scale_factor=(2, 2)),
nn.ReLU(),
ResizeConv2d(64, 32, kernel_size=(3, 5), scale_factor=(2, 2)),
nn.ReLU(),
ResizeConv2d(32, 2, kernel_size=(3, 5), scale_factor=(1, 2)),
# No final ReLU, we need raw coordinate values
)
def forward(self, cspace_complex):
batch_size, time_steps, num_nodes = cspace_complex.shape
# Split into real and imaginary
real = cspace_complex.real
imag = cspace_complex.imag
# Stack: (batch, 2, time, num_nodes)
x = torch.stack([real, imag], dim=1)
# Transpose to Frequency-First for 2D Conv logic: (batch, 2, num_nodes, time)
x = x.transpose(2, 3)
# Encode
latent = self.encoder(x)
# Decode
x_recon = self.decoder(latent)
# Force exact size match (Upsampling might result in 1-2 pixel differences depending on rounding)
# We interpolate to the exact target size if slightly off
if x_recon.shape[2] != num_nodes or x_recon.shape[3] != time_steps:
x_recon = torch.nn.functional.interpolate(
x_recon,
size=(num_nodes, time_steps),
mode='bilinear',
align_corners=False
)
# Extract real and imag
real_recon = x_recon[:, 0, :, :]
imag_recon = x_recon[:, 1, :, :]
# Transpose back to (batch, time, num_nodes)
real_recon = real_recon.transpose(1, 2)
imag_recon = imag_recon.transpose(1, 2)
return torch.complex(real_recon, imag_recon)