Ahsen Khaliq commited on
Commit
9769454
1 Parent(s): 6bd4dd4

add spiderverse model

Browse files
Files changed (1) hide show
  1. app.py +10 -4
app.py CHANGED
@@ -57,9 +57,7 @@ generatorarcanemulti = deepcopy(original_generator)
57
 
58
  generatorart = deepcopy(original_generator)
59
 
60
-
61
-
62
-
63
 
64
 
65
  transform = transforms.Compose(
@@ -108,6 +106,11 @@ os.system("gdown https://drive.google.com/uc?id=1a0QDEHwXQ6hE_FcYEyNMuv5r5UnRQLK
108
  ckptart = torch.load('art.pt', map_location=lambda storage, loc: storage)
109
  generatorart.load_state_dict(ckptart["g"], strict=False)
110
 
 
 
 
 
 
111
 
112
  def inference(img, model):
113
  aligned_face = align_face(img)
@@ -131,9 +134,12 @@ def inference(img, model):
131
  elif model == 'Arcane Multi':
132
  with torch.no_grad():
133
  my_sample = generatorarcanemulti(my_w, input_is_latent=True)
134
- else:
135
  with torch.no_grad():
136
  my_sample = generatorart(my_w, input_is_latent=True)
 
 
 
137
 
138
 
139
  npimage = my_sample[0].permute(1, 2, 0).detach().numpy()
57
 
58
  generatorart = deepcopy(original_generator)
59
 
60
+ generatorspider = deepcopy(original_generator)
 
 
61
 
62
 
63
  transform = transforms.Compose(
106
  ckptart = torch.load('art.pt', map_location=lambda storage, loc: storage)
107
  generatorart.load_state_dict(ckptart["g"], strict=False)
108
 
109
+ os.system("wget https://huggingface.co/akhaliq/jojo-gan-spiderverse/blob/main/spiderverse-checkpoint-3-face-500iters.pt")
110
+
111
+ ckptspider = torch.load('spiderverse-checkpoint-3-face-500iters.pt', map_location=lambda storage, loc: storage)
112
+ generatorspider.load_state_dict(ckptspider["g"], strict=False)
113
+
114
 
115
  def inference(img, model):
116
  aligned_face = align_face(img)
134
  elif model == 'Arcane Multi':
135
  with torch.no_grad():
136
  my_sample = generatorarcanemulti(my_w, input_is_latent=True)
137
+ elif model == 'Art':
138
  with torch.no_grad():
139
  my_sample = generatorart(my_w, input_is_latent=True)
140
+ else:
141
+ with torch.no_grad():
142
+ my_sample = generatorspider(my_w, input_is_latent=True)
143
 
144
 
145
  npimage = my_sample[0].permute(1, 2, 0).detach().numpy()