Spaces:
Sleeping
Sleeping
# sagan_model.py | |
import torch | |
import torch.nn as nn | |
from torch.nn.utils import spectral_norm | |
# ------------------------- | |
# Self-Attention Module | |
# ------------------------- | |
class Self_Attn(nn.Module): | |
def __init__(self, in_dim): | |
super().__init__() | |
self.query_conv = nn.Conv2d(in_dim, in_dim // 8, 1) | |
self.key_conv = nn.Conv2d(in_dim, in_dim // 8, 1) | |
self.value_conv = nn.Conv2d(in_dim, in_dim, 1) | |
self.gamma = nn.Parameter(torch.zeros(1)) | |
self.softmax = nn.Softmax(dim=-1) | |
def forward(self, x): | |
B, C, W, H = x.size() | |
proj_q = self.query_conv(x).view(B, -1, W*H).permute(0,2,1) | |
proj_k = self.key_conv(x).view(B, -1, W*H) | |
energy = torch.bmm(proj_q, proj_k) # B×(WH)×(WH) | |
attention = self.softmax(energy) | |
proj_v = self.value_conv(x).view(B, -1, W*H) | |
out = torch.bmm(proj_v, attention.permute(0,2,1)) | |
out = out.view(B, C, W, H) | |
return self.gamma * out + x | |
# ------------------------- | |
# Generator & Discriminator | |
# ------------------------- | |
class Generator(nn.Module): | |
def __init__(self, z_dim=128, img_channels=3, base_channels=64): | |
super().__init__() | |
self.net = nn.Sequential( | |
spectral_norm(nn.ConvTranspose2d(z_dim, base_channels*8, 4, 1, 0)), | |
nn.BatchNorm2d(base_channels*8), | |
nn.ReLU(True), | |
spectral_norm(nn.ConvTranspose2d(base_channels*8, base_channels*4, 4, 2, 1)), | |
nn.BatchNorm2d(base_channels*4), | |
nn.ReLU(True), | |
# insert self‐attention at 32×32 | |
Self_Attn(base_channels*4), | |
spectral_norm(nn.ConvTranspose2d(base_channels*4, base_channels*2, 4, 2, 1)), | |
nn.BatchNorm2d(base_channels*2), | |
nn.ReLU(True), | |
spectral_norm(nn.ConvTranspose2d(base_channels*2, base_channels, 4, 2, 1)), | |
nn.BatchNorm2d(base_channels), | |
nn.ReLU(True), | |
spectral_norm(nn.ConvTranspose2d(base_channels, img_channels, 4, 2, 1)), | |
nn.Tanh() | |
) | |
def forward(self, z): | |
# Expect z shape: (B, z_dim, 1, 1) | |
return self.net(z) | |
class Discriminator(nn.Module): | |
def __init__(self, img_channels=3, base_channels=64): | |
super().__init__() | |
self.net = nn.Sequential( | |
spectral_norm(nn.Conv2d(img_channels, base_channels, 4, 2, 1)), | |
nn.LeakyReLU(0.1, True), | |
spectral_norm(nn.Conv2d(base_channels, base_channels*2, 4, 2, 1)), | |
nn.LeakyReLU(0.1, True), | |
# self‐attention at 32×32 | |
Self_Attn(base_channels*2), | |
spectral_norm(nn.Conv2d(base_channels*2, base_channels*4, 4, 2, 1)), | |
nn.LeakyReLU(0.1, True), | |
spectral_norm(nn.Conv2d(base_channels*4, 1, 4, 1, 0)) | |
) | |
def forward(self, x): | |
return self.net(x).view(-1) | |
# ------------------------- | |
# High-Level Wrapper | |
# ------------------------- | |
class SAGANModel(nn.Module): | |
def __init__(self, z_dim=128, img_channels=3, base_channels=64): | |
super().__init__() | |
self.gen = Generator(z_dim, img_channels, base_channels) | |
self.dis = Discriminator(img_channels, base_channels) | |
def forward(self, z): | |
# Only generator’s forward is typically used during inference | |
return self.gen(z) | |