Spaces:
Runtime error
Runtime error
Ahsen Khaliq
commited on
Commit
•
96d399d
1
Parent(s):
f74a230
Update app.py
Browse files
app.py
CHANGED
@@ -17,11 +17,46 @@ import lpips
|
|
17 |
from model import *
|
18 |
|
19 |
|
20 |
-
from e4e_projection import projection as e4e_projection
|
21 |
|
22 |
from copy import deepcopy
|
23 |
import imageio
|
24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
os.makedirs('models', exist_ok=True)
|
26 |
os.system("gdown https://drive.google.com/uc?id=1jtCg8HQ6RlTmLdnbT2PfW1FJ2AYkWqsK")
|
27 |
os.system("cp e4e_ffhq_encode.pt models/e4e_ffhq_encode.pt")
|
@@ -113,7 +148,7 @@ generatorspider.load_state_dict(ckptspider["g"], strict=False)
|
|
113 |
def inference(img, model):
|
114 |
aligned_face = align_face(img)
|
115 |
|
116 |
-
my_w =
|
117 |
if model == 'JoJo':
|
118 |
with torch.no_grad():
|
119 |
my_sample = generatorjojo(my_w, input_is_latent=True)
|
17 |
from model import *
|
18 |
|
19 |
|
20 |
+
#from e4e_projection import projection as e4e_projection
|
21 |
|
22 |
from copy import deepcopy
|
23 |
import imageio
|
24 |
|
25 |
+
import os
|
26 |
+
import sys
|
27 |
+
import numpy as np
|
28 |
+
from PIL import Image
|
29 |
+
import torch
|
30 |
+
import torchvision.transforms as transforms
|
31 |
+
from argparse import Namespace
|
32 |
+
from e4e.models.psp import pSp
|
33 |
+
from util import *
|
34 |
+
|
35 |
+
@ torch.no_grad()
|
36 |
+
|
37 |
+
def projection(img, name, device='cuda'):
|
38 |
+
|
39 |
+
model_path = 'models/e4e_ffhq_encode.pt'
|
40 |
+
ckpt = torch.load(model_path, map_location='cpu')
|
41 |
+
opts = ckpt['opts']
|
42 |
+
opts['checkpoint_path'] = model_path
|
43 |
+
opts= Namespace(**opts)
|
44 |
+
net = pSp(opts, device).eval().to(device)
|
45 |
+
transform = transforms.Compose(
|
46 |
+
[
|
47 |
+
transforms.Resize(256),
|
48 |
+
transforms.CenterCrop(256),
|
49 |
+
transforms.ToTensor(),
|
50 |
+
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
|
51 |
+
]
|
52 |
+
)
|
53 |
+
img = transform(img).unsqueeze(0).to(device)
|
54 |
+
images, w_plus = net(img, randomize_noise=False, return_latents=True)
|
55 |
+
result_file = {}
|
56 |
+
result_file['latent'] = w_plus[0]
|
57 |
+
torch.save(result_file, name)
|
58 |
+
return w_plus[0]
|
59 |
+
|
60 |
os.makedirs('models', exist_ok=True)
|
61 |
os.system("gdown https://drive.google.com/uc?id=1jtCg8HQ6RlTmLdnbT2PfW1FJ2AYkWqsK")
|
62 |
os.system("cp e4e_ffhq_encode.pt models/e4e_ffhq_encode.pt")
|
148 |
def inference(img, model):
|
149 |
aligned_face = align_face(img)
|
150 |
|
151 |
+
my_w = projection(aligned_face, "test.pt", device).unsqueeze(0)
|
152 |
if model == 'JoJo':
|
153 |
with torch.no_grad():
|
154 |
my_sample = generatorjojo(my_w, input_is_latent=True)
|