from collections import namedtuple import torch from torch.utils import model_zoo import requests from tqdm import tqdm from pathlib import Path from src.FaceDetector.face_detector import FaceDetector from src.FaceId.faceid import FaceId from src.Generator.fs_networks_fix import Generator_Adain_Upsample from src.PostProcess.ParsingModel.model import BiSeNet from src.PostProcess.GFPGAN.gfpgan import GFPGANer from src.Blend.blend import BlendModule model = namedtuple("model", ["url", "model"]) models = { "face_detector": model( url="", model=FaceDetector, ), "arcface": model( url="", model=FaceId, ), "generator_224": model( url="", model=Generator_Adain_Upsample, ), "generator_512": model( url="", model=Generator_Adain_Upsample, ), "parsing_model": model( url="", model=BiSeNet, ), "gfpgan": model( url="", model=GFPGANer, ), "blend_module": model( url="", model=BlendModule ) } def get_model( model_name: str, device: torch.device, load_state_dice: bool, model_path: Path, **kwargs, ): dst_dir = Path.cwd() / "weights" dst_dir.mkdir(exist_ok=True) url = models[model_name].url if not model_path.is_file() else str(model_path) if load_state_dice: model = models[model_name].model(**kwargs) if Path(url).is_file(): state_dict = torch.load(url) else: state_dict = model_zoo.load_url( url, model_dir=str(dst_dir), progress=True, map_location="cpu", ) model.load_state_dict(state_dict) model.eval() else: dst_path = Path(url) if not dst_path.is_file(): dst_path = dst_dir / Path(url).name if not dst_path.is_file(): print(f"Downloading: '{url}' to {dst_path}") response = requests.get(url, stream=True) if int(response.status_code) == 200: file_size = int(response.headers["Content-Length"]) / (2 ** 20) chunk_size = 1024 bar_format = "{desc}: {percentage:3.0f}%|{bar}| {n:3.1f}M/{total:3.1f}M [{elapsed}<{remaining}]" with open(dst_path, "wb") as handle: with tqdm(total=file_size, bar_format=bar_format) as pbar: for data in response.iter_content(chunk_size=chunk_size): handle.write(data) pbar.update(len(data) / (2 ** 20)) else: raise ValueError( f"Couldn't download weights {url}. Specify weights for the '{model_name}' model manually." ) kwargs.update({"model_path": str(dst_path), "device": device}) model = models[model_name].model(**kwargs) return model