Spaces:
Configuration error
Configuration error
File size: 3,717 Bytes
e6a22e6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
from collections import namedtuple
import torch
from torch.utils import model_zoo
import requests
from tqdm import tqdm
from pathlib import Path
from src.FaceDetector.face_detector import FaceDetector
from src.FaceId.faceid import FaceId
from src.Generator.fs_networks_fix import Generator_Adain_Upsample
from src.PostProcess.ParsingModel.model import BiSeNet
from src.PostProcess.GFPGAN.gfpgan import GFPGANer
from src.Blend.blend import BlendModule
model = namedtuple("model", ["url", "model"])
models = {
"face_detector": model(
url="https://github.com/mike9251/simswap-inference-pytorch/releases/download/weights/face_detector_scrfd_10g_bnkps.onnx",
model=FaceDetector,
),
"arcface": model(
url="https://github.com/mike9251/simswap-inference-pytorch/releases/download/weights/arcface_net.jit",
model=FaceId,
),
"generator_224": model(
url="https://github.com/mike9251/simswap-inference-pytorch/releases/download/weights/simswap_224_latest_net_G.pth",
model=Generator_Adain_Upsample,
),
"generator_512": model(
url="https://github.com/mike9251/simswap-inference-pytorch/releases/download/weights/simswap_512_390000_net_G.pth",
model=Generator_Adain_Upsample,
),
"parsing_model": model(
url="https://github.com/mike9251/simswap-inference-pytorch/releases/download/weights/parsing_model_79999_iter.pth",
model=BiSeNet,
),
"gfpgan": model(
url="https://github.com/mike9251/simswap-inference-pytorch/releases/download/v1.1/GFPGANv1.4_ema.pth",
model=GFPGANer,
),
"blend_module": model(
url="https://github.com/mike9251/simswap-inference-pytorch/releases/download/v1.2/blend_module.jit",
model=BlendModule
)
}
def get_model(
model_name: str,
device: torch.device,
load_state_dice: bool,
model_path: Path,
**kwargs,
):
dst_dir = Path.cwd() / "weights"
dst_dir.mkdir(exist_ok=True)
url = models[model_name].url if not model_path.is_file() else str(model_path)
if load_state_dice:
model = models[model_name].model(**kwargs)
if Path(url).is_file():
state_dict = torch.load(url)
else:
state_dict = model_zoo.load_url(
url,
model_dir=str(dst_dir),
progress=True,
map_location="cpu",
)
model.load_state_dict(state_dict)
model.to(device)
model.eval()
else:
dst_path = Path(url)
if not dst_path.is_file():
dst_path = dst_dir / Path(url).name
if not dst_path.is_file():
print(f"Downloading: '{url}' to {dst_path}")
response = requests.get(url, stream=True)
if int(response.status_code) == 200:
file_size = int(response.headers["Content-Length"]) / (2 ** 20)
chunk_size = 1024
bar_format = "{desc}: {percentage:3.0f}%|{bar}| {n:3.1f}M/{total:3.1f}M [{elapsed}<{remaining}]"
with open(dst_path, "wb") as handle:
with tqdm(total=file_size, bar_format=bar_format) as pbar:
for data in response.iter_content(chunk_size=chunk_size):
handle.write(data)
pbar.update(len(data) / (2 ** 20))
else:
raise ValueError(
f"Couldn't download weights {url}. Specify weights for the '{model_name}' model manually."
)
kwargs.update({"model_path": str(dst_path), "device": device})
model = models[model_name].model(**kwargs)
return model
|