Spaces:
Paused
Paused
Update videoretalking/third_part/GPEN/face_model/face_gan.py
Browse files
videoretalking/third_part/GPEN/face_model/face_gan.py
CHANGED
@@ -1,55 +1,55 @@
|
|
1 |
-
'''
|
2 |
-
@paper: GAN Prior Embedded Network for Blind Face Restoration in the Wild (CVPR2021)
|
3 |
-
@author: yangxy (yangtao9009@gmail.com)
|
4 |
-
'''
|
5 |
-
import torch
|
6 |
-
import os
|
7 |
-
import cv2
|
8 |
-
import glob
|
9 |
-
import numpy as np
|
10 |
-
from torch import nn
|
11 |
-
import torch.nn.functional as F
|
12 |
-
from torchvision import transforms, utils
|
13 |
-
from face_model.gpen_model import FullGenerator
|
14 |
-
|
15 |
-
class FaceGAN(object):
|
16 |
-
def __init__(self, base_dir='./', size=512, model=None, channel_multiplier=2, narrow=1, is_norm=True, device='cuda'):
|
17 |
-
self.mfile = os.path.join(base_dir, model+'.pth')
|
18 |
-
self.n_mlp = 8
|
19 |
-
self.device = device
|
20 |
-
self.is_norm = is_norm
|
21 |
-
self.resolution = size
|
22 |
-
self.load_model(channel_multiplier, narrow)
|
23 |
-
|
24 |
-
def load_model(self, channel_multiplier=2, narrow=1):
|
25 |
-
self.model = FullGenerator(self.resolution, 512, self.n_mlp, channel_multiplier, narrow=narrow, device=self.device)
|
26 |
-
pretrained_dict = torch.load(self.mfile, map_location=torch.device('cpu'))
|
27 |
-
self.model.load_state_dict(pretrained_dict)
|
28 |
-
self.model.to(self.device)
|
29 |
-
self.model.eval()
|
30 |
-
|
31 |
-
def process(self, img):
|
32 |
-
img = cv2.resize(img, (self.resolution, self.resolution))
|
33 |
-
img_t = self.img2tensor(img)
|
34 |
-
|
35 |
-
with torch.no_grad():
|
36 |
-
out, __ = self.model(img_t)
|
37 |
-
|
38 |
-
out = self.tensor2img(out)
|
39 |
-
|
40 |
-
return out
|
41 |
-
|
42 |
-
def img2tensor(self, img):
|
43 |
-
img_t = torch.from_numpy(img).to(self.device)/255.
|
44 |
-
if self.is_norm:
|
45 |
-
img_t = (img_t - 0.5) / 0.5
|
46 |
-
img_t = img_t.permute(2, 0, 1).unsqueeze(0).flip(1) # BGR->RGB
|
47 |
-
return img_t
|
48 |
-
|
49 |
-
def tensor2img(self, img_t, pmax=255.0, imtype=np.uint8):
|
50 |
-
if self.is_norm:
|
51 |
-
img_t = img_t * 0.5 + 0.5
|
52 |
-
img_t = img_t.squeeze(0).permute(1, 2, 0).flip(2) # RGB->BGR
|
53 |
-
img_np = np.clip(img_t.float().cpu().numpy(), 0, 1) * pmax
|
54 |
-
|
55 |
-
return img_np.astype(imtype)
|
|
|
1 |
+
'''
|
2 |
+
@paper: GAN Prior Embedded Network for Blind Face Restoration in the Wild (CVPR2021)
|
3 |
+
@author: yangxy (yangtao9009@gmail.com)
|
4 |
+
'''
|
5 |
+
import torch
|
6 |
+
import os
|
7 |
+
import cv2
|
8 |
+
import glob
|
9 |
+
import numpy as np
|
10 |
+
from torch import nn
|
11 |
+
import torch.nn.functional as F
|
12 |
+
from torchvision import transforms, utils
|
13 |
+
from videoretalking.third_part.GPEN.face_model.gpen_model import FullGenerator
|
14 |
+
|
15 |
+
class FaceGAN(object):
|
16 |
+
def __init__(self, base_dir='./', size=512, model=None, channel_multiplier=2, narrow=1, is_norm=True, device='cuda'):
|
17 |
+
self.mfile = os.path.join(base_dir, model+'.pth')
|
18 |
+
self.n_mlp = 8
|
19 |
+
self.device = device
|
20 |
+
self.is_norm = is_norm
|
21 |
+
self.resolution = size
|
22 |
+
self.load_model(channel_multiplier, narrow)
|
23 |
+
|
24 |
+
def load_model(self, channel_multiplier=2, narrow=1):
|
25 |
+
self.model = FullGenerator(self.resolution, 512, self.n_mlp, channel_multiplier, narrow=narrow, device=self.device)
|
26 |
+
pretrained_dict = torch.load(self.mfile, map_location=torch.device('cpu'))
|
27 |
+
self.model.load_state_dict(pretrained_dict)
|
28 |
+
self.model.to(self.device)
|
29 |
+
self.model.eval()
|
30 |
+
|
31 |
+
def process(self, img):
|
32 |
+
img = cv2.resize(img, (self.resolution, self.resolution))
|
33 |
+
img_t = self.img2tensor(img)
|
34 |
+
|
35 |
+
with torch.no_grad():
|
36 |
+
out, __ = self.model(img_t)
|
37 |
+
|
38 |
+
out = self.tensor2img(out)
|
39 |
+
|
40 |
+
return out
|
41 |
+
|
42 |
+
def img2tensor(self, img):
|
43 |
+
img_t = torch.from_numpy(img).to(self.device)/255.
|
44 |
+
if self.is_norm:
|
45 |
+
img_t = (img_t - 0.5) / 0.5
|
46 |
+
img_t = img_t.permute(2, 0, 1).unsqueeze(0).flip(1) # BGR->RGB
|
47 |
+
return img_t
|
48 |
+
|
49 |
+
def tensor2img(self, img_t, pmax=255.0, imtype=np.uint8):
|
50 |
+
if self.is_norm:
|
51 |
+
img_t = img_t * 0.5 + 0.5
|
52 |
+
img_t = img_t.squeeze(0).permute(1, 2, 0).flip(2) # RGB->BGR
|
53 |
+
img_np = np.clip(img_t.float().cpu().numpy(), 0, 1) * pmax
|
54 |
+
|
55 |
+
return img_np.astype(imtype)
|