sztanki commited on
Commit
979953f
·
1 Parent(s): 59506ce

Add White Walker model

Browse files
Files changed (1) hide show
  1. app.py +11 -0
app.py CHANGED
@@ -70,6 +70,7 @@ mean_latent = original_generator.mean_latent(10000)
70
  generatorzombie = deepcopy(original_generator)
71
  generatorhulk = deepcopy(original_generator)
72
  generatorjojo = deepcopy(original_generator)
 
73
 
74
  transform = transforms.Compose(
75
  [
@@ -89,11 +90,18 @@ modelzombie = hf_hub_download(repo_id="Awesimo/jojogan-zombie", filename="zombie
89
  ckptzombie = torch.load(modelzombie, map_location=lambda storage, loc: storage)
90
  generatorzombie.load_state_dict(ckptzombie["g"], strict=False)
91
 
 
 
 
 
 
92
  #JOJO
93
  modeljojo = hf_hub_download(repo_id="akhaliq/JoJoGAN-jojo", filename="jojo_preserve_color.pt")
94
  ckptjojo = torch.load(modeljojo, map_location=lambda storage, loc: storage)
95
  generatorjojo.load_state_dict(ckptjojo["g"], strict=False)
96
 
 
 
97
  def inference(img, model):
98
  img.save('out.jpg')
99
  aligned_face = align_face('out.jpg')
@@ -105,6 +113,9 @@ def inference(img, model):
105
  elif model == 'Zombie':
106
  with torch.no_grad():
107
  my_sample = generatorzombie(my_w, input_is_latent=True)
 
 
 
108
  elif model == 'JoJo':
109
  with torch.no_grad():
110
  my_sample = generatorjojo(my_w, input_is_latent=True)
 
70
  generatorzombie = deepcopy(original_generator)
71
  generatorhulk = deepcopy(original_generator)
72
  generatorjojo = deepcopy(original_generator)
73
+ generatorwalker = deepcopy(original_generator)
74
 
75
  transform = transforms.Compose(
76
  [
 
90
  ckptzombie = torch.load(modelzombie, map_location=lambda storage, loc: storage)
91
  generatorzombie.load_state_dict(ckptzombie["g"], strict=False)
92
 
93
+ #WHITE WALKER
94
+ modelwalker = hf_hub_download(repo_id="Awesimo/jojogan-white-walker", filename="white_walker.pt")
95
+ ckptwalker = torch.load(modelwalker, map_location=lambda storage, loc: storage)
96
+ generatorwalker.load_state_dict(ckptwalker["g"], strict=False)
97
+
98
  #JOJO
99
  modeljojo = hf_hub_download(repo_id="akhaliq/JoJoGAN-jojo", filename="jojo_preserve_color.pt")
100
  ckptjojo = torch.load(modeljojo, map_location=lambda storage, loc: storage)
101
  generatorjojo.load_state_dict(ckptjojo["g"], strict=False)
102
 
103
+
104
+
105
  def inference(img, model):
106
  img.save('out.jpg')
107
  aligned_face = align_face('out.jpg')
 
113
  elif model == 'Zombie':
114
  with torch.no_grad():
115
  my_sample = generatorzombie(my_w, input_is_latent=True)
116
+ elif model == 'White-Walker':
117
+ with torch.no_grad():
118
+ my_sample = generatorwalker(my_w, input_is_latent=True)
119
  elif model == 'JoJo':
120
  with torch.no_grad():
121
  my_sample = generatorjojo(my_w, input_is_latent=True)