Мясников Филипп Сергеевич commited on
Commit
9e6273b
1 Parent(s): 6c92b57
Files changed (1) hide show
  1. app.py +7 -13
app.py CHANGED
@@ -2,10 +2,7 @@ import os
2
  from PIL import Image
3
  import torch
4
  import gradio as gr
5
- import torch
6
  torch.backends.cudnn.benchmark = True
7
- from torchvision import transforms, utils
8
- from PIL import Image
9
  import math
10
  import random
11
  import numpy as np
@@ -18,11 +15,8 @@ import time
18
  from copy import deepcopy
19
  import imageio
20
 
21
- import os
22
  import sys
23
- import numpy as np
24
  from PIL import Image
25
- import torch
26
  import torchvision.transforms as transforms
27
  from argparse import Namespace
28
  from e4e.utils.common import tensor2im
@@ -114,9 +108,9 @@ def run_alignment(image_path):
114
  def gen_im(ffhq_codes, dog_codes, cat_codes, model_type='ffhq'):
115
  if model_type=='ffhq':
116
  imgs, _ = ffhq_decoder([ffhq_codes], input_is_latent=True, randomize_noise=False, return_latents=True)
117
- elif model_type=='dog':
118
  imgs, _ = dog_decoder([dog_codes], input_is_latent=True, randomize_noise=False, return_latents=True)
119
- elif model_type=='cat':
120
  imgs, _ = cat_decoder([cat_codes], input_is_latent=True, randomize_noise=False, return_latents=True)
121
  else:
122
  imgs, _ = custom_decoder([custom_codes], input_is_latent=True, randomize_noise=False, return_latents=True)
@@ -125,7 +119,7 @@ def gen_im(ffhq_codes, dog_codes, cat_codes, model_type='ffhq'):
125
  def set_seed(rd):
126
  torch.manual_seed(rd)
127
 
128
- def inference(img):
129
  random_seed = round(time.time() * 1000)
130
  set_seed(random_seed)
131
 
@@ -143,17 +137,17 @@ def inference(img):
143
  dog_codes = dog_encoder(transformed_image.unsqueeze(0).to(device).float())
144
  dog_codes = dog_codes + ffhq_latent_avg.repeat(dog_codes.shape[0], 1, 1)
145
 
146
- animal = "cat"
147
- npimage = gen_im(ffhq_codes, dog_codes, cat_codes, animal)
148
 
149
  imageio.imwrite('filename.jpeg', npimage)
150
  return 'filename.jpeg'
151
 
152
  title = "PetBreeder v1.1"
153
- description = "Gradio Demo for PetBreeder."
154
 
155
  gr.Interface(inference,
156
- [gr.inputs.Image(type="pil")],
 
157
  gr.outputs.Image(type="file"),
158
  title=title,
159
  description=description).launch()
 
2
  from PIL import Image
3
  import torch
4
  import gradio as gr
 
5
  torch.backends.cudnn.benchmark = True
 
 
6
  import math
7
  import random
8
  import numpy as np
 
15
  from copy import deepcopy
16
  import imageio
17
 
 
18
  import sys
 
19
  from PIL import Image
 
20
  import torchvision.transforms as transforms
21
  from argparse import Namespace
22
  from e4e.utils.common import tensor2im
 
108
  def gen_im(ffhq_codes, dog_codes, cat_codes, model_type='ffhq'):
109
  if model_type=='ffhq':
110
  imgs, _ = ffhq_decoder([ffhq_codes], input_is_latent=True, randomize_noise=False, return_latents=True)
111
+ elif model_type=='Dog':
112
  imgs, _ = dog_decoder([dog_codes], input_is_latent=True, randomize_noise=False, return_latents=True)
113
+ elif model_type=='Cat':
114
  imgs, _ = cat_decoder([cat_codes], input_is_latent=True, randomize_noise=False, return_latents=True)
115
  else:
116
  imgs, _ = custom_decoder([custom_codes], input_is_latent=True, randomize_noise=False, return_latents=True)
 
119
  def set_seed(rd):
120
  torch.manual_seed(rd)
121
 
122
+ def inference(img, model):
123
  random_seed = round(time.time() * 1000)
124
  set_seed(random_seed)
125
 
 
137
  dog_codes = dog_encoder(transformed_image.unsqueeze(0).to(device).float())
138
  dog_codes = dog_codes + ffhq_latent_avg.repeat(dog_codes.shape[0], 1, 1)
139
 
140
+ npimage = gen_im(ffhq_codes, dog_codes, cat_codes, model)
 
141
 
142
  imageio.imwrite('filename.jpeg', npimage)
143
  return 'filename.jpeg'
144
 
145
  title = "PetBreeder v1.1"
146
+ description = "Gradio Demo for PetBreeder. Based on [Colab](https://colab.research.google.com/github/tg-bomze/collection-of-notebooks/blob/master/PetBreeder.ipynb) by [@MLArt](https://t.me/MLArt)."
147
 
148
  gr.Interface(inference,
149
+ [gr.inputs.Image(type="pil"),
150
+ gr.inputs.Dropdown(choices=['Cat','Dog'], type='value', default='Cat', label='Model')]
151
  gr.outputs.Image(type="file"),
152
  title=title,
153
  description=description).launch()