from typing import Dict, Type import torch from nota_wav2lip.models import NotaWav2Lip, Wav2Lip, Wav2LipBase MODEL_REGISTRY: Dict[str, Type[Wav2LipBase]] = { 'wav2lip': Wav2Lip, 'nota_wav2lip': NotaWav2Lip } def _load(checkpoint_path, device): assert device in ['cpu', 'cuda'] print(f"Load checkpoint from: {checkpoint_path}") if device == 'cuda': return torch.load(checkpoint_path) return torch.load(checkpoint_path, map_location=lambda storage, _: storage) def load_model(model_name: str, device, checkpoint, **kwargs) -> Wav2LipBase: cls = MODEL_REGISTRY[model_name.lower()] assert issubclass(cls, Wav2LipBase) model = cls(**kwargs) checkpoint = _load(checkpoint, device) model.load_state_dict(checkpoint) model = model.to(device) return model.eval() def count_params(model): return sum(p.numel() for p in model.parameters())