Spanicin commited on
Commit
af60c81
1 Parent(s): b9a690d

Update videoretalking/third_part/GPEN/face_model/face_gan.py

Browse files
videoretalking/third_part/GPEN/face_model/face_gan.py CHANGED
@@ -1,55 +1,55 @@
1
- '''
2
- @paper: GAN Prior Embedded Network for Blind Face Restoration in the Wild (CVPR2021)
3
- @author: yangxy (yangtao9009@gmail.com)
4
- '''
5
- import torch
6
- import os
7
- import cv2
8
- import glob
9
- import numpy as np
10
- from torch import nn
11
- import torch.nn.functional as F
12
- from torchvision import transforms, utils
13
- from face_model.gpen_model import FullGenerator
14
-
15
- class FaceGAN(object):
16
- def __init__(self, base_dir='./', size=512, model=None, channel_multiplier=2, narrow=1, is_norm=True, device='cuda'):
17
- self.mfile = os.path.join(base_dir, model+'.pth')
18
- self.n_mlp = 8
19
- self.device = device
20
- self.is_norm = is_norm
21
- self.resolution = size
22
- self.load_model(channel_multiplier, narrow)
23
-
24
- def load_model(self, channel_multiplier=2, narrow=1):
25
- self.model = FullGenerator(self.resolution, 512, self.n_mlp, channel_multiplier, narrow=narrow, device=self.device)
26
- pretrained_dict = torch.load(self.mfile, map_location=torch.device('cpu'))
27
- self.model.load_state_dict(pretrained_dict)
28
- self.model.to(self.device)
29
- self.model.eval()
30
-
31
- def process(self, img):
32
- img = cv2.resize(img, (self.resolution, self.resolution))
33
- img_t = self.img2tensor(img)
34
-
35
- with torch.no_grad():
36
- out, __ = self.model(img_t)
37
-
38
- out = self.tensor2img(out)
39
-
40
- return out
41
-
42
- def img2tensor(self, img):
43
- img_t = torch.from_numpy(img).to(self.device)/255.
44
- if self.is_norm:
45
- img_t = (img_t - 0.5) / 0.5
46
- img_t = img_t.permute(2, 0, 1).unsqueeze(0).flip(1) # BGR->RGB
47
- return img_t
48
-
49
- def tensor2img(self, img_t, pmax=255.0, imtype=np.uint8):
50
- if self.is_norm:
51
- img_t = img_t * 0.5 + 0.5
52
- img_t = img_t.squeeze(0).permute(1, 2, 0).flip(2) # RGB->BGR
53
- img_np = np.clip(img_t.float().cpu().numpy(), 0, 1) * pmax
54
-
55
- return img_np.astype(imtype)
 
1
+ '''
2
+ @paper: GAN Prior Embedded Network for Blind Face Restoration in the Wild (CVPR2021)
3
+ @author: yangxy (yangtao9009@gmail.com)
4
+ '''
5
+ import torch
6
+ import os
7
+ import cv2
8
+ import glob
9
+ import numpy as np
10
+ from torch import nn
11
+ import torch.nn.functional as F
12
+ from torchvision import transforms, utils
13
+ from videoretalking.third_part.GPEN.face_model.gpen_model import FullGenerator
14
+
15
+ class FaceGAN(object):
16
+ def __init__(self, base_dir='./', size=512, model=None, channel_multiplier=2, narrow=1, is_norm=True, device='cuda'):
17
+ self.mfile = os.path.join(base_dir, model+'.pth')
18
+ self.n_mlp = 8
19
+ self.device = device
20
+ self.is_norm = is_norm
21
+ self.resolution = size
22
+ self.load_model(channel_multiplier, narrow)
23
+
24
+ def load_model(self, channel_multiplier=2, narrow=1):
25
+ self.model = FullGenerator(self.resolution, 512, self.n_mlp, channel_multiplier, narrow=narrow, device=self.device)
26
+ pretrained_dict = torch.load(self.mfile, map_location=torch.device('cpu'))
27
+ self.model.load_state_dict(pretrained_dict)
28
+ self.model.to(self.device)
29
+ self.model.eval()
30
+
31
+ def process(self, img):
32
+ img = cv2.resize(img, (self.resolution, self.resolution))
33
+ img_t = self.img2tensor(img)
34
+
35
+ with torch.no_grad():
36
+ out, __ = self.model(img_t)
37
+
38
+ out = self.tensor2img(out)
39
+
40
+ return out
41
+
42
+ def img2tensor(self, img):
43
+ img_t = torch.from_numpy(img).to(self.device)/255.
44
+ if self.is_norm:
45
+ img_t = (img_t - 0.5) / 0.5
46
+ img_t = img_t.permute(2, 0, 1).unsqueeze(0).flip(1) # BGR->RGB
47
+ return img_t
48
+
49
+ def tensor2img(self, img_t, pmax=255.0, imtype=np.uint8):
50
+ if self.is_norm:
51
+ img_t = img_t * 0.5 + 0.5
52
+ img_t = img_t.squeeze(0).permute(1, 2, 0).flip(2) # RGB->BGR
53
+ img_np = np.clip(img_t.float().cpu().numpy(), 0, 1) * pmax
54
+
55
+ return img_np.astype(imtype)