Ahsen Khaliq commited on
Commit
aa0db9c
β€’
1 Parent(s): 0f2a39a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -58
app.py CHANGED
@@ -45,7 +45,7 @@ latent_dim = 512
45
 
46
  # Load original generator
47
  original_generator = Generator(1024, latent_dim, 8, 2).to(device)
48
- ckpt = torch.load(os.path.join('models', ckpt), map_location=lambda storage, loc: storage)
49
  original_generator.load_state_dict(ckpt["g_ema"], strict=False)
50
  mean_latent = original_generator.mean_latent(10000)
51
 
@@ -61,68 +61,49 @@ transform = transforms.Compose(
61
  )
62
 
63
  plt.rcParams['figure.dpi'] = 150
64
-
65
- filepath = f'test_input/{filename}'
66
-
67
- name = strip_path_extension(filepath)+'.pt'
68
-
69
- aligned_face = align_face(filepath)
70
-
71
  os.system("gdown https://drive.google.com/uc?id=1o6ijA3PkcewZvwJJ73dJ0fxhndn0nnh7")
72
  os.system("mv e4e_ffhq_encode.pt models/e4e_ffhq_encode.pt")
73
- my_w = e4e_projection(aligned_face, name, device).unsqueeze(0)
74
 
75
 
76
- plt.rcParams['figure.dpi'] = 150
77
- pretrained = 'jojo' #@param ['supergirl', 'arcane_jinx', 'arcane_caitlyn', 'jojo_yasuho', 'jojo', 'disney']
78
- #@markdown Preserve color tries to preserve color of original image by limiting family of allowable transformations. Otherwise, the stylized image will inherit the colors of the reference images, leading to heavier stylizations.
79
- preserve_color = False #@param{type:"boolean"}
80
-
81
- if preserve_color:
82
- ckpt = f'{pretrained}_preserve_color.pt'
83
- else:
84
- ckpt = f'{pretrained}.pt'
85
-
86
- downloader.download_file(ckpt)
87
- ckpt = torch.load(os.path.join('models', ckpt), map_location=lambda storage, loc: storage)
88
- generator.load_state_dict(ckpt["g"], strict=False)
89
-
90
- #@title Generate results
91
- n_sample = 1#@param {type:"number"}
92
- seed = 3000 #@param {type:"number"}
93
-
94
- torch.manual_seed(seed)
95
- with torch.no_grad():
96
- generator.eval()
97
- z = torch.randn(n_sample, latent_dim, device=device)
98
-
99
- original_sample = original_generator([z], truncation=0.7, truncation_latent=mean_latent)
100
- sample = generator([z], truncation=0.7, truncation_latent=mean_latent)
101
-
102
- original_my_sample = original_generator(my_w, input_is_latent=True)
103
- my_sample = generator(my_w, input_is_latent=True)
104
-
105
- # display reference images
106
- style_path = f'style_images_aligned/{pretrained}.png'
107
- style_image = transform(Image.open(style_path)).unsqueeze(0).to(device)
108
- face = transform(aligned_face).unsqueeze(0).to(device)
109
-
110
- my_output = torch.cat([style_image, face, my_sample], 0)
111
- display_image(utils.make_grid(my_output, normalize=True, range=(-1, 1)), title='My sample')
112
-
113
- output = torch.cat([original_sample, sample], 0)
114
- display_image(utils.make_grid(output, normalize=True, range=(-1, 1), nrow=n_sample), title='Random samples')
115
-
116
- def inference(img, ver):
117
- if ver == 'version 2 (πŸ”Ί robustness,πŸ”» stylization)':
118
- out = face2paint(model2, img)
119
- else:
120
- out = face2paint(model1, img)
121
- return out
122
 
123
  title = "AnimeGANv2"
124
  description = "Gradio Demo for AnimeGanv2 Face Portrait. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below. Please use a cropped portrait picture for best results similar to the examples below."
125
  article = "<p style='text-align: center'><a href='https://github.com/bryandlee/animegan2-pytorch' target='_blank'>Github Repo Pytorch</a></p> <center><img src='https://visitor-badge.glitch.me/badge?page_id=akhaliq_animegan' alt='visitor badge'></center> <p style='text-align: center'>samples from repo: <img src='https://user-images.githubusercontent.com/26464535/129888683-98bb6283-7bb8-4d1a-a04a-e795f5858dcf.gif' alt='animation'/> <img src='https://user-images.githubusercontent.com/26464535/137619176-59620b59-4e20-4d98-9559-a424f86b7f24.jpg' alt='animation'/><img src='https://user-images.githubusercontent.com/26464535/127134790-93595da2-4f8b-4aca-a9d7-98699c5e6914.jpg' alt='animation'/></p>"
126
- examples=[['groot.jpeg','version 2 (πŸ”Ί robustness,πŸ”» stylization)'],['bill.png','version 1 (πŸ”Ί stylization, πŸ”» robustness)'],['tony.png','version 1 (πŸ”Ί stylization, πŸ”» robustness)'],['elon.png','version 2 (πŸ”Ί robustness,πŸ”» stylization)'],['IU.png','version 1 (πŸ”Ί stylization, πŸ”» robustness)'],['billie.png','version 2 (πŸ”Ί robustness,πŸ”» stylization)'],['will.png','version 2 (πŸ”Ί robustness,πŸ”» stylization)'],['beyonce.png','version 1 (πŸ”Ί stylization, πŸ”» robustness)'],['gongyoo.jpeg','version 1 (πŸ”Ί stylization, πŸ”» robustness)']]
127
- gr.Interface(inference, [gr.inputs.Image(type="pil"),gr.inputs.Radio(['version 1 (πŸ”Ί stylization, πŸ”» robustness)','version 2 (πŸ”Ί robustness,πŸ”» stylization)'], type="value", default='version 2 (πŸ”Ί robustness,πŸ”» stylization)', label='version')
128
- ], gr.outputs.Image(type="pil"),title=title,description=description,article=article,enable_queue=True,examples=examples,allow_flagging=False).launch()
45
 
46
  # Load original generator
47
  original_generator = Generator(1024, latent_dim, 8, 2).to(device)
48
+ ckpt = torch.load(os.path.join('models', 'stylegan2-ffhq-config-f.pt'), map_location=lambda storage, loc: storage)
49
  original_generator.load_state_dict(ckpt["g_ema"], strict=False)
50
  mean_latent = original_generator.mean_latent(10000)
51
 
61
  )
62
 
63
  plt.rcParams['figure.dpi'] = 150
 
 
 
 
 
 
 
64
  os.system("gdown https://drive.google.com/uc?id=1o6ijA3PkcewZvwJJ73dJ0fxhndn0nnh7")
65
  os.system("mv e4e_ffhq_encode.pt models/e4e_ffhq_encode.pt")
 
66
 
67
 
68
+ os.system("gdown https://drive.google.com/uc?id=13cR2xjIBj8Ga5jMO7gtxzIJj2PDsBYK4")
69
+ os.system("mv e4e_ffhq_encode.pt models/jojo.pt")
70
+
71
+ os.system("gdown https://drive.google.com/uc?id=1ZRwYLRytCEKi__eT2Zxv1IlV6BGVQ_K2")
72
+ os.system("mv e4e_ffhq_encode.pt models/jojo_preserve_color.pt")
73
+
74
+ def inference(img):
75
+ filepath = img
76
+
77
+ name = strip_path_extension(filepath)+'.pt'
78
+
79
+ aligned_face = align_face(filepath)
80
+
81
+ my_w = e4e_projection(aligned_face, name, device).unsqueeze(0)
82
+
83
+
84
+ plt.rcParams['figure.dpi'] = 150
85
+ pretrained = 'jojo' #@param ['supergirl', 'arcane_jinx', 'arcane_caitlyn', 'jojo_yasuho', 'jojo', 'disney']
86
+ #@markdown Preserve color tries to preserve color of original image by limiting family of allowable transformations. Otherwise, the stylized image will inherit the colors of the reference images, leading to heavier stylizations.
87
+ preserve_color = False #@param{type:"boolean"}
88
+
89
+
90
+ ckpt = torch.load(os.path.join('models', 'jojo.pt'), map_location=lambda storage, loc: storage)
91
+ generator.load_state_dict(ckpt["g"], strict=False)
92
+
93
+ with torch.no_grad():
94
+ generator.eval()
95
+
96
+ original_my_sample = original_generator(my_w, input_is_latent=True)
97
+ my_sample = generator(my_w, input_is_latent=True)
98
+
99
+ return
100
+
101
+ my_output = torch.cat([style_image, face, my_sample], 0)
102
+
103
+ return my_output
 
 
 
 
 
 
 
 
 
 
104
 
105
  title = "AnimeGANv2"
106
  description = "Gradio Demo for AnimeGanv2 Face Portrait. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below. Please use a cropped portrait picture for best results similar to the examples below."
107
  article = "<p style='text-align: center'><a href='https://github.com/bryandlee/animegan2-pytorch' target='_blank'>Github Repo Pytorch</a></p> <center><img src='https://visitor-badge.glitch.me/badge?page_id=akhaliq_animegan' alt='visitor badge'></center> <p style='text-align: center'>samples from repo: <img src='https://user-images.githubusercontent.com/26464535/129888683-98bb6283-7bb8-4d1a-a04a-e795f5858dcf.gif' alt='animation'/> <img src='https://user-images.githubusercontent.com/26464535/137619176-59620b59-4e20-4d98-9559-a424f86b7f24.jpg' alt='animation'/><img src='https://user-images.githubusercontent.com/26464535/127134790-93595da2-4f8b-4aca-a9d7-98699c5e6914.jpg' alt='animation'/></p>"
108
+
109
+ gr.Interface(inference, [gr.inputs.Image(type="file")], gr.outputs.Image(type="pil"),title=title,description=description,article=article,enable_queue=True,allow_flagging=False).launch()