Ahsen Khaliq commited on
Commit
49ce528
1 Parent(s): eb212a1

add caitlyn model

Browse files
Files changed (1) hide show
  1. app.py +13 -32
app.py CHANGED
@@ -4,12 +4,6 @@ import torch
4
  import gradio as gr
5
  os.system("pip install gradio==2.5.3")
6
 
7
- #os.system("pip install facexlib")
8
-
9
- #from facexlib.utils.face_restoration_helper import FaceRestoreHelper
10
- #os.system("pip install autocrop")
11
- #os.system("pip install dlib")
12
- #from autocrop import Cropper
13
  import torch
14
  torch.backends.cudnn.benchmark = True
15
  from torchvision import transforms, utils
@@ -38,15 +32,6 @@ os.system("wget http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2"
38
  os.system("bzip2 -dk shape_predictor_68_face_landmarks.dat.bz2")
39
  os.system("mv shape_predictor_68_face_landmarks.dat models/dlibshape_predictor_68_face_landmarks.dat")
40
 
41
- #cropper = Cropper(face_percent=80)
42
-
43
- #face_helper = FaceRestoreHelper(
44
- #upscale_factor=0,
45
- #face_size=512,
46
- #crop_ratio=(1, 1),
47
- #det_model='retinaface_resnet50',
48
- #save_ext='png',
49
- #device='cpu')
50
 
51
  device = 'cpu'
52
 
@@ -54,19 +39,19 @@ os.system("gdown https://drive.google.com/uc?id=1_cTsjqzD_X9DK3t3IZE53huKgnzj_bt
54
 
55
  latent_dim = 512
56
 
57
- # Load original generator
58
  original_generator = Generator(1024, latent_dim, 8, 2).to(device)
59
  ckpt = torch.load('stylegan2-ffhq-config-f.pt', map_location=lambda storage, loc: storage)
60
  original_generator.load_state_dict(ckpt["g_ema"], strict=False)
61
  mean_latent = original_generator.mean_latent(10000)
62
 
63
- # to be finetuned generator
64
  generatorjojo = deepcopy(original_generator)
65
 
66
  generatordisney = deepcopy(original_generator)
67
 
68
  generatorjinx = deepcopy(original_generator)
69
 
 
 
70
 
71
 
72
  transform = transforms.Compose(
@@ -95,22 +80,15 @@ os.system("gdown https://drive.google.com/uc?id=1jElwHxaYPod5Itdy18izJk49K1nl4ne
95
  ckptjinx = torch.load('arcane_jinx_preserve_color.pt', map_location=lambda storage, loc: storage)
96
  generatorjinx.load_state_dict(ckptjinx["g"], strict=False)
97
 
 
 
 
 
 
98
 
99
  def inference(img, model):
100
- #face_helper.clean_all()
101
  aligned_face = align_face(img)
102
- #cropped_array = cropper.crop(img[:,:,::-1])
103
-
104
- #if cropped_array.any():
105
- #aligned_face = Image.fromarray(cropped_array)
106
- #else:
107
- #aligned_face = Image.fromarray(img[:,:,::-1])
108
-
109
- #face_helper.read_image(img)
110
- #face_helper.get_face_landmarks_5(only_center_face=False, eye_dist_threshold=10)
111
- #face_helper.align_warp_face(save_cropped_path="/home/user/app/")
112
- #pilimg = Image.open("/home/user/app/_02.png")
113
-
114
  my_w = e4e_projection(aligned_face, "test.pt", device).unsqueeze(0)
115
  if model == 'JoJo':
116
  with torch.no_grad():
@@ -118,9 +96,12 @@ def inference(img, model):
118
  elif model == 'Disney':
119
  with torch.no_grad():
120
  my_sample = generatordisney(my_w, input_is_latent=True)
121
- else:
122
  with torch.no_grad():
123
  my_sample = generatorjinx(my_w, input_is_latent=True)
 
 
 
124
 
125
 
126
  npimage = my_sample[0].permute(1, 2, 0).detach().numpy()
@@ -133,4 +114,4 @@ description = "Gradio Demo for JoJoGAN: One Shot Face Stylization. To use it, si
133
  article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2112.11641' target='_blank'>JoJoGAN: One Shot Face Stylization</a>| <a href='https://github.com/mchong6/JoJoGAN' target='_blank'>Github Repo Pytorch</a></p> <center><img src='https://visitor-badge.glitch.me/badge?page_id=akhaliq_jojogan' alt='visitor badge'></center> <p style='text-align: center'>samples from repo: <img src='https://raw.githubusercontent.com/mchong6/JoJoGAN/main/teaser.jpg' alt='animation'/></p>"
134
 
135
  examples=[['iu.jpeg','Jinx']]
136
- gr.Interface(inference, [gr.inputs.Image(type="filepath"),gr.inputs.Dropdown(choices=['JoJo', 'Disney','Jinx'], type="value", default='JoJo', label="Model")], gr.outputs.Image(type="file"),title=title,description=description,article=article,enable_queue=True,allow_flagging=False,examples=examples).launch()
4
  import gradio as gr
5
  os.system("pip install gradio==2.5.3")
6
 
 
 
 
 
 
 
7
  import torch
8
  torch.backends.cudnn.benchmark = True
9
  from torchvision import transforms, utils
32
  os.system("bzip2 -dk shape_predictor_68_face_landmarks.dat.bz2")
33
  os.system("mv shape_predictor_68_face_landmarks.dat models/dlibshape_predictor_68_face_landmarks.dat")
34
 
 
 
 
 
 
 
 
 
 
35
 
36
  device = 'cpu'
37
 
39
 
40
  latent_dim = 512
41
 
 
42
  original_generator = Generator(1024, latent_dim, 8, 2).to(device)
43
  ckpt = torch.load('stylegan2-ffhq-config-f.pt', map_location=lambda storage, loc: storage)
44
  original_generator.load_state_dict(ckpt["g_ema"], strict=False)
45
  mean_latent = original_generator.mean_latent(10000)
46
 
 
47
  generatorjojo = deepcopy(original_generator)
48
 
49
  generatordisney = deepcopy(original_generator)
50
 
51
  generatorjinx = deepcopy(original_generator)
52
 
53
+ generatorcaitlyn = deepcopy(original_generator)
54
+
55
 
56
 
57
  transform = transforms.Compose(
80
  ckptjinx = torch.load('arcane_jinx_preserve_color.pt', map_location=lambda storage, loc: storage)
81
  generatorjinx.load_state_dict(ckptjinx["g"], strict=False)
82
 
83
+ os.system("gdown https://drive.google.com/uc?id=1cUTyjU-q98P75a8THCaO545RTwpVV-aH")
84
+
85
+ ckptcaitlyn = torch.load('arcane_caitlyn_preserve_color.pt', map_location=lambda storage, loc: storage)
86
+ generatorcaitlyn.load_state_dict(ckptcaitlyn["g"], strict=False)
87
+
88
 
89
  def inference(img, model):
 
90
  aligned_face = align_face(img)
91
+
 
 
 
 
 
 
 
 
 
 
 
92
  my_w = e4e_projection(aligned_face, "test.pt", device).unsqueeze(0)
93
  if model == 'JoJo':
94
  with torch.no_grad():
96
  elif model == 'Disney':
97
  with torch.no_grad():
98
  my_sample = generatordisney(my_w, input_is_latent=True)
99
+ elif model == 'Jinx':
100
  with torch.no_grad():
101
  my_sample = generatorjinx(my_w, input_is_latent=True)
102
+ else:
103
+ with torch.no_grad():
104
+ my_sample = generatorcaitlyn(my_w, input_is_latent=True)
105
 
106
 
107
  npimage = my_sample[0].permute(1, 2, 0).detach().numpy()
114
  article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2112.11641' target='_blank'>JoJoGAN: One Shot Face Stylization</a>| <a href='https://github.com/mchong6/JoJoGAN' target='_blank'>Github Repo Pytorch</a></p> <center><img src='https://visitor-badge.glitch.me/badge?page_id=akhaliq_jojogan' alt='visitor badge'></center> <p style='text-align: center'>samples from repo: <img src='https://raw.githubusercontent.com/mchong6/JoJoGAN/main/teaser.jpg' alt='animation'/></p>"
115
 
116
  examples=[['iu.jpeg','Jinx']]
117
+ gr.Interface(inference, [gr.inputs.Image(type="filepath"),gr.inputs.Dropdown(choices=['JoJo', 'Disney','Jinx','Caitlyn'], type="value", default='JoJo', label="Model")], gr.outputs.Image(type="file"),title=title,description=description,article=article,enable_queue=True,allow_flagging=False,examples=examples).launch()