|
|
import torch |
|
|
import torch.nn as nn |
|
|
import math |
|
|
import os |
|
|
|
|
|
|
|
|
class ResidualBlock(nn.Module): |
|
|
def __init__(self, num_features, kernel_size=3, bn=False, act=nn.ReLU(True), res_scale=1.0): |
|
|
super(ResidualBlock, self).__init__() |
|
|
padding = kernel_size // 2 |
|
|
m = [] |
|
|
m.append(nn.Conv2d(num_features, num_features, kernel_size, padding=padding)) |
|
|
if bn: m.append(nn.BatchNorm2d(num_features)) |
|
|
m.append(act) |
|
|
m.append(nn.Conv2d(num_features, num_features, kernel_size, padding=padding)) |
|
|
if bn: m.append(nn.BatchNorm2d(num_features)) |
|
|
self.body = nn.Sequential(*m) |
|
|
self.res_scale = res_scale |
|
|
def forward(self, x): |
|
|
res = self.body(x).mul(self.res_scale) |
|
|
res += x |
|
|
return res |
|
|
|
|
|
class Upsampler(nn.Module): |
|
|
def __init__(self, scale_factor, num_features, act=nn.ReLU(True)): |
|
|
super(Upsampler, self).__init__() |
|
|
m = [] |
|
|
m.append(nn.Conv2d(num_features, num_features * (scale_factor ** 2), kernel_size=3, padding=1)) |
|
|
m.append(nn.PixelShuffle(scale_factor)) |
|
|
if act: m.append(act) |
|
|
self.body = nn.Sequential(*m) |
|
|
def forward(self, x): |
|
|
return self.body(x) |
|
|
|
|
|
class Generator(nn.Module): |
|
|
def __init__(self, scale_factor=4, in_channels=3, out_channels=3, num_features=64, num_res_blocks=16, res_scale=1.0): |
|
|
super(Generator, self).__init__() |
|
|
self.scale_factor = scale_factor |
|
|
act = nn.ReLU(True) |
|
|
self.head = nn.Conv2d(in_channels, num_features, kernel_size=3, padding=1) |
|
|
res_blocks = [ResidualBlock(num_features, kernel_size=3, act=act, res_scale=res_scale) for _ in range(num_res_blocks)] |
|
|
res_blocks.append(nn.Conv2d(num_features, num_features, kernel_size=3, padding=1)) |
|
|
self.body = nn.Sequential(*res_blocks) |
|
|
m_tail = [] |
|
|
if (scale_factor & (scale_factor - 1)) == 0: |
|
|
for _ in range(int(math.log2(scale_factor))): |
|
|
m_tail.append(Upsampler(scale_factor=2, num_features=num_features, act=None)) |
|
|
elif scale_factor == 3: |
|
|
m_tail.append(Upsampler(scale_factor=3, num_features=num_features, act=None)) |
|
|
else: |
|
|
raise NotImplementedError(f"Scale factor {scale_factor} not directly supported by this simple upsampler.") |
|
|
self.tail = nn.Sequential(*m_tail) |
|
|
self.final_conv = nn.Conv2d(num_features, out_channels, kernel_size=3, padding=1) |
|
|
|
|
|
def forward(self, lr_img): |
|
|
x = self.head(lr_img) |
|
|
res = self.body(x) |
|
|
res += x |
|
|
x = self.tail(res) |
|
|
x = self.final_conv(x) |
|
|
return x |
|
|
|
|
|
|
|
|
class Discriminator(nn.Module): |
|
|
""" |
|
|
Simple CNN Discriminator Network (PatchGAN style is common but this is simpler). |
|
|
Takes an image (real HR or generated SR) and outputs a single logit. |
|
|
""" |
|
|
def __init__(self, in_channels=3, num_features_start=64, num_blocks=4): |
|
|
super(Discriminator, self).__init__() |
|
|
|
|
|
|
|
|
layers = [ |
|
|
nn.Conv2d(in_channels, num_features_start, kernel_size=3, stride=1, padding=1), |
|
|
nn.LeakyReLU(0.2, inplace=True) |
|
|
] |
|
|
|
|
|
current_features = num_features_start |
|
|
for i in range(num_blocks): |
|
|
stride = 1 if i % 2 == 0 else 2 |
|
|
next_features = current_features * 2 if stride == 2 else current_features |
|
|
layers.extend([ |
|
|
nn.Conv2d(current_features, next_features, kernel_size=3, stride=stride, padding=1), |
|
|
nn.BatchNorm2d(next_features), |
|
|
nn.LeakyReLU(0.2, inplace=True) |
|
|
]) |
|
|
current_features = next_features |
|
|
|
|
|
self.features = nn.Sequential(*layers) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) |
|
|
self.classifier = nn.Sequential( |
|
|
nn.Linear(current_features, 100), |
|
|
nn.LeakyReLU(0.2, inplace=True), |
|
|
nn.Linear(100, 1) |
|
|
) |
|
|
|
|
|
def forward(self, img): |
|
|
""" |
|
|
Args: |
|
|
img (torch.Tensor): Input image tensor (B, C, H, W), either real HR or fake SR. |
|
|
Returns: |
|
|
torch.Tensor: Output logits (B, 1). Higher values -> more likely "real". |
|
|
""" |
|
|
batch_size = img.size(0) |
|
|
features = self.features(img) |
|
|
pooled = self.avgpool(features) |
|
|
|
|
|
pooled = pooled.view(batch_size, -1) |
|
|
output = self.classifier(pooled) |
|
|
return output |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
|
|
SCALE = 4 |
|
|
GEN_FEATURES = 64 |
|
|
GEN_RES_BLOCKS = 8 |
|
|
save_dir = "saved_models" |
|
|
os.makedirs(save_dir, exist_ok=True) |
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
print(f"Using device: {device}") |
|
|
|
|
|
|
|
|
gen_batch_size = 1 |
|
|
lr_height = 32 |
|
|
lr_width = 32 |
|
|
in_channels = 3 |
|
|
dummy_lr = torch.randn(gen_batch_size, in_channels, lr_height, lr_width).to(device) |
|
|
print(f"Dummy LR input shape (Generator): {dummy_lr.shape}") |
|
|
|
|
|
generator = Generator(scale_factor=SCALE, num_features=GEN_FEATURES, num_res_blocks=GEN_RES_BLOCKS).to(device) |
|
|
generator.eval() |
|
|
with torch.no_grad(): |
|
|
output_sr = generator(dummy_lr) |
|
|
print(f"Output SR shape (Generator): {output_sr.shape}") |
|
|
|
|
|
print("\nGenerator definition test successful!") |
|
|
num_params_gen = sum(p.numel() for p in generator.parameters() if p.requires_grad) |
|
|
print(f"Generator - Number of trainable parameters: {num_params_gen:,}") |
|
|
|
|
|
|
|
|
print("\n--- Testing Discriminator ---") |
|
|
|
|
|
DISC_FEATURES = 64 |
|
|
DISC_BLOCKS = 3 |
|
|
|
|
|
|
|
|
disc_batch_size = 4 |
|
|
hr_height = output_sr.shape[2] |
|
|
hr_width = output_sr.shape[3] |
|
|
dummy_hr = torch.randn(disc_batch_size, in_channels, hr_height, hr_width).to(device) |
|
|
print(f"Dummy HR/SR input shape (Discriminator): {dummy_hr.shape}") |
|
|
|
|
|
|
|
|
discriminator = Discriminator(in_channels=in_channels, |
|
|
num_features_start=DISC_FEATURES, |
|
|
num_blocks=DISC_BLOCKS).to(device) |
|
|
discriminator.eval() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
output_logits = discriminator(dummy_hr) |
|
|
|
|
|
print(f"Output Logits shape (Discriminator): {output_logits.shape}") |
|
|
|
|
|
|
|
|
expected_disc_shape = (disc_batch_size, 1) |
|
|
assert output_logits.shape == expected_disc_shape, \ |
|
|
f"Discriminator output shape mismatch! Expected {expected_disc_shape}, got {output_logits.shape}" |
|
|
|
|
|
print("Discriminator definition test successful!") |
|
|
|
|
|
|
|
|
num_params_disc = sum(p.numel() for p in discriminator.parameters() if p.requires_grad) |
|
|
print(f"Discriminator - Number of trainable parameters: {num_params_disc:,}") |