Ahsen Khaliq commited on
Commit
96d399d
1 Parent(s): f74a230

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -2
app.py CHANGED
@@ -17,11 +17,46 @@ import lpips
17
  from model import *
18
 
19
 
20
- from e4e_projection import projection as e4e_projection
21
 
22
  from copy import deepcopy
23
  import imageio
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  os.makedirs('models', exist_ok=True)
26
  os.system("gdown https://drive.google.com/uc?id=1jtCg8HQ6RlTmLdnbT2PfW1FJ2AYkWqsK")
27
  os.system("cp e4e_ffhq_encode.pt models/e4e_ffhq_encode.pt")
@@ -113,7 +148,7 @@ generatorspider.load_state_dict(ckptspider["g"], strict=False)
113
  def inference(img, model):
114
  aligned_face = align_face(img)
115
 
116
- my_w = e4e_projection(aligned_face, "test.pt", device).unsqueeze(0)
117
  if model == 'JoJo':
118
  with torch.no_grad():
119
  my_sample = generatorjojo(my_w, input_is_latent=True)
17
  from model import *
18
 
19
 
20
+ #from e4e_projection import projection as e4e_projection
21
 
22
  from copy import deepcopy
23
  import imageio
24
 
25
+ import os
26
+ import sys
27
+ import numpy as np
28
+ from PIL import Image
29
+ import torch
30
+ import torchvision.transforms as transforms
31
+ from argparse import Namespace
32
+ from e4e.models.psp import pSp
33
+ from util import *
34
+
35
+ @ torch.no_grad()
36
+
37
+ def projection(img, name, device='cuda'):
38
+
39
+ model_path = 'models/e4e_ffhq_encode.pt'
40
+ ckpt = torch.load(model_path, map_location='cpu')
41
+ opts = ckpt['opts']
42
+ opts['checkpoint_path'] = model_path
43
+ opts= Namespace(**opts)
44
+ net = pSp(opts, device).eval().to(device)
45
+ transform = transforms.Compose(
46
+ [
47
+ transforms.Resize(256),
48
+ transforms.CenterCrop(256),
49
+ transforms.ToTensor(),
50
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
51
+ ]
52
+ )
53
+ img = transform(img).unsqueeze(0).to(device)
54
+ images, w_plus = net(img, randomize_noise=False, return_latents=True)
55
+ result_file = {}
56
+ result_file['latent'] = w_plus[0]
57
+ torch.save(result_file, name)
58
+ return w_plus[0]
59
+
60
  os.makedirs('models', exist_ok=True)
61
  os.system("gdown https://drive.google.com/uc?id=1jtCg8HQ6RlTmLdnbT2PfW1FJ2AYkWqsK")
62
  os.system("cp e4e_ffhq_encode.pt models/e4e_ffhq_encode.pt")
148
  def inference(img, model):
149
  aligned_face = align_face(img)
150
 
151
+ my_w = projection(aligned_face, "test.pt", device).unsqueeze(0)
152
  if model == 'JoJo':
153
  with torch.no_grad():
154
  my_sample = generatorjojo(my_w, input_is_latent=True)