File size: 3,375 Bytes
7698a6d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
# sagan_model.py

import torch
import torch.nn as nn
from torch.nn.utils import spectral_norm

# -------------------------
#  Self-Attention Module
# -------------------------
class Self_Attn(nn.Module):
    def __init__(self, in_dim):
        super().__init__()
        self.query_conv = nn.Conv2d(in_dim, in_dim // 8, 1)
        self.key_conv   = nn.Conv2d(in_dim, in_dim // 8, 1)
        self.value_conv = nn.Conv2d(in_dim, in_dim, 1)
        self.gamma      = nn.Parameter(torch.zeros(1))
        self.softmax    = nn.Softmax(dim=-1)

    def forward(self, x):
        B, C, W, H = x.size()
        proj_q = self.query_conv(x).view(B, -1, W*H).permute(0,2,1)
        proj_k = self.key_conv(x).view(B, -1, W*H)
        energy = torch.bmm(proj_q, proj_k)            # B×(WH)×(WH)
        attention = self.softmax(energy)
        proj_v = self.value_conv(x).view(B, -1, W*H)

        out = torch.bmm(proj_v, attention.permute(0,2,1))
        out = out.view(B, C, W, H)
        return self.gamma * out + x

# -------------------------
#    Generator & Discriminator
# -------------------------
class Generator(nn.Module):
    def __init__(self, z_dim=128, img_channels=3, base_channels=64):
        super().__init__()
        self.net = nn.Sequential(
            spectral_norm(nn.ConvTranspose2d(z_dim, base_channels*8, 4, 1, 0)),
            nn.BatchNorm2d(base_channels*8),
            nn.ReLU(True),

            spectral_norm(nn.ConvTranspose2d(base_channels*8, base_channels*4, 4, 2, 1)),
            nn.BatchNorm2d(base_channels*4),
            nn.ReLU(True),

            # insert self‐attention at 32×32
            Self_Attn(base_channels*4),

            spectral_norm(nn.ConvTranspose2d(base_channels*4, base_channels*2, 4, 2, 1)),
            nn.BatchNorm2d(base_channels*2),
            nn.ReLU(True),

            spectral_norm(nn.ConvTranspose2d(base_channels*2, base_channels, 4, 2, 1)),
            nn.BatchNorm2d(base_channels),
            nn.ReLU(True),

            spectral_norm(nn.ConvTranspose2d(base_channels, img_channels, 4, 2, 1)),
            nn.Tanh()
        )

    def forward(self, z):
        # Expect z shape: (B, z_dim, 1, 1)
        return self.net(z)

class Discriminator(nn.Module):
    def __init__(self, img_channels=3, base_channels=64):
        super().__init__()
        self.net = nn.Sequential(
            spectral_norm(nn.Conv2d(img_channels, base_channels, 4, 2, 1)),
            nn.LeakyReLU(0.1, True),

            spectral_norm(nn.Conv2d(base_channels, base_channels*2, 4, 2, 1)),
            nn.LeakyReLU(0.1, True),

            # self‐attention at 32×32
            Self_Attn(base_channels*2),

            spectral_norm(nn.Conv2d(base_channels*2, base_channels*4, 4, 2, 1)),
            nn.LeakyReLU(0.1, True),

            spectral_norm(nn.Conv2d(base_channels*4, 1, 4, 1, 0))
        )

    def forward(self, x):
        return self.net(x).view(-1)

# -------------------------
#  High-Level Wrapper
# -------------------------
class SAGANModel(nn.Module):
    def __init__(self, z_dim=128, img_channels=3, base_channels=64):
        super().__init__()
        self.gen = Generator(z_dim, img_channels, base_channels)
        self.dis = Discriminator(img_channels, base_channels)

    def forward(self, z):
        # Only generator’s forward is typically used during inference
        return self.gen(z)