SAGAN / sagan_model.py
Pavithiran's picture
Create sagan_model.py
7698a6d verified
# sagan_model.py
import torch
import torch.nn as nn
from torch.nn.utils import spectral_norm
# -------------------------
# Self-Attention Module
# -------------------------
class Self_Attn(nn.Module):
def __init__(self, in_dim):
super().__init__()
self.query_conv = nn.Conv2d(in_dim, in_dim // 8, 1)
self.key_conv = nn.Conv2d(in_dim, in_dim // 8, 1)
self.value_conv = nn.Conv2d(in_dim, in_dim, 1)
self.gamma = nn.Parameter(torch.zeros(1))
self.softmax = nn.Softmax(dim=-1)
def forward(self, x):
B, C, W, H = x.size()
proj_q = self.query_conv(x).view(B, -1, W*H).permute(0,2,1)
proj_k = self.key_conv(x).view(B, -1, W*H)
energy = torch.bmm(proj_q, proj_k) # B×(WH)×(WH)
attention = self.softmax(energy)
proj_v = self.value_conv(x).view(B, -1, W*H)
out = torch.bmm(proj_v, attention.permute(0,2,1))
out = out.view(B, C, W, H)
return self.gamma * out + x
# -------------------------
# Generator & Discriminator
# -------------------------
class Generator(nn.Module):
def __init__(self, z_dim=128, img_channels=3, base_channels=64):
super().__init__()
self.net = nn.Sequential(
spectral_norm(nn.ConvTranspose2d(z_dim, base_channels*8, 4, 1, 0)),
nn.BatchNorm2d(base_channels*8),
nn.ReLU(True),
spectral_norm(nn.ConvTranspose2d(base_channels*8, base_channels*4, 4, 2, 1)),
nn.BatchNorm2d(base_channels*4),
nn.ReLU(True),
# insert self‐attention at 32×32
Self_Attn(base_channels*4),
spectral_norm(nn.ConvTranspose2d(base_channels*4, base_channels*2, 4, 2, 1)),
nn.BatchNorm2d(base_channels*2),
nn.ReLU(True),
spectral_norm(nn.ConvTranspose2d(base_channels*2, base_channels, 4, 2, 1)),
nn.BatchNorm2d(base_channels),
nn.ReLU(True),
spectral_norm(nn.ConvTranspose2d(base_channels, img_channels, 4, 2, 1)),
nn.Tanh()
)
def forward(self, z):
# Expect z shape: (B, z_dim, 1, 1)
return self.net(z)
class Discriminator(nn.Module):
def __init__(self, img_channels=3, base_channels=64):
super().__init__()
self.net = nn.Sequential(
spectral_norm(nn.Conv2d(img_channels, base_channels, 4, 2, 1)),
nn.LeakyReLU(0.1, True),
spectral_norm(nn.Conv2d(base_channels, base_channels*2, 4, 2, 1)),
nn.LeakyReLU(0.1, True),
# self‐attention at 32×32
Self_Attn(base_channels*2),
spectral_norm(nn.Conv2d(base_channels*2, base_channels*4, 4, 2, 1)),
nn.LeakyReLU(0.1, True),
spectral_norm(nn.Conv2d(base_channels*4, 1, 4, 1, 0))
)
def forward(self, x):
return self.net(x).view(-1)
# -------------------------
# High-Level Wrapper
# -------------------------
class SAGANModel(nn.Module):
def __init__(self, z_dim=128, img_channels=3, base_channels=64):
super().__init__()
self.gen = Generator(z_dim, img_channels, base_channels)
self.dis = Discriminator(img_channels, base_channels)
def forward(self, z):
# Only generator’s forward is typically used during inference
return self.gen(z)