Spaces:
Sleeping
Sleeping
Upload 4 files
Browse files- app.py +89 -0
- model.py +159 -0
- outputs/checkpoints/best.pth +3 -0
- requirements.txt +14 -0
app.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Gradio demo for Aging-GAN: upload a face, choose direction, and get an aged or rejuvenated output.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import gradio as gr
|
| 6 |
+
import torch
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from PIL import Image
|
| 9 |
+
import torchvision.transforms as T
|
| 10 |
+
|
| 11 |
+
from aging_gan.model import initialize_models
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
# Utils
|
| 15 |
+
def get_device() -> torch.device:
|
| 16 |
+
"""Return CUDA device if available else CPU."""
|
| 17 |
+
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
# Transforms
|
| 21 |
+
preprocess = T.Compose(
|
| 22 |
+
[
|
| 23 |
+
T.Resize((256 + 50, 256 + 50), antialias=True),
|
| 24 |
+
T.CenterCrop(256),
|
| 25 |
+
T.ToTensor(),
|
| 26 |
+
T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
|
| 27 |
+
]
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
postprocess = T.Compose([T.Normalize(mean=[-1, -1, -1], std=[2, 2, 2]), T.ToPILImage()])
|
| 31 |
+
|
| 32 |
+
# Load models & checkpoint once
|
| 33 |
+
device = get_device()
|
| 34 |
+
|
| 35 |
+
# initialize G (young→old) and F (old→young)
|
| 36 |
+
G, F, _, _ = initialize_models()
|
| 37 |
+
ckpt_path = Path("outputs/checkpoints/best.pth")
|
| 38 |
+
ckpt = torch.load(ckpt_path, map_location=device)
|
| 39 |
+
|
| 40 |
+
G.load_state_dict(ckpt["G"])
|
| 41 |
+
F.load_state_dict(ckpt["F"])
|
| 42 |
+
G.eval().to(device)
|
| 43 |
+
F.eval().to(device)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
# Inference function
|
| 47 |
+
def infer(image: Image.Image, direction: str) -> Image.Image:
|
| 48 |
+
"""
|
| 49 |
+
Run a single forward pass through the chosen generator.
|
| 50 |
+
"""
|
| 51 |
+
# preprocess
|
| 52 |
+
x = preprocess(image).unsqueeze(0).to(device) # (1,3,256,256)
|
| 53 |
+
|
| 54 |
+
# generate
|
| 55 |
+
with torch.inference_mode():
|
| 56 |
+
if direction == "young2old":
|
| 57 |
+
y_hat = G(x)
|
| 58 |
+
else:
|
| 59 |
+
y_hat = F(x)
|
| 60 |
+
y_hat = torch.clamp(y_hat, -1, 1)
|
| 61 |
+
|
| 62 |
+
# postprocess & return PIL image
|
| 63 |
+
out = postprocess(y_hat.squeeze(0).cpu())
|
| 64 |
+
return out
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
# Launch Gradio
|
| 68 |
+
demo = gr.Interface(
|
| 69 |
+
fn=infer,
|
| 70 |
+
inputs=[
|
| 71 |
+
gr.Image(type="pil", label="Input Face"),
|
| 72 |
+
gr.Radio(
|
| 73 |
+
choices=["young2old", "old2young"],
|
| 74 |
+
value="young2old",
|
| 75 |
+
label="Transformation Direction",
|
| 76 |
+
),
|
| 77 |
+
],
|
| 78 |
+
outputs=gr.Image(type="pil", label="Output Face"),
|
| 79 |
+
title="Aging-GAN Demo",
|
| 80 |
+
description=(
|
| 81 |
+
"Upload a portrait, select “young2old” to age it or “old2young” to rejuvenate. "
|
| 82 |
+
"Powered by a ResNet-style CycleGAN generator. "
|
| 83 |
+
"TIP: Upload close-up photos of the face similar to ones in the Github README examples."
|
| 84 |
+
),
|
| 85 |
+
allow_flagging="never",
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
if __name__ == "__main__":
|
| 89 |
+
demo.launch()
|
model.py
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Model definitions for the CycleGAN-style architecture."""
|
| 2 |
+
|
| 3 |
+
from torch import Tensor
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class ResidualBlock(nn.Module):
|
| 9 |
+
"""Simple residual block with two conv layers."""
|
| 10 |
+
|
| 11 |
+
def __init__(self, in_features: int) -> None:
|
| 12 |
+
super().__init__()
|
| 13 |
+
|
| 14 |
+
conv_block = [
|
| 15 |
+
nn.ReflectionPad2d(1), # (B, C, H+2, W+2)
|
| 16 |
+
nn.Conv2d(in_features, in_features, 3), # (B, C, H, W)
|
| 17 |
+
nn.BatchNorm2d(in_features), # (B, C, H, W)
|
| 18 |
+
nn.ReLU(), # (B, C, H, W)
|
| 19 |
+
nn.ReflectionPad2d(1), # (B, C, H+2, W+2)
|
| 20 |
+
nn.Conv2d(in_features, in_features, 3), # (B, C, H, W)
|
| 21 |
+
nn.BatchNorm2d(in_features),
|
| 22 |
+
] # (B, C, H, W)
|
| 23 |
+
|
| 24 |
+
self.conv_block = nn.Sequential(*conv_block)
|
| 25 |
+
|
| 26 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 27 |
+
"""Apply the residual block."""
|
| 28 |
+
return x + self.conv_block(x)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class Generator(nn.Module):
|
| 32 |
+
"""U-Net style generator used for domain translation."""
|
| 33 |
+
|
| 34 |
+
def __init__(self, ngf: int, n_residual_blocks: int = 9) -> None:
|
| 35 |
+
super().__init__()
|
| 36 |
+
|
| 37 |
+
# Initial convlution block
|
| 38 |
+
model = [
|
| 39 |
+
nn.ReflectionPad2d(
|
| 40 |
+
3
|
| 41 |
+
), # (B, 3, H+6, W+6), applies 2D "reflection" padding of 3 pixels on all four sides of image
|
| 42 |
+
nn.Conv2d(
|
| 43 |
+
3, ngf, 7
|
| 44 |
+
), # (B, ngf, H, W), 3 in_channels, ngf out_channels, kernel size 7 (keeps same image size)
|
| 45 |
+
nn.BatchNorm2d(
|
| 46 |
+
ngf
|
| 47 |
+
), # (B, ngf, H, W), normalized for each ngf across all B, H, W
|
| 48 |
+
nn.ReLU(),
|
| 49 |
+
] # (B, ngf, H, W)
|
| 50 |
+
|
| 51 |
+
# Downsampling
|
| 52 |
+
in_features = ngf # number of generator filters
|
| 53 |
+
out_features = in_features * 2
|
| 54 |
+
for _ in range(2):
|
| 55 |
+
model += [
|
| 56 |
+
nn.Conv2d(
|
| 57 |
+
in_features, out_features, 3, stride=2, padding=1
|
| 58 |
+
), # (B, in_features*2, H//2, W//2), doubles number of channels and reduces H, W by half
|
| 59 |
+
nn.BatchNorm2d(out_features), # (B, in_features*2, H//2, W//2)
|
| 60 |
+
nn.ReLU(),
|
| 61 |
+
] # (B, in_features*2, H//2, W//2)
|
| 62 |
+
in_features = out_features
|
| 63 |
+
out_features = in_features * 2
|
| 64 |
+
|
| 65 |
+
# Residual blocks
|
| 66 |
+
for _ in range(n_residual_blocks):
|
| 67 |
+
model += [
|
| 68 |
+
ResidualBlock(in_features)
|
| 69 |
+
] # (B, in_features, H, W), returns same size as input
|
| 70 |
+
|
| 71 |
+
# Upsampling
|
| 72 |
+
out_features = in_features // 2
|
| 73 |
+
for _ in range(2):
|
| 74 |
+
model += [
|
| 75 |
+
nn.ConvTranspose2d(
|
| 76 |
+
in_features, out_features, 3, stride=2, padding=1, output_padding=1
|
| 77 |
+
), # (B, in_features//2, H*2, W*2), upsamples to twice the H, W with half the channels
|
| 78 |
+
nn.BatchNorm2d(out_features), # (B, in_features//2, H*2, W*2)
|
| 79 |
+
nn.ReLU(),
|
| 80 |
+
] # (B, in_features//2, H*2, W*2)
|
| 81 |
+
in_features = out_features
|
| 82 |
+
out_features = in_features // 2
|
| 83 |
+
|
| 84 |
+
# Output layer
|
| 85 |
+
model += [
|
| 86 |
+
nn.ReflectionPad2d(3), # (B, in_features, H+6, W+6)
|
| 87 |
+
nn.Conv2d(ngf, 3, 7), # (B, 3, H, W)
|
| 88 |
+
nn.Tanh(),
|
| 89 |
+
] # (B, 3, H, W), passed tanh activation
|
| 90 |
+
|
| 91 |
+
self.model = nn.Sequential(*model)
|
| 92 |
+
|
| 93 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 94 |
+
"""Generate an image from ``x``."""
|
| 95 |
+
return self.model(x)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
class Discriminator(nn.Module):
|
| 99 |
+
"""PatchGAN discriminator."""
|
| 100 |
+
|
| 101 |
+
def __init__(self, ndf: int) -> None:
|
| 102 |
+
super().__init__()
|
| 103 |
+
|
| 104 |
+
model = [
|
| 105 |
+
nn.Conv2d(
|
| 106 |
+
3, ndf, 4, stride=2, padding=1
|
| 107 |
+
), # (B, ndf, H//2, W//2), channel from 3 -> ndf
|
| 108 |
+
nn.LeakyReLU(0.2, inplace=True),
|
| 109 |
+
] # (B, ndf, H//2, W//2)
|
| 110 |
+
|
| 111 |
+
model += [
|
| 112 |
+
nn.Conv2d(ndf, ndf * 2, 4, stride=2, padding=1), # (B, ndf * 2, H//4, W//4)
|
| 113 |
+
nn.BatchNorm2d(ndf * 2),
|
| 114 |
+
nn.LeakyReLU(0.2, inplace=True),
|
| 115 |
+
]
|
| 116 |
+
|
| 117 |
+
model += [
|
| 118 |
+
nn.Conv2d(
|
| 119 |
+
ndf * 2, ndf * 4, 4, stride=2, padding=1
|
| 120 |
+
), # (B, ndf * 4, H//8, W//8)
|
| 121 |
+
nn.InstanceNorm2d(ndf * 4),
|
| 122 |
+
nn.LeakyReLU(0.2, inplace=True),
|
| 123 |
+
]
|
| 124 |
+
|
| 125 |
+
model += [
|
| 126 |
+
nn.Conv2d(ndf * 4, ndf * 8, 4, padding=1), # (B, ndf * 8, H//8-1, W//8-1)
|
| 127 |
+
nn.InstanceNorm2d(ndf * 8),
|
| 128 |
+
nn.LeakyReLU(0.2, inplace=True),
|
| 129 |
+
]
|
| 130 |
+
|
| 131 |
+
# FCN classification layer
|
| 132 |
+
model += [nn.Conv2d(ndf * 8, 1, 4, padding=1)] # (B, 1, H//8-2, W//8-2)
|
| 133 |
+
|
| 134 |
+
self.model = nn.Sequential(*model)
|
| 135 |
+
|
| 136 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 137 |
+
"""Return discriminator logits for input ``x``."""
|
| 138 |
+
# x: (B, 3, H, W)
|
| 139 |
+
x = self.model(x) # (B, 1, H//8-2, W//8-2)
|
| 140 |
+
# Average pooling and flatten
|
| 141 |
+
return F.avg_pool2d(x, x.size()[2:]).view(
|
| 142 |
+
x.size()[0], -1
|
| 143 |
+
) # global average -> (B, 1, 1, 1) -> flatten to (B, 1)
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
# Initialize and return the generators and discriminators used for training
|
| 147 |
+
def initialize_models(
|
| 148 |
+
ngf: int = 32,
|
| 149 |
+
ndf: int = 32,
|
| 150 |
+
n_blocks: int = 9,
|
| 151 |
+
) -> tuple[Generator, Generator, Discriminator, Discriminator]:
|
| 152 |
+
"""Instantiate generators and discriminators with default sizes."""
|
| 153 |
+
# initialize the generators and discriminators
|
| 154 |
+
G = Generator(ngf, n_blocks)
|
| 155 |
+
F = Generator(ngf, n_blocks)
|
| 156 |
+
DX = Discriminator(ndf)
|
| 157 |
+
DY = Discriminator(ndf)
|
| 158 |
+
|
| 159 |
+
return G, F, DX, DY
|
outputs/checkpoints/best.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c720561c96c4366f6368c99f526ad4d85632899751364274b13a80be765b2fd4
|
| 3 |
+
size 85499149
|
requirements.txt
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
torchvision
|
| 3 |
+
ipykernel
|
| 4 |
+
matplotlib
|
| 5 |
+
accelerate
|
| 6 |
+
segmentation-models-pytorch
|
| 7 |
+
gdown
|
| 8 |
+
tqdm
|
| 9 |
+
torchmetrics[image]
|
| 10 |
+
wandb
|
| 11 |
+
numpy
|
| 12 |
+
python-dotenv
|
| 13 |
+
boto3
|
| 14 |
+
gradio
|