Spaces:
Runtime error
Runtime error
from typing import Dict, Type | |
import torch | |
from nota_wav2lip.models import NotaWav2Lip, Wav2Lip, Wav2LipBase | |
MODEL_REGISTRY: Dict[str, Type[Wav2LipBase]] = { | |
'wav2lip': Wav2Lip, | |
'nota_wav2lip': NotaWav2Lip | |
} | |
def _load(checkpoint_path, device): | |
assert device in ['cpu', 'cuda'] | |
print(f"Load checkpoint from: {checkpoint_path}") | |
if device == 'cuda': | |
return torch.load(checkpoint_path) | |
return torch.load(checkpoint_path, map_location=lambda storage, _: storage) | |
def load_model(model_name: str, device, checkpoint, **kwargs) -> Wav2LipBase: | |
cls = MODEL_REGISTRY[model_name.lower()] | |
assert issubclass(cls, Wav2LipBase) | |
model = cls(**kwargs) | |
checkpoint = _load(checkpoint, device) | |
model.load_state_dict(checkpoint) | |
model = model.to(device) | |
return model.eval() | |
def count_params(model): | |
return sum(p.numel() for p in model.parameters()) | |