hyxue leonelhs commited on
Commit
51094e2
·
verified ·
1 Parent(s): 54a5078

Upload model.py (#5)

Browse files

- Upload model.py (125f2597f0105a40caa7147feb1e0f39cf5eed0e)


Co-authored-by: leonel hernandez <leonelhs@users.noreply.huggingface.co>

Files changed (1) hide show
  1. models/model.py +29 -56
models/model.py CHANGED
@@ -1,6 +1,6 @@
1
  import os
 
2
  from typing import Dict
3
- from typing import Optional
4
  from typing import Tuple
5
 
6
  import kornia
@@ -10,28 +10,30 @@ import torch.nn as nn
10
  import torch.nn.functional as F
11
  from loguru import logger
12
 
13
- from arcface_torch.backbones.iresnet import iresnet100
14
- from configs.train_config import TrainConfig
15
  from Deep3DFaceRecon_pytorch.models.bfm import ParametricFaceModel
16
  from Deep3DFaceRecon_pytorch.models.networks import ReconNetWrapper
17
  from HRNet.hrnet import HighResolutionNet
 
18
  from models.discriminator import Discriminator
19
  from models.gan_loss import GANLoss
20
  from models.generator import Generator
21
  from models.init_weight import init_net
22
 
23
-
24
  class HifiFace:
25
  def __init__(
26
  self,
27
  identity_extractor_config,
28
- is_training=True,
29
- device="cpu",
30
- load_checkpoint: Optional[Tuple[str, int]] = None,
31
  ):
32
  super(HifiFace, self).__init__()
 
 
33
  self.generator = Generator(identity_extractor_config)
34
  self.is_training = is_training
 
 
35
 
36
  if self.is_training:
37
  self.lr = TrainConfig().lr
@@ -80,10 +82,9 @@ class HifiFace:
80
 
81
  self.dilation_kernel = torch.ones(5, 5)
82
 
83
- if load_checkpoint is not None:
84
- self.load(load_checkpoint[0], load_checkpoint[1])
85
 
86
- self.setup(device)
87
 
88
  def save(self, path, idx=None):
89
  os.makedirs(path, exist_ok=True)
@@ -100,18 +101,9 @@ class HifiFace:
100
  torch.save(self.generator.state_dict(), g_path)
101
  torch.save(self.discriminator.state_dict(), d_path)
102
 
103
- def load(self, path, idx=None):
104
- if idx is None:
105
- g_path = os.path.join(path, "generator.pth")
106
- d_path = os.path.join(path, "discriminator.pth")
107
- else:
108
- g_path = os.path.join(path, f"generator_{idx}.pth")
109
- d_path = os.path.join(path, f"discriminator_{idx}.pth")
110
- logger.info(f"Loading generator from {g_path}")
111
- self.generator.load_state_dict(torch.load(g_path, map_location="cpu"))
112
- if self.is_training:
113
- logger.info(f"Loading discriminator from {d_path}")
114
- self.discriminator.load_state_dict(torch.load(d_path, map_location="cpu"))
115
 
116
  def setup(self, device):
117
  self.generator.to(device)
@@ -399,37 +391,18 @@ class HifiFace:
399
  }
400
 
401
 
402
- if __name__ == "__main__":
403
- import torch
404
- import cv2
405
- from configs.train_config import TrainConfig
406
-
407
- identity_extractor_config = TrainConfig().identity_extractor_config
408
-
409
- model = HifiFace(identity_extractor_config, is_training=True)
410
-
411
- # src = cv2.imread("/home/xuehongyang/data/test1.jpg")
412
- # tgt = cv2.imread("/home/xuehongyang/data/test2.jpg")
413
- # src = cv2.cvtColor(src, cv2.COLOR_BGR2RGB)
414
- # tgt = cv2.cvtColor(tgt, cv2.COLOR_BGR2RGB)
415
- # src = cv2.resize(src, (256, 256))
416
- # tgt = cv2.resize(tgt, (256, 256))
417
- # src = src.transpose(2, 0, 1)[None, ...]
418
- # tgt = tgt.transpose(2, 0, 1)[None, ...]
419
- # source_img = torch.from_numpy(src).float() / 255.0
420
- # target_img = torch.from_numpy(tgt).float() / 255.0
421
- # same_id_mask = torch.Tensor([1]).unsqueeze(0)
422
- # tgt_mask = target_img[:, 0, :, :].unsqueeze(1)
423
- # if torch.cuda.is_available():
424
- # model.to("cuda:3")
425
- # source_img = source_img.to("cuda:3")
426
- # target_img = target_img.to("cuda:3")
427
- # tgt_mask = tgt_mask.to("cuda:3")
428
- # same_id_mask = same_id_mask.to("cuda:3")
429
- # source_img = source_img.repeat(16, 1, 1, 1)
430
- # target_img = target_img.repeat(16, 1, 1, 1)
431
- # tgt_mask = tgt_mask.repeat(16, 1, 1, 1)
432
- # same_id_mask = same_id_mask.repeat(16, 1)
433
- # while True:
434
- # x = model.optimize(source_img, target_img, tgt_mask, same_id_mask)
435
- # print(x[0]["loss_generator"])
 
1
  import os
2
+ from abc import abstractmethod
3
  from typing import Dict
 
4
  from typing import Tuple
5
 
6
  import kornia
 
10
  import torch.nn.functional as F
11
  from loguru import logger
12
 
 
 
13
  from Deep3DFaceRecon_pytorch.models.bfm import ParametricFaceModel
14
  from Deep3DFaceRecon_pytorch.models.networks import ReconNetWrapper
15
  from HRNet.hrnet import HighResolutionNet
16
+ from arcface_torch.backbones.iresnet import iresnet100
17
  from models.discriminator import Discriminator
18
  from models.gan_loss import GANLoss
19
  from models.generator import Generator
20
  from models.init_weight import init_net
21
 
 
22
  class HifiFace:
23
  def __init__(
24
  self,
25
  identity_extractor_config,
26
+ generator_path,
27
+ is_training=False,
28
+ device="cpu"
29
  ):
30
  super(HifiFace, self).__init__()
31
+ self.d_optimizer = None
32
+ self.g_optimizer = None
33
  self.generator = Generator(identity_extractor_config)
34
  self.is_training = is_training
35
+ self.device = device
36
+ self.generator_path = generator_path
37
 
38
  if self.is_training:
39
  self.lr = TrainConfig().lr
 
82
 
83
  self.dilation_kernel = torch.ones(5, 5)
84
 
85
+ self.load_checkpoint()
 
86
 
87
+ self.setup(self.device)
88
 
89
  def save(self, path, idx=None):
90
  os.makedirs(path, exist_ok=True)
 
101
  torch.save(self.generator.state_dict(), g_path)
102
  torch.save(self.discriminator.state_dict(), d_path)
103
 
104
+ @abstractmethod
105
+ def load_checkpoint(self):
106
+ pass
 
 
 
 
 
 
 
 
 
107
 
108
  def setup(self, device):
109
  self.generator.to(device)
 
391
  }
392
 
393
 
394
+ class HifiFaceST(HifiFace):
395
+ def __init__(self, identity_extractor_config, device, generator_path):
396
+ super().__init__(identity_extractor_config, device=device, generator_path=generator_path)
397
+
398
+ def load_checkpoint(self):
399
+ self.generator.load_state_dict(torch.load(self.generator_path, map_location=self.device))
400
+ logger.info(f"Loading generator from {self.generator_path}")
401
+
402
+ class HifiFaceWGM(HifiFace):
403
+ def __init__(self, identity_extractor_config, device, generator_path):
404
+ super().__init__(identity_extractor_config, device=device, generator_path=generator_path)
405
+
406
+ def load_checkpoint(self):
407
+ self.generator.load_state_dict(torch.load(self.generator_path, map_location=self.device))
408
+ logger.info(f"Loading generator from {self.generator_path}")