Spaces:
Build error
Build error
Upload model.py (#5)
Browse files- Upload model.py (125f2597f0105a40caa7147feb1e0f39cf5eed0e)
Co-authored-by: leonel hernandez <leonelhs@users.noreply.huggingface.co>
- models/model.py +29 -56
models/model.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
import os
|
|
|
|
| 2 |
from typing import Dict
|
| 3 |
-
from typing import Optional
|
| 4 |
from typing import Tuple
|
| 5 |
|
| 6 |
import kornia
|
|
@@ -10,28 +10,30 @@ import torch.nn as nn
|
|
| 10 |
import torch.nn.functional as F
|
| 11 |
from loguru import logger
|
| 12 |
|
| 13 |
-
from arcface_torch.backbones.iresnet import iresnet100
|
| 14 |
-
from configs.train_config import TrainConfig
|
| 15 |
from Deep3DFaceRecon_pytorch.models.bfm import ParametricFaceModel
|
| 16 |
from Deep3DFaceRecon_pytorch.models.networks import ReconNetWrapper
|
| 17 |
from HRNet.hrnet import HighResolutionNet
|
|
|
|
| 18 |
from models.discriminator import Discriminator
|
| 19 |
from models.gan_loss import GANLoss
|
| 20 |
from models.generator import Generator
|
| 21 |
from models.init_weight import init_net
|
| 22 |
|
| 23 |
-
|
| 24 |
class HifiFace:
|
| 25 |
def __init__(
|
| 26 |
self,
|
| 27 |
identity_extractor_config,
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
):
|
| 32 |
super(HifiFace, self).__init__()
|
|
|
|
|
|
|
| 33 |
self.generator = Generator(identity_extractor_config)
|
| 34 |
self.is_training = is_training
|
|
|
|
|
|
|
| 35 |
|
| 36 |
if self.is_training:
|
| 37 |
self.lr = TrainConfig().lr
|
|
@@ -80,10 +82,9 @@ class HifiFace:
|
|
| 80 |
|
| 81 |
self.dilation_kernel = torch.ones(5, 5)
|
| 82 |
|
| 83 |
-
|
| 84 |
-
self.load(load_checkpoint[0], load_checkpoint[1])
|
| 85 |
|
| 86 |
-
self.setup(device)
|
| 87 |
|
| 88 |
def save(self, path, idx=None):
|
| 89 |
os.makedirs(path, exist_ok=True)
|
|
@@ -100,18 +101,9 @@ class HifiFace:
|
|
| 100 |
torch.save(self.generator.state_dict(), g_path)
|
| 101 |
torch.save(self.discriminator.state_dict(), d_path)
|
| 102 |
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
d_path = os.path.join(path, "discriminator.pth")
|
| 107 |
-
else:
|
| 108 |
-
g_path = os.path.join(path, f"generator_{idx}.pth")
|
| 109 |
-
d_path = os.path.join(path, f"discriminator_{idx}.pth")
|
| 110 |
-
logger.info(f"Loading generator from {g_path}")
|
| 111 |
-
self.generator.load_state_dict(torch.load(g_path, map_location="cpu"))
|
| 112 |
-
if self.is_training:
|
| 113 |
-
logger.info(f"Loading discriminator from {d_path}")
|
| 114 |
-
self.discriminator.load_state_dict(torch.load(d_path, map_location="cpu"))
|
| 115 |
|
| 116 |
def setup(self, device):
|
| 117 |
self.generator.to(device)
|
|
@@ -399,37 +391,18 @@ class HifiFace:
|
|
| 399 |
}
|
| 400 |
|
| 401 |
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
|
| 414 |
-
|
| 415 |
-
|
| 416 |
-
|
| 417 |
-
# src = src.transpose(2, 0, 1)[None, ...]
|
| 418 |
-
# tgt = tgt.transpose(2, 0, 1)[None, ...]
|
| 419 |
-
# source_img = torch.from_numpy(src).float() / 255.0
|
| 420 |
-
# target_img = torch.from_numpy(tgt).float() / 255.0
|
| 421 |
-
# same_id_mask = torch.Tensor([1]).unsqueeze(0)
|
| 422 |
-
# tgt_mask = target_img[:, 0, :, :].unsqueeze(1)
|
| 423 |
-
# if torch.cuda.is_available():
|
| 424 |
-
# model.to("cuda:3")
|
| 425 |
-
# source_img = source_img.to("cuda:3")
|
| 426 |
-
# target_img = target_img.to("cuda:3")
|
| 427 |
-
# tgt_mask = tgt_mask.to("cuda:3")
|
| 428 |
-
# same_id_mask = same_id_mask.to("cuda:3")
|
| 429 |
-
# source_img = source_img.repeat(16, 1, 1, 1)
|
| 430 |
-
# target_img = target_img.repeat(16, 1, 1, 1)
|
| 431 |
-
# tgt_mask = tgt_mask.repeat(16, 1, 1, 1)
|
| 432 |
-
# same_id_mask = same_id_mask.repeat(16, 1)
|
| 433 |
-
# while True:
|
| 434 |
-
# x = model.optimize(source_img, target_img, tgt_mask, same_id_mask)
|
| 435 |
-
# print(x[0]["loss_generator"])
|
|
|
|
| 1 |
import os
|
| 2 |
+
from abc import abstractmethod
|
| 3 |
from typing import Dict
|
|
|
|
| 4 |
from typing import Tuple
|
| 5 |
|
| 6 |
import kornia
|
|
|
|
| 10 |
import torch.nn.functional as F
|
| 11 |
from loguru import logger
|
| 12 |
|
|
|
|
|
|
|
| 13 |
from Deep3DFaceRecon_pytorch.models.bfm import ParametricFaceModel
|
| 14 |
from Deep3DFaceRecon_pytorch.models.networks import ReconNetWrapper
|
| 15 |
from HRNet.hrnet import HighResolutionNet
|
| 16 |
+
from arcface_torch.backbones.iresnet import iresnet100
|
| 17 |
from models.discriminator import Discriminator
|
| 18 |
from models.gan_loss import GANLoss
|
| 19 |
from models.generator import Generator
|
| 20 |
from models.init_weight import init_net
|
| 21 |
|
|
|
|
| 22 |
class HifiFace:
|
| 23 |
def __init__(
|
| 24 |
self,
|
| 25 |
identity_extractor_config,
|
| 26 |
+
generator_path,
|
| 27 |
+
is_training=False,
|
| 28 |
+
device="cpu"
|
| 29 |
):
|
| 30 |
super(HifiFace, self).__init__()
|
| 31 |
+
self.d_optimizer = None
|
| 32 |
+
self.g_optimizer = None
|
| 33 |
self.generator = Generator(identity_extractor_config)
|
| 34 |
self.is_training = is_training
|
| 35 |
+
self.device = device
|
| 36 |
+
self.generator_path = generator_path
|
| 37 |
|
| 38 |
if self.is_training:
|
| 39 |
self.lr = TrainConfig().lr
|
|
|
|
| 82 |
|
| 83 |
self.dilation_kernel = torch.ones(5, 5)
|
| 84 |
|
| 85 |
+
self.load_checkpoint()
|
|
|
|
| 86 |
|
| 87 |
+
self.setup(self.device)
|
| 88 |
|
| 89 |
def save(self, path, idx=None):
|
| 90 |
os.makedirs(path, exist_ok=True)
|
|
|
|
| 101 |
torch.save(self.generator.state_dict(), g_path)
|
| 102 |
torch.save(self.discriminator.state_dict(), d_path)
|
| 103 |
|
| 104 |
+
@abstractmethod
|
| 105 |
+
def load_checkpoint(self):
|
| 106 |
+
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
|
| 108 |
def setup(self, device):
|
| 109 |
self.generator.to(device)
|
|
|
|
| 391 |
}
|
| 392 |
|
| 393 |
|
| 394 |
+
class HifiFaceST(HifiFace):
|
| 395 |
+
def __init__(self, identity_extractor_config, device, generator_path):
|
| 396 |
+
super().__init__(identity_extractor_config, device=device, generator_path=generator_path)
|
| 397 |
+
|
| 398 |
+
def load_checkpoint(self):
|
| 399 |
+
self.generator.load_state_dict(torch.load(self.generator_path, map_location=self.device))
|
| 400 |
+
logger.info(f"Loading generator from {self.generator_path}")
|
| 401 |
+
|
| 402 |
+
class HifiFaceWGM(HifiFace):
|
| 403 |
+
def __init__(self, identity_extractor_config, device, generator_path):
|
| 404 |
+
super().__init__(identity_extractor_config, device=device, generator_path=generator_path)
|
| 405 |
+
|
| 406 |
+
def load_checkpoint(self):
|
| 407 |
+
self.generator.load_state_dict(torch.load(self.generator_path, map_location=self.device))
|
| 408 |
+
logger.info(f"Loading generator from {self.generator_path}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|