codinglabsong commited on
Commit
0587b57
·
verified ·
1 Parent(s): f347772

Upload 4 files

Browse files
Files changed (4) hide show
  1. app.py +89 -0
  2. model.py +159 -0
  3. outputs/checkpoints/best.pth +3 -0
  4. requirements.txt +14 -0
app.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio demo for Aging-GAN: upload a face, choose direction, and get an aged or rejuvenated output.
3
+ """
4
+
5
+ import gradio as gr
6
+ import torch
7
+ from pathlib import Path
8
+ from PIL import Image
9
+ import torchvision.transforms as T
10
+
11
+ from aging_gan.model import initialize_models
12
+
13
+
14
+ # Utils
15
+ def get_device() -> torch.device:
16
+ """Return CUDA device if available else CPU."""
17
+ return torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
+
19
+
20
+ # Transforms
21
+ preprocess = T.Compose(
22
+ [
23
+ T.Resize((256 + 50, 256 + 50), antialias=True),
24
+ T.CenterCrop(256),
25
+ T.ToTensor(),
26
+ T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
27
+ ]
28
+ )
29
+
30
+ postprocess = T.Compose([T.Normalize(mean=[-1, -1, -1], std=[2, 2, 2]), T.ToPILImage()])
31
+
32
+ # Load models & checkpoint once
33
+ device = get_device()
34
+
35
+ # initialize G (young→old) and F (old→young)
36
+ G, F, _, _ = initialize_models()
37
+ ckpt_path = Path("outputs/checkpoints/best.pth")
38
+ ckpt = torch.load(ckpt_path, map_location=device)
39
+
40
+ G.load_state_dict(ckpt["G"])
41
+ F.load_state_dict(ckpt["F"])
42
+ G.eval().to(device)
43
+ F.eval().to(device)
44
+
45
+
46
+ # Inference function
47
+ def infer(image: Image.Image, direction: str) -> Image.Image:
48
+ """
49
+ Run a single forward pass through the chosen generator.
50
+ """
51
+ # preprocess
52
+ x = preprocess(image).unsqueeze(0).to(device) # (1,3,256,256)
53
+
54
+ # generate
55
+ with torch.inference_mode():
56
+ if direction == "young2old":
57
+ y_hat = G(x)
58
+ else:
59
+ y_hat = F(x)
60
+ y_hat = torch.clamp(y_hat, -1, 1)
61
+
62
+ # postprocess & return PIL image
63
+ out = postprocess(y_hat.squeeze(0).cpu())
64
+ return out
65
+
66
+
67
+ # Launch Gradio
68
+ demo = gr.Interface(
69
+ fn=infer,
70
+ inputs=[
71
+ gr.Image(type="pil", label="Input Face"),
72
+ gr.Radio(
73
+ choices=["young2old", "old2young"],
74
+ value="young2old",
75
+ label="Transformation Direction",
76
+ ),
77
+ ],
78
+ outputs=gr.Image(type="pil", label="Output Face"),
79
+ title="Aging-GAN Demo",
80
+ description=(
81
+ "Upload a portrait, select “young2old” to age it or “old2young” to rejuvenate. "
82
+ "Powered by a ResNet-style CycleGAN generator. "
83
+ "TIP: Upload close-up photos of the face similar to ones in the Github README examples."
84
+ ),
85
+ allow_flagging="never",
86
+ )
87
+
88
+ if __name__ == "__main__":
89
+ demo.launch()
model.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Model definitions for the CycleGAN-style architecture."""
2
+
3
+ from torch import Tensor
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+
8
+ class ResidualBlock(nn.Module):
9
+ """Simple residual block with two conv layers."""
10
+
11
+ def __init__(self, in_features: int) -> None:
12
+ super().__init__()
13
+
14
+ conv_block = [
15
+ nn.ReflectionPad2d(1), # (B, C, H+2, W+2)
16
+ nn.Conv2d(in_features, in_features, 3), # (B, C, H, W)
17
+ nn.BatchNorm2d(in_features), # (B, C, H, W)
18
+ nn.ReLU(), # (B, C, H, W)
19
+ nn.ReflectionPad2d(1), # (B, C, H+2, W+2)
20
+ nn.Conv2d(in_features, in_features, 3), # (B, C, H, W)
21
+ nn.BatchNorm2d(in_features),
22
+ ] # (B, C, H, W)
23
+
24
+ self.conv_block = nn.Sequential(*conv_block)
25
+
26
+ def forward(self, x: Tensor) -> Tensor:
27
+ """Apply the residual block."""
28
+ return x + self.conv_block(x)
29
+
30
+
31
+ class Generator(nn.Module):
32
+ """U-Net style generator used for domain translation."""
33
+
34
+ def __init__(self, ngf: int, n_residual_blocks: int = 9) -> None:
35
+ super().__init__()
36
+
37
+ # Initial convlution block
38
+ model = [
39
+ nn.ReflectionPad2d(
40
+ 3
41
+ ), # (B, 3, H+6, W+6), applies 2D "reflection" padding of 3 pixels on all four sides of image
42
+ nn.Conv2d(
43
+ 3, ngf, 7
44
+ ), # (B, ngf, H, W), 3 in_channels, ngf out_channels, kernel size 7 (keeps same image size)
45
+ nn.BatchNorm2d(
46
+ ngf
47
+ ), # (B, ngf, H, W), normalized for each ngf across all B, H, W
48
+ nn.ReLU(),
49
+ ] # (B, ngf, H, W)
50
+
51
+ # Downsampling
52
+ in_features = ngf # number of generator filters
53
+ out_features = in_features * 2
54
+ for _ in range(2):
55
+ model += [
56
+ nn.Conv2d(
57
+ in_features, out_features, 3, stride=2, padding=1
58
+ ), # (B, in_features*2, H//2, W//2), doubles number of channels and reduces H, W by half
59
+ nn.BatchNorm2d(out_features), # (B, in_features*2, H//2, W//2)
60
+ nn.ReLU(),
61
+ ] # (B, in_features*2, H//2, W//2)
62
+ in_features = out_features
63
+ out_features = in_features * 2
64
+
65
+ # Residual blocks
66
+ for _ in range(n_residual_blocks):
67
+ model += [
68
+ ResidualBlock(in_features)
69
+ ] # (B, in_features, H, W), returns same size as input
70
+
71
+ # Upsampling
72
+ out_features = in_features // 2
73
+ for _ in range(2):
74
+ model += [
75
+ nn.ConvTranspose2d(
76
+ in_features, out_features, 3, stride=2, padding=1, output_padding=1
77
+ ), # (B, in_features//2, H*2, W*2), upsamples to twice the H, W with half the channels
78
+ nn.BatchNorm2d(out_features), # (B, in_features//2, H*2, W*2)
79
+ nn.ReLU(),
80
+ ] # (B, in_features//2, H*2, W*2)
81
+ in_features = out_features
82
+ out_features = in_features // 2
83
+
84
+ # Output layer
85
+ model += [
86
+ nn.ReflectionPad2d(3), # (B, in_features, H+6, W+6)
87
+ nn.Conv2d(ngf, 3, 7), # (B, 3, H, W)
88
+ nn.Tanh(),
89
+ ] # (B, 3, H, W), passed tanh activation
90
+
91
+ self.model = nn.Sequential(*model)
92
+
93
+ def forward(self, x: Tensor) -> Tensor:
94
+ """Generate an image from ``x``."""
95
+ return self.model(x)
96
+
97
+
98
+ class Discriminator(nn.Module):
99
+ """PatchGAN discriminator."""
100
+
101
+ def __init__(self, ndf: int) -> None:
102
+ super().__init__()
103
+
104
+ model = [
105
+ nn.Conv2d(
106
+ 3, ndf, 4, stride=2, padding=1
107
+ ), # (B, ndf, H//2, W//2), channel from 3 -> ndf
108
+ nn.LeakyReLU(0.2, inplace=True),
109
+ ] # (B, ndf, H//2, W//2)
110
+
111
+ model += [
112
+ nn.Conv2d(ndf, ndf * 2, 4, stride=2, padding=1), # (B, ndf * 2, H//4, W//4)
113
+ nn.BatchNorm2d(ndf * 2),
114
+ nn.LeakyReLU(0.2, inplace=True),
115
+ ]
116
+
117
+ model += [
118
+ nn.Conv2d(
119
+ ndf * 2, ndf * 4, 4, stride=2, padding=1
120
+ ), # (B, ndf * 4, H//8, W//8)
121
+ nn.InstanceNorm2d(ndf * 4),
122
+ nn.LeakyReLU(0.2, inplace=True),
123
+ ]
124
+
125
+ model += [
126
+ nn.Conv2d(ndf * 4, ndf * 8, 4, padding=1), # (B, ndf * 8, H//8-1, W//8-1)
127
+ nn.InstanceNorm2d(ndf * 8),
128
+ nn.LeakyReLU(0.2, inplace=True),
129
+ ]
130
+
131
+ # FCN classification layer
132
+ model += [nn.Conv2d(ndf * 8, 1, 4, padding=1)] # (B, 1, H//8-2, W//8-2)
133
+
134
+ self.model = nn.Sequential(*model)
135
+
136
+ def forward(self, x: Tensor) -> Tensor:
137
+ """Return discriminator logits for input ``x``."""
138
+ # x: (B, 3, H, W)
139
+ x = self.model(x) # (B, 1, H//8-2, W//8-2)
140
+ # Average pooling and flatten
141
+ return F.avg_pool2d(x, x.size()[2:]).view(
142
+ x.size()[0], -1
143
+ ) # global average -> (B, 1, 1, 1) -> flatten to (B, 1)
144
+
145
+
146
+ # Initialize and return the generators and discriminators used for training
147
+ def initialize_models(
148
+ ngf: int = 32,
149
+ ndf: int = 32,
150
+ n_blocks: int = 9,
151
+ ) -> tuple[Generator, Generator, Discriminator, Discriminator]:
152
+ """Instantiate generators and discriminators with default sizes."""
153
+ # initialize the generators and discriminators
154
+ G = Generator(ngf, n_blocks)
155
+ F = Generator(ngf, n_blocks)
156
+ DX = Discriminator(ndf)
157
+ DY = Discriminator(ndf)
158
+
159
+ return G, F, DX, DY
outputs/checkpoints/best.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c720561c96c4366f6368c99f526ad4d85632899751364274b13a80be765b2fd4
3
+ size 85499149
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ ipykernel
4
+ matplotlib
5
+ accelerate
6
+ segmentation-models-pytorch
7
+ gdown
8
+ tqdm
9
+ torchmetrics[image]
10
+ wandb
11
+ numpy
12
+ python-dotenv
13
+ boto3
14
+ gradio