File size: 3,717 Bytes
e6a22e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
from collections import namedtuple
import torch
from torch.utils import model_zoo
import requests
from tqdm import tqdm
from pathlib import Path

from src.FaceDetector.face_detector import FaceDetector
from src.FaceId.faceid import FaceId
from src.Generator.fs_networks_fix import Generator_Adain_Upsample
from src.PostProcess.ParsingModel.model import BiSeNet
from src.PostProcess.GFPGAN.gfpgan import GFPGANer
from src.Blend.blend import BlendModule


model = namedtuple("model", ["url", "model"])

models = {
    "face_detector": model(
        url="https://github.com/mike9251/simswap-inference-pytorch/releases/download/weights/face_detector_scrfd_10g_bnkps.onnx",
        model=FaceDetector,
    ),
    "arcface": model(
        url="https://github.com/mike9251/simswap-inference-pytorch/releases/download/weights/arcface_net.jit",
        model=FaceId,
    ),
    "generator_224": model(
        url="https://github.com/mike9251/simswap-inference-pytorch/releases/download/weights/simswap_224_latest_net_G.pth",
        model=Generator_Adain_Upsample,
    ),
    "generator_512": model(
        url="https://github.com/mike9251/simswap-inference-pytorch/releases/download/weights/simswap_512_390000_net_G.pth",
        model=Generator_Adain_Upsample,
    ),
    "parsing_model": model(
        url="https://github.com/mike9251/simswap-inference-pytorch/releases/download/weights/parsing_model_79999_iter.pth",
        model=BiSeNet,
    ),
    "gfpgan": model(
        url="https://github.com/mike9251/simswap-inference-pytorch/releases/download/v1.1/GFPGANv1.4_ema.pth",
        model=GFPGANer,
    ),
    "blend_module": model(
        url="https://github.com/mike9251/simswap-inference-pytorch/releases/download/v1.2/blend_module.jit",
        model=BlendModule
    )
}


def get_model(
        model_name: str,
        device: torch.device,
        load_state_dice: bool,
        model_path: Path,
        **kwargs,
):
    dst_dir = Path.cwd() / "weights"
    dst_dir.mkdir(exist_ok=True)

    url = models[model_name].url if not model_path.is_file() else str(model_path)

    if load_state_dice:
        model = models[model_name].model(**kwargs)

        if Path(url).is_file():
            state_dict = torch.load(url)
        else:
            state_dict = model_zoo.load_url(
                url,
                model_dir=str(dst_dir),
                progress=True,
                map_location="cpu",
            )

        model.load_state_dict(state_dict)

        model.to(device)
        model.eval()
    else:
        dst_path = Path(url)

        if not dst_path.is_file():
            dst_path = dst_dir / Path(url).name

        if not dst_path.is_file():
            print(f"Downloading: '{url}' to {dst_path}")
            response = requests.get(url, stream=True)
            if int(response.status_code) == 200:
                file_size = int(response.headers["Content-Length"]) / (2 ** 20)
                chunk_size = 1024
                bar_format = "{desc}: {percentage:3.0f}%|{bar}| {n:3.1f}M/{total:3.1f}M [{elapsed}<{remaining}]"
                with open(dst_path, "wb") as handle:
                    with tqdm(total=file_size, bar_format=bar_format) as pbar:
                        for data in response.iter_content(chunk_size=chunk_size):
                            handle.write(data)
                            pbar.update(len(data) / (2 ** 20))
            else:
                raise ValueError(
                    f"Couldn't download weights {url}. Specify weights for the '{model_name}' model manually."
                )

        kwargs.update({"model_path": str(dst_path), "device": device})

        model = models[model_name].model(**kwargs)

    return model