Ahsen Khaliq commited on
Commit
c3d2c2f
1 Parent(s): 49ce528

add yasuho model

Browse files
Files changed (1) hide show
  1. app.py +12 -1
app.py CHANGED
@@ -52,6 +52,9 @@ generatorjinx = deepcopy(original_generator)
52
 
53
  generatorcaitlyn = deepcopy(original_generator)
54
 
 
 
 
55
 
56
 
57
  transform = transforms.Compose(
@@ -85,6 +88,11 @@ os.system("gdown https://drive.google.com/uc?id=1cUTyjU-q98P75a8THCaO545RTwpVV-a
85
  ckptcaitlyn = torch.load('arcane_caitlyn_preserve_color.pt', map_location=lambda storage, loc: storage)
86
  generatorcaitlyn.load_state_dict(ckptcaitlyn["g"], strict=False)
87
 
 
 
 
 
 
88
 
89
  def inference(img, model):
90
  aligned_face = align_face(img)
@@ -99,9 +107,12 @@ def inference(img, model):
99
  elif model == 'Jinx':
100
  with torch.no_grad():
101
  my_sample = generatorjinx(my_w, input_is_latent=True)
102
- else:
103
  with torch.no_grad():
104
  my_sample = generatorcaitlyn(my_w, input_is_latent=True)
 
 
 
105
 
106
 
107
  npimage = my_sample[0].permute(1, 2, 0).detach().numpy()
52
 
53
  generatorcaitlyn = deepcopy(original_generator)
54
 
55
+ generatoryasuho = deepcopy(original_generator)
56
+
57
+
58
 
59
 
60
  transform = transforms.Compose(
88
  ckptcaitlyn = torch.load('arcane_caitlyn_preserve_color.pt', map_location=lambda storage, loc: storage)
89
  generatorcaitlyn.load_state_dict(ckptcaitlyn["g"], strict=False)
90
 
91
+ os.system("gdown https://drive.google.com/uc?id=1SKBu1h0iRNyeKBnya_3BBmLr4pkPeg_L")
92
+
93
+ ckptyasuho = torch.load('jojo_yasuho_preserve_color.pt', map_location=lambda storage, loc: storage)
94
+ generatoryasuho.load_state_dict(ckptyasuho["g"], strict=False)
95
+
96
 
97
  def inference(img, model):
98
  aligned_face = align_face(img)
107
  elif model == 'Jinx':
108
  with torch.no_grad():
109
  my_sample = generatorjinx(my_w, input_is_latent=True)
110
+ elif model == 'Caitlyn':
111
  with torch.no_grad():
112
  my_sample = generatorcaitlyn(my_w, input_is_latent=True)
113
+ else:
114
+ with torch.no_grad():
115
+ my_sample = generatoryasuho(my_w, input_is_latent=True)
116
 
117
 
118
  npimage = my_sample[0].permute(1, 2, 0).detach().numpy()