Ahsen Khaliq commited on
Commit
c5e7a23
1 Parent(s): 1fbd109

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -1
app.py CHANGED
@@ -54,6 +54,9 @@ generatoryasuho = deepcopy(original_generator)
54
 
55
  generatorarcanemulti = deepcopy(original_generator)
56
 
 
 
 
57
 
58
 
59
 
@@ -99,6 +102,11 @@ os.system("gdown https://drive.google.com/uc?id=1enJgrC08NpWpx2XGBmLt1laimjpGCyf
99
  ckptarcanemulti = torch.load('arcane_multi_preserve_color.pt', map_location=lambda storage, loc: storage)
100
  generatorarcanemulti.load_state_dict(ckptarcanemulti["g"], strict=False)
101
 
 
 
 
 
 
102
 
103
  def inference(img, model):
104
  aligned_face = align_face(img)
@@ -119,9 +127,12 @@ def inference(img, model):
119
  elif model == 'Yasuho':
120
  with torch.no_grad():
121
  my_sample = generatoryasuho(my_w, input_is_latent=True)
122
- else:
123
  with torch.no_grad():
124
  my_sample = generatorarcanemulti(my_w, input_is_latent=True)
 
 
 
125
 
126
 
127
  npimage = my_sample[0].permute(1, 2, 0).detach().numpy()
54
 
55
  generatorarcanemulti = deepcopy(original_generator)
56
 
57
+ generatorart = deepcopy(original_generator)
58
+
59
+
60
 
61
 
62
 
102
  ckptarcanemulti = torch.load('arcane_multi_preserve_color.pt', map_location=lambda storage, loc: storage)
103
  generatorarcanemulti.load_state_dict(ckptarcanemulti["g"], strict=False)
104
 
105
+ os.system("gdown https://drive.google.com/uc?id=1a0QDEHwXQ6hE_FcYEyNMuv5r5UnRQLKT")
106
+
107
+ ckptart = torch.load('art.pt', map_location=lambda storage, loc: storage)
108
+ generatorart.load_state_dict(ckptart["g"], strict=False)
109
+
110
 
111
  def inference(img, model):
112
  aligned_face = align_face(img)
127
  elif model == 'Yasuho':
128
  with torch.no_grad():
129
  my_sample = generatoryasuho(my_w, input_is_latent=True)
130
+ elif model == 'Arcane Multi':
131
  with torch.no_grad():
132
  my_sample = generatorarcanemulti(my_w, input_is_latent=True)
133
+ else:
134
+ with torch.no_grad():
135
+ my_sample = generatorart(my_w, input_is_latent=True)
136
 
137
 
138
  npimage = my_sample[0].permute(1, 2, 0).detach().numpy()