Sreeharshan commited on
Commit
099578f
1 Parent(s): ff6f7d1

Upload 9 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ examples/garden_in.jpg filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from cgitb import enable
2
+ from ctypes.wintypes import HFONT
3
+ import os
4
+ import sys
5
+ import torch
6
+ import gradio as gr
7
+ import numpy as np
8
+ import torchvision.transforms as transforms
9
+
10
+
11
+ from torch.autograd import Variable
12
+ from network.Transformer import Transformer
13
+ from huggingface_hub import hf_hub_download
14
+
15
+ from PIL import Image
16
+
17
+ import logging
18
+
19
+ logging.basicConfig(level=logging.INFO)
20
+ logger = logging.getLogger(__name__)
21
+
22
+ # Constants
23
+
24
+ MAX_DIMENSION = 1280
25
+ MODEL_PATH = "models"
26
+ COLOUR_MODEL = "RGB"
27
+
28
+ STYLE_SHINKAI = "Makoto Shinkai"
29
+ STYLE_HOSODA = "Mamoru Hosoda"
30
+ STYLE_MIYAZAKI = "Hayao Miyazaki"
31
+ STYLE_KON = "Satoshi Kon"
32
+ DEFAULT_STYLE = STYLE_SHINKAI
33
+ STYLE_CHOICE_LIST = [STYLE_SHINKAI, STYLE_HOSODA, STYLE_MIYAZAKI, STYLE_KON]
34
+
35
+ MODEL_REPO_SHINKAI = "akiyamasho/AnimeBackgroundGAN-Shinkai"
36
+ MODEL_FILE_SHINKAI = "shinkai_makoto.pth"
37
+
38
+ MODEL_REPO_HOSODA = "akiyamasho/AnimeBackgroundGAN-Hosoda"
39
+ MODEL_FILE_HOSODA = "hosoda_mamoru.pth"
40
+
41
+ MODEL_REPO_MIYAZAKI = "akiyamasho/AnimeBackgroundGAN-Miyazaki"
42
+ MODEL_FILE_MIYAZAKI = "miyazaki_hayao.pth"
43
+
44
+ MODEL_REPO_KON = "akiyamasho/AnimeBackgroundGAN-Kon"
45
+ MODEL_FILE_KON = "kon_satoshi.pth"
46
+
47
+ # Model Initalisation
48
+ shinkai_model_hfhub = hf_hub_download(repo_id=MODEL_REPO_SHINKAI, filename=MODEL_FILE_SHINKAI)
49
+ hosoda_model_hfhub = hf_hub_download(repo_id=MODEL_REPO_HOSODA, filename=MODEL_FILE_HOSODA)
50
+ miyazaki_model_hfhub = hf_hub_download(repo_id=MODEL_REPO_MIYAZAKI, filename=MODEL_FILE_MIYAZAKI)
51
+ kon_model_hfhub = hf_hub_download(repo_id=MODEL_REPO_KON, filename=MODEL_FILE_KON)
52
+
53
+ shinkai_model = Transformer()
54
+ hosoda_model = Transformer()
55
+ miyazaki_model = Transformer()
56
+ kon_model = Transformer()
57
+
58
+ enable_gpu = torch.cuda.is_available()
59
+
60
+ if enable_gpu:
61
+ # If you have multiple cards,
62
+ # you can assign to a specific card, eg: "cuda:0"("cuda") or "cuda:1"
63
+ # Use the first card by default: "cuda"
64
+ device = torch.device("cuda")
65
+ else:
66
+ device = "cpu"
67
+
68
+ shinkai_model.load_state_dict(
69
+ torch.load(shinkai_model_hfhub, device)
70
+ )
71
+ hosoda_model.load_state_dict(
72
+ torch.load(hosoda_model_hfhub, device)
73
+ )
74
+ miyazaki_model.load_state_dict(
75
+ torch.load(miyazaki_model_hfhub, device)
76
+ )
77
+ kon_model.load_state_dict(
78
+ torch.load(kon_model_hfhub, device)
79
+ )
80
+
81
+ if enable_gpu:
82
+ shinkai_model = shinkai_model.to(device)
83
+ hosoda_model = hosoda_model.to(device)
84
+ miyazaki_model = miyazaki_model.to(device)
85
+ kon_model = kon_model.to(device)
86
+
87
+ shinkai_model.eval()
88
+ hosoda_model.eval()
89
+ miyazaki_model.eval()
90
+ kon_model.eval()
91
+
92
+
93
+ # Functions
94
+
95
+ def get_model(style):
96
+ if style == STYLE_SHINKAI:
97
+ return shinkai_model
98
+ elif style == STYLE_HOSODA:
99
+ return hosoda_model
100
+ elif style == STYLE_MIYAZAKI:
101
+ return miyazaki_model
102
+ elif style == STYLE_KON:
103
+ return kon_model
104
+ else:
105
+ logger.warning(
106
+ f"Style {style} not found. Defaulting to Makoto Shinkai"
107
+ )
108
+ return shinkai_model
109
+
110
+
111
+ def adjust_image_for_model(img):
112
+ logger.info(f"Image Height: {img.height}, Image Width: {img.width}")
113
+ if img.height > MAX_DIMENSION or img.width > MAX_DIMENSION:
114
+ logger.info(f"Dimensions too large. Resizing to {MAX_DIMENSION}px.")
115
+ img.thumbnail((MAX_DIMENSION, MAX_DIMENSION), Image.ANTIALIAS)
116
+
117
+ return img
118
+
119
+
120
+ def inference(img, style):
121
+ img = adjust_image_for_model(img)
122
+
123
+ # load image
124
+ input_image = img.convert(COLOUR_MODEL)
125
+ input_image = np.asarray(input_image)
126
+ # RGB -> BGR
127
+ input_image = input_image[:, :, [2, 1, 0]]
128
+ input_image = transforms.ToTensor()(input_image).unsqueeze(0)
129
+ # preprocess, (-1, 1)
130
+ input_image = -1 + 2 * input_image
131
+
132
+ if enable_gpu:
133
+ logger.info(f"CUDA found. Using GPU.")
134
+ # Allows to specify a card for calculation
135
+ input_image = Variable(input_image).to(device)
136
+ else:
137
+ logger.info(f"CUDA not found. Using CPU.")
138
+ input_image = Variable(input_image).float()
139
+
140
+ # forward
141
+ model = get_model(style)
142
+ output_image = model(input_image)
143
+ output_image = output_image[0]
144
+ # BGR -> RGB
145
+ output_image = output_image[[2, 1, 0], :, :]
146
+ output_image = output_image.data.cpu().float() * 0.5 + 0.5
147
+
148
+ return transforms.ToPILImage()(output_image)
149
+
150
+
151
+ # Gradio setup
152
+
153
+ title = "Anime Background GAN"
154
+ description = "Gradio Demo for CartoonGAN by Chen Et. Al. Models are Shinkai Makoto, Hosoda Mamoru, Kon Satoshi, and Miyazaki Hayao."
155
+ article = "<p style='text-align: center'><a href='http://openaccess.thecvf.com/content_cvpr_2018/CameraReady/2205.pdf' target='_blank'>CartoonGAN Whitepaper from Chen et.al</a></p><p style='text-align: center'><a href='https://github.com/venture-anime/cartoongan-pytorch' target='_blank'>Github Repo</a></p><p style='text-align: center'><a href='https://github.com/Yijunmaverick/CartoonGAN-Test-Pytorch-Torch' target='_blank'>Original Implementation from Yijunmaverick</a></p><center><img src='https://visitor-badge.glitch.me/badge?page_id=akiyamasho' alt='visitor badge'></center></p>"
156
+
157
+ examples = [
158
+ ["examples/garden_in.jpg", STYLE_SHINKAI],
159
+ ["examples/library_in.jpg", STYLE_KON],
160
+ ]
161
+
162
+
163
+ gr.Interface(
164
+ fn=inference,
165
+ inputs=[
166
+ gr.inputs.Image(
167
+ type="pil",
168
+ label="Input Photo (less than 1280px on both width and height)",
169
+ ),
170
+ gr.inputs.Dropdown(
171
+ STYLE_CHOICE_LIST,
172
+ type="value",
173
+ default=DEFAULT_STYLE,
174
+ label="Style",
175
+ ),
176
+ ],
177
+ outputs=gr.outputs.Image(
178
+ type="pil",
179
+ label="Output Image",
180
+ ),
181
+ title=title,
182
+ description=description,
183
+ article=article,
184
+ examples=examples,
185
+ allow_flagging="never",
186
+ allow_screenshot=False,
187
+ ).launch(enable_queue=True)
examples/garden_in.jpg ADDED

Git LFS Details

  • SHA256: 40e4981ebc9c5e51185b451ac90726e48faadb3fb1e24797fafa30a30f13b42d
  • Pointer size: 132 Bytes
  • Size of remote file: 1.48 MB
examples/library_in.jpg ADDED
models/hosoda_mamoru.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:813504096c42ab7fa965c67cdbc24608400dd2c5a9ddaf8171d165d7344492d1
3
+ size 133
models/kon_satoshi.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3d5f9a0b193c1d7c019951a9886289a0536661d1ec3a2dcd98fcd213402bad28
3
+ size 133
models/miyazaki_hayao.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4c2aee56380168b266a7c747e0c26b6f939b7fdac41a8fd620d94450dad12061
3
+ size 133
models/shinkai_makoto.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6e662cf1194c6633f409dfbffcc8454118593f96719e92dc268b74d0a74892cd
3
+ size 133
network/Transformer.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class Transformer(nn.Module):
7
+ def __init__(self):
8
+ super(Transformer, self).__init__()
9
+ #
10
+ self.refpad01_1 = nn.ReflectionPad2d(3)
11
+ self.conv01_1 = nn.Conv2d(3, 64, 7)
12
+ self.in01_1 = InstanceNormalization(64)
13
+ # relu
14
+ self.conv02_1 = nn.Conv2d(64, 128, 3, 2, 1)
15
+ self.conv02_2 = nn.Conv2d(128, 128, 3, 1, 1)
16
+ self.in02_1 = InstanceNormalization(128)
17
+ # relu
18
+ self.conv03_1 = nn.Conv2d(128, 256, 3, 2, 1)
19
+ self.conv03_2 = nn.Conv2d(256, 256, 3, 1, 1)
20
+ self.in03_1 = InstanceNormalization(256)
21
+ # relu
22
+
23
+ ## res block 1
24
+ self.refpad04_1 = nn.ReflectionPad2d(1)
25
+ self.conv04_1 = nn.Conv2d(256, 256, 3)
26
+ self.in04_1 = InstanceNormalization(256)
27
+ # relu
28
+ self.refpad04_2 = nn.ReflectionPad2d(1)
29
+ self.conv04_2 = nn.Conv2d(256, 256, 3)
30
+ self.in04_2 = InstanceNormalization(256)
31
+ # + input
32
+
33
+ ## res block 2
34
+ self.refpad05_1 = nn.ReflectionPad2d(1)
35
+ self.conv05_1 = nn.Conv2d(256, 256, 3)
36
+ self.in05_1 = InstanceNormalization(256)
37
+ # relu
38
+ self.refpad05_2 = nn.ReflectionPad2d(1)
39
+ self.conv05_2 = nn.Conv2d(256, 256, 3)
40
+ self.in05_2 = InstanceNormalization(256)
41
+ # + input
42
+
43
+ ## res block 3
44
+ self.refpad06_1 = nn.ReflectionPad2d(1)
45
+ self.conv06_1 = nn.Conv2d(256, 256, 3)
46
+ self.in06_1 = InstanceNormalization(256)
47
+ # relu
48
+ self.refpad06_2 = nn.ReflectionPad2d(1)
49
+ self.conv06_2 = nn.Conv2d(256, 256, 3)
50
+ self.in06_2 = InstanceNormalization(256)
51
+ # + input
52
+
53
+ ## res block 4
54
+ self.refpad07_1 = nn.ReflectionPad2d(1)
55
+ self.conv07_1 = nn.Conv2d(256, 256, 3)
56
+ self.in07_1 = InstanceNormalization(256)
57
+ # relu
58
+ self.refpad07_2 = nn.ReflectionPad2d(1)
59
+ self.conv07_2 = nn.Conv2d(256, 256, 3)
60
+ self.in07_2 = InstanceNormalization(256)
61
+ # + input
62
+
63
+ ## res block 5
64
+ self.refpad08_1 = nn.ReflectionPad2d(1)
65
+ self.conv08_1 = nn.Conv2d(256, 256, 3)
66
+ self.in08_1 = InstanceNormalization(256)
67
+ # relu
68
+ self.refpad08_2 = nn.ReflectionPad2d(1)
69
+ self.conv08_2 = nn.Conv2d(256, 256, 3)
70
+ self.in08_2 = InstanceNormalization(256)
71
+ # + input
72
+
73
+ ## res block 6
74
+ self.refpad09_1 = nn.ReflectionPad2d(1)
75
+ self.conv09_1 = nn.Conv2d(256, 256, 3)
76
+ self.in09_1 = InstanceNormalization(256)
77
+ # relu
78
+ self.refpad09_2 = nn.ReflectionPad2d(1)
79
+ self.conv09_2 = nn.Conv2d(256, 256, 3)
80
+ self.in09_2 = InstanceNormalization(256)
81
+ # + input
82
+
83
+ ## res block 7
84
+ self.refpad10_1 = nn.ReflectionPad2d(1)
85
+ self.conv10_1 = nn.Conv2d(256, 256, 3)
86
+ self.in10_1 = InstanceNormalization(256)
87
+ # relu
88
+ self.refpad10_2 = nn.ReflectionPad2d(1)
89
+ self.conv10_2 = nn.Conv2d(256, 256, 3)
90
+ self.in10_2 = InstanceNormalization(256)
91
+ # + input
92
+
93
+ ## res block 8
94
+ self.refpad11_1 = nn.ReflectionPad2d(1)
95
+ self.conv11_1 = nn.Conv2d(256, 256, 3)
96
+ self.in11_1 = InstanceNormalization(256)
97
+ # relu
98
+ self.refpad11_2 = nn.ReflectionPad2d(1)
99
+ self.conv11_2 = nn.Conv2d(256, 256, 3)
100
+ self.in11_2 = InstanceNormalization(256)
101
+ # + input
102
+
103
+ ##------------------------------------##
104
+ self.deconv01_1 = nn.ConvTranspose2d(256, 128, 3, 2, 1, 1)
105
+ self.deconv01_2 = nn.Conv2d(128, 128, 3, 1, 1)
106
+ self.in12_1 = InstanceNormalization(128)
107
+ # relu
108
+ self.deconv02_1 = nn.ConvTranspose2d(128, 64, 3, 2, 1, 1)
109
+ self.deconv02_2 = nn.Conv2d(64, 64, 3, 1, 1)
110
+ self.in13_1 = InstanceNormalization(64)
111
+ # relu
112
+ self.refpad12_1 = nn.ReflectionPad2d(3)
113
+ self.deconv03_1 = nn.Conv2d(64, 3, 7)
114
+ # tanh
115
+
116
+ def forward(self, x):
117
+ y = F.relu(self.in01_1(self.conv01_1(self.refpad01_1(x))))
118
+ y = F.relu(self.in02_1(self.conv02_2(self.conv02_1(y))))
119
+ t04 = F.relu(self.in03_1(self.conv03_2(self.conv03_1(y))))
120
+
121
+ ##
122
+ y = F.relu(self.in04_1(self.conv04_1(self.refpad04_1(t04))))
123
+ t05 = self.in04_2(self.conv04_2(self.refpad04_2(y))) + t04
124
+
125
+ y = F.relu(self.in05_1(self.conv05_1(self.refpad05_1(t05))))
126
+ t06 = self.in05_2(self.conv05_2(self.refpad05_2(y))) + t05
127
+
128
+ y = F.relu(self.in06_1(self.conv06_1(self.refpad06_1(t06))))
129
+ t07 = self.in06_2(self.conv06_2(self.refpad06_2(y))) + t06
130
+
131
+ y = F.relu(self.in07_1(self.conv07_1(self.refpad07_1(t07))))
132
+ t08 = self.in07_2(self.conv07_2(self.refpad07_2(y))) + t07
133
+
134
+ y = F.relu(self.in08_1(self.conv08_1(self.refpad08_1(t08))))
135
+ t09 = self.in08_2(self.conv08_2(self.refpad08_2(y))) + t08
136
+
137
+ y = F.relu(self.in09_1(self.conv09_1(self.refpad09_1(t09))))
138
+ t10 = self.in09_2(self.conv09_2(self.refpad09_2(y))) + t09
139
+
140
+ y = F.relu(self.in10_1(self.conv10_1(self.refpad10_1(t10))))
141
+ t11 = self.in10_2(self.conv10_2(self.refpad10_2(y))) + t10
142
+
143
+ y = F.relu(self.in11_1(self.conv11_1(self.refpad11_1(t11))))
144
+ y = self.in11_2(self.conv11_2(self.refpad11_2(y))) + t11
145
+ ##
146
+
147
+ y = F.relu(self.in12_1(self.deconv01_2(self.deconv01_1(y))))
148
+ y = F.relu(self.in13_1(self.deconv02_2(self.deconv02_1(y))))
149
+ y = torch.tanh(self.deconv03_1(self.refpad12_1(y)))
150
+
151
+ return y
152
+
153
+
154
+ class InstanceNormalization(nn.Module):
155
+ def __init__(self, dim, eps=1e-9):
156
+ super(InstanceNormalization, self).__init__()
157
+ self.scale = nn.Parameter(torch.FloatTensor(dim))
158
+ self.shift = nn.Parameter(torch.FloatTensor(dim))
159
+ self.eps = eps
160
+ self._reset_parameters()
161
+
162
+ def _reset_parameters(self):
163
+ self.scale.data.uniform_()
164
+ self.shift.data.zero_()
165
+
166
+ def __call__(self, x):
167
+ n = x.size(2) * x.size(3)
168
+ t = x.view(x.size(0), x.size(1), n)
169
+ mean = torch.mean(t, 2).unsqueeze(2).unsqueeze(3).expand_as(x)
170
+ # Calculate the biased var. torch.var returns unbiased var
171
+ var = torch.var(t, 2).unsqueeze(2).unsqueeze(3).expand_as(x) * (
172
+ (n - 1) / float(n)
173
+ )
174
+ scale_broadcast = self.scale.unsqueeze(1).unsqueeze(1).unsqueeze(0)
175
+ scale_broadcast = scale_broadcast.expand_as(x)
176
+ shift_broadcast = self.shift.unsqueeze(1).unsqueeze(1).unsqueeze(0)
177
+ shift_broadcast = shift_broadcast.expand_as(x)
178
+ out = (x - mean) / torch.sqrt(var + self.eps)
179
+ out = out * scale_broadcast + shift_broadcast
180
+ return out
network/__init__.py ADDED
File without changes