Spaces:
Runtime error
Runtime error
Ahsen Khaliq
commited on
Commit
·
9769454
1
Parent(s):
6bd4dd4
add spiderverse model
Browse files
app.py
CHANGED
@@ -57,9 +57,7 @@ generatorarcanemulti = deepcopy(original_generator)
|
|
57 |
|
58 |
generatorart = deepcopy(original_generator)
|
59 |
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
|
64 |
|
65 |
transform = transforms.Compose(
|
@@ -108,6 +106,11 @@ os.system("gdown https://drive.google.com/uc?id=1a0QDEHwXQ6hE_FcYEyNMuv5r5UnRQLK
|
|
108 |
ckptart = torch.load('art.pt', map_location=lambda storage, loc: storage)
|
109 |
generatorart.load_state_dict(ckptart["g"], strict=False)
|
110 |
|
|
|
|
|
|
|
|
|
|
|
111 |
|
112 |
def inference(img, model):
|
113 |
aligned_face = align_face(img)
|
@@ -131,9 +134,12 @@ def inference(img, model):
|
|
131 |
elif model == 'Arcane Multi':
|
132 |
with torch.no_grad():
|
133 |
my_sample = generatorarcanemulti(my_w, input_is_latent=True)
|
134 |
-
|
135 |
with torch.no_grad():
|
136 |
my_sample = generatorart(my_w, input_is_latent=True)
|
|
|
|
|
|
|
137 |
|
138 |
|
139 |
npimage = my_sample[0].permute(1, 2, 0).detach().numpy()
|
|
|
57 |
|
58 |
generatorart = deepcopy(original_generator)
|
59 |
|
60 |
+
generatorspider = deepcopy(original_generator)
|
|
|
|
|
61 |
|
62 |
|
63 |
transform = transforms.Compose(
|
|
|
106 |
ckptart = torch.load('art.pt', map_location=lambda storage, loc: storage)
|
107 |
generatorart.load_state_dict(ckptart["g"], strict=False)
|
108 |
|
109 |
+
os.system("wget https://huggingface.co/akhaliq/jojo-gan-spiderverse/blob/main/spiderverse-checkpoint-3-face-500iters.pt")
|
110 |
+
|
111 |
+
ckptspider = torch.load('spiderverse-checkpoint-3-face-500iters.pt', map_location=lambda storage, loc: storage)
|
112 |
+
generatorspider.load_state_dict(ckptspider["g"], strict=False)
|
113 |
+
|
114 |
|
115 |
def inference(img, model):
|
116 |
aligned_face = align_face(img)
|
|
|
134 |
elif model == 'Arcane Multi':
|
135 |
with torch.no_grad():
|
136 |
my_sample = generatorarcanemulti(my_w, input_is_latent=True)
|
137 |
+
elif model == 'Art':
|
138 |
with torch.no_grad():
|
139 |
my_sample = generatorart(my_w, input_is_latent=True)
|
140 |
+
else:
|
141 |
+
with torch.no_grad():
|
142 |
+
my_sample = generatorspider(my_w, input_is_latent=True)
|
143 |
|
144 |
|
145 |
npimage = my_sample[0].permute(1, 2, 0).detach().numpy()
|