Ahsen Khaliq commited on
Commit
8d4d98f
β€’
1 Parent(s): f6223e8
Files changed (1) hide show
  1. app.py +157 -0
app.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from PIL import Image
3
+ import torch
4
+ import gradio as gr
5
+
6
+ os.system("git clone https://github.com/mchong6/JoJoGAN.git")
7
+ os.chdir("JoJoGAN")
8
+
9
+
10
+ import torch
11
+ torch.backends.cudnn.benchmark = True
12
+ from torchvision import transforms, utils
13
+ from util import *
14
+ from PIL import Image
15
+ import math
16
+ import random
17
+
18
+ import numpy as np
19
+ from torch import nn, autograd, optim
20
+ from torch.nn import functional as F
21
+ from tqdm import tqdm
22
+ import lpips
23
+ from model import *
24
+ from e4e_projection import projection as e4e_projection
25
+
26
+ from google.colab import files
27
+ from copy import deepcopy
28
+ from pydrive.auth import GoogleAuth
29
+ from pydrive.drive import GoogleDrive
30
+ from google.colab import auth
31
+ from oauth2client.client import GoogleCredentials
32
+
33
+ os.makedirs('inversion_codes', exist_ok=True)
34
+ os.makedirs('style_images', exist_ok=True)
35
+ os.makedirs('style_images_aligned', exist_ok=True)
36
+ os.makedirs('models', exist_ok=True)
37
+
38
+ os.system("wget http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2")
39
+ os.system("bzip2 -dk shape_predictor_68_face_landmarks.dat.bz2")
40
+ os.system("mv shape_predictor_68_face_landmarks.dat models/dlibshape_predictor_68_face_landmarks.dat")
41
+
42
+
43
+ device = 'cpu'
44
+
45
+ download_with_pydrive = True #@param {type:"boolean"}
46
+
47
+ drive_ids = {
48
+ "stylegan2-ffhq-config-f.pt": "1Yr7KuD959btpmcKGAUsbAk5rPjX2MytK",
49
+ "e4e_ffhq_encode.pt": "1o6ijA3PkcewZvwJJ73dJ0fxhndn0nnh7",
50
+ "restyle_psp_ffhq_encode.pt": "1nbxCIVw9H3YnQsoIPykNEFwWJnHVHlVd",
51
+ "arcane_caitlyn.pt": "1gOsDTiTPcENiFOrhmkkxJcTURykW1dRc",
52
+ "arcane_caitlyn_preserve_color.pt": "1cUTyjU-q98P75a8THCaO545RTwpVV-aH",
53
+ "arcane_jinx_preserve_color.pt": "1jElwHxaYPod5Itdy18izJk49K1nl4ney",
54
+ "arcane_jinx.pt": "1quQ8vPjYpUiXM4k1_KIwP4EccOefPpG_",
55
+ "disney.pt": "1zbE2upakFUAx8ximYnLofFwfT8MilqJA",
56
+ "disney_preserve_color.pt": "1Bnh02DjfvN_Wm8c4JdOiNV4q9J7Z_tsi",
57
+ "jojo.pt": "13cR2xjIBj8Ga5jMO7gtxzIJj2PDsBYK4",
58
+ "jojo_preserve_color.pt": "1ZRwYLRytCEKi__eT2Zxv1IlV6BGVQ_K2",
59
+ "jojo_yasuho.pt": "1grZT3Gz1DLzFoJchAmoj3LoM9ew9ROX_",
60
+ "jojo_yasuho_preserve_color.pt": "1SKBu1h0iRNyeKBnya_3BBmLr4pkPeg_L",
61
+ "supergirl.pt": "1L0y9IYgzLNzB-33xTpXpecsKU-t9DpVC",
62
+ "supergirl_preserve_color.pt": "1VmKGuvThWHym7YuayXxjv0fSn32lfDpE",
63
+ }
64
+
65
+
66
+
67
+
68
+
69
+ os.system("gdown https://drive.google.com/uc?id=1Yr7KuD959btpmcKGAUsbAk5rPjX2MytK")
70
+ os.system("mv stylegan2-ffhq-config-f.pt models/stylegan2-ffhq-config-f.pt")
71
+
72
+
73
+ latent_dim = 512
74
+
75
+ # Load original generator
76
+ original_generator = Generator(1024, latent_dim, 8, 2).to(device)
77
+ ckpt = torch.load(os.path.join('models', ckpt), map_location=lambda storage, loc: storage)
78
+ original_generator.load_state_dict(ckpt["g_ema"], strict=False)
79
+ mean_latent = original_generator.mean_latent(10000)
80
+
81
+ # to be finetuned generator
82
+ generator = deepcopy(original_generator)
83
+
84
+ transform = transforms.Compose(
85
+ [
86
+ transforms.Resize((1024, 1024)),
87
+ transforms.ToTensor(),
88
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
89
+ ]
90
+ )
91
+
92
+ plt.rcParams['figure.dpi'] = 150
93
+
94
+ filepath = f'test_input/{filename}'
95
+
96
+ name = strip_path_extension(filepath)+'.pt'
97
+
98
+ aligned_face = align_face(filepath)
99
+
100
+ os.system("gdown https://drive.google.com/uc?id=1o6ijA3PkcewZvwJJ73dJ0fxhndn0nnh7")
101
+ os.system("mv e4e_ffhq_encode.pt models/e4e_ffhq_encode.pt")
102
+ my_w = e4e_projection(aligned_face, name, device).unsqueeze(0)
103
+
104
+
105
+ plt.rcParams['figure.dpi'] = 150
106
+ pretrained = 'jojo' #@param ['supergirl', 'arcane_jinx', 'arcane_caitlyn', 'jojo_yasuho', 'jojo', 'disney']
107
+ #@markdown Preserve color tries to preserve color of original image by limiting family of allowable transformations. Otherwise, the stylized image will inherit the colors of the reference images, leading to heavier stylizations.
108
+ preserve_color = False #@param{type:"boolean"}
109
+
110
+ if preserve_color:
111
+ ckpt = f'{pretrained}_preserve_color.pt'
112
+ else:
113
+ ckpt = f'{pretrained}.pt'
114
+
115
+ downloader.download_file(ckpt)
116
+ ckpt = torch.load(os.path.join('models', ckpt), map_location=lambda storage, loc: storage)
117
+ generator.load_state_dict(ckpt["g"], strict=False)
118
+
119
+ #@title Generate results
120
+ n_sample = 1#@param {type:"number"}
121
+ seed = 3000 #@param {type:"number"}
122
+
123
+ torch.manual_seed(seed)
124
+ with torch.no_grad():
125
+ generator.eval()
126
+ z = torch.randn(n_sample, latent_dim, device=device)
127
+
128
+ original_sample = original_generator([z], truncation=0.7, truncation_latent=mean_latent)
129
+ sample = generator([z], truncation=0.7, truncation_latent=mean_latent)
130
+
131
+ original_my_sample = original_generator(my_w, input_is_latent=True)
132
+ my_sample = generator(my_w, input_is_latent=True)
133
+
134
+ # display reference images
135
+ style_path = f'style_images_aligned/{pretrained}.png'
136
+ style_image = transform(Image.open(style_path)).unsqueeze(0).to(device)
137
+ face = transform(aligned_face).unsqueeze(0).to(device)
138
+
139
+ my_output = torch.cat([style_image, face, my_sample], 0)
140
+ display_image(utils.make_grid(my_output, normalize=True, range=(-1, 1)), title='My sample')
141
+
142
+ output = torch.cat([original_sample, sample], 0)
143
+ display_image(utils.make_grid(output, normalize=True, range=(-1, 1), nrow=n_sample), title='Random samples')
144
+
145
+ def inference(img, ver):
146
+ if ver == 'version 2 (πŸ”Ί robustness,πŸ”» stylization)':
147
+ out = face2paint(model2, img)
148
+ else:
149
+ out = face2paint(model1, img)
150
+ return out
151
+
152
+ title = "AnimeGANv2"
153
+ description = "Gradio Demo for AnimeGanv2 Face Portrait. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below. Please use a cropped portrait picture for best results similar to the examples below."
154
+ article = "<p style='text-align: center'><a href='https://github.com/bryandlee/animegan2-pytorch' target='_blank'>Github Repo Pytorch</a></p> <center><img src='https://visitor-badge.glitch.me/badge?page_id=akhaliq_animegan' alt='visitor badge'></center> <p style='text-align: center'>samples from repo: <img src='https://user-images.githubusercontent.com/26464535/129888683-98bb6283-7bb8-4d1a-a04a-e795f5858dcf.gif' alt='animation'/> <img src='https://user-images.githubusercontent.com/26464535/137619176-59620b59-4e20-4d98-9559-a424f86b7f24.jpg' alt='animation'/><img src='https://user-images.githubusercontent.com/26464535/127134790-93595da2-4f8b-4aca-a9d7-98699c5e6914.jpg' alt='animation'/></p>"
155
+ examples=[['groot.jpeg','version 2 (πŸ”Ί robustness,πŸ”» stylization)'],['bill.png','version 1 (πŸ”Ί stylization, πŸ”» robustness)'],['tony.png','version 1 (πŸ”Ί stylization, πŸ”» robustness)'],['elon.png','version 2 (πŸ”Ί robustness,πŸ”» stylization)'],['IU.png','version 1 (πŸ”Ί stylization, πŸ”» robustness)'],['billie.png','version 2 (πŸ”Ί robustness,πŸ”» stylization)'],['will.png','version 2 (πŸ”Ί robustness,πŸ”» stylization)'],['beyonce.png','version 1 (πŸ”Ί stylization, πŸ”» robustness)'],['gongyoo.jpeg','version 1 (πŸ”Ί stylization, πŸ”» robustness)']]
156
+ gr.Interface(inference, [gr.inputs.Image(type="pil"),gr.inputs.Radio(['version 1 (πŸ”Ί stylization, πŸ”» robustness)','version 2 (πŸ”Ί robustness,πŸ”» stylization)'], type="value", default='version 2 (πŸ”Ί robustness,πŸ”» stylization)', label='version')
157
+ ], gr.outputs.Image(type="pil"),title=title,description=description,article=article,enable_queue=True,examples=examples,allow_flagging=False).launch()