from typing import Iterable, Iterator, List, Tuple import cv2 import numpy as np import torch import torch.nn as nn from omegaconf import DictConfig from tqdm import tqdm from config import hparams as hp from nota_wav2lip.models.util import count_params, load_model class Wav2LipInferenceImpl: def __init__(self, model_name: str, hp_inference_model: DictConfig, device='cpu'): self.model: nn.Module = load_model( model_name, device=device, **hp_inference_model ) self.device = device self._params: str = self._format_param(count_params(self.model)) @property def params(self): return self._params @staticmethod def _format_param(num_params: int) -> str: params_in_million = num_params / 1e6 return f"{params_in_million:.1f}M" @staticmethod def _reset_batch() -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray], List[List[int]]]: return [], [], [], [] def get_data_iterator( self, audio_iterable: Iterable[np.ndarray], video_iterable: List[Tuple[np.ndarray, List[int]]] ) -> Iterator[Tuple[np.ndarray, np.ndarray, np.ndarray, List[int]]]: img_batch, mel_batch, frame_batch, coords_batch = self._reset_batch() for i, m in enumerate(audio_iterable): idx = i % len(video_iterable) _frame_to_save, coords = video_iterable[idx] frame_to_save = _frame_to_save.copy() face = frame_to_save[coords[0]:coords[1], coords[2]:coords[3]].copy() face: np.ndarray = cv2.resize(face, (hp.face.img_size, hp.face.img_size)) img_batch.append(face) mel_batch.append(m) frame_batch.append(frame_to_save) coords_batch.append(coords) if len(img_batch) >= hp.inference.batch_size: img_batch = np.asarray(img_batch) mel_batch = np.asarray(mel_batch) img_masked = img_batch.copy() img_masked[:, hp.face.img_size // 2:] = 0 img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255. mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1]) yield img_batch, mel_batch, frame_batch, coords_batch img_batch, mel_batch, frame_batch, coords_batch = self._reset_batch() if len(img_batch) > 0: img_batch = np.asarray(img_batch) mel_batch = np.asarray(mel_batch) img_masked = img_batch.copy() img_masked[:, hp.face.img_size // 2:] = 0 img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255. mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1]) yield img_batch, mel_batch, frame_batch, coords_batch @torch.no_grad() def inference_with_iterator( self, audio_iterable: Iterable[np.ndarray], video_iterable: List[Tuple[np.ndarray, List[int]]] ) -> Iterator[np.ndarray]: data_iterator = self.get_data_iterator(audio_iterable, video_iterable) for (img_batch, mel_batch, frames, coords) in \ tqdm(data_iterator, total=int(np.ceil(float(len(audio_iterable)) / hp.inference.batch_size))): img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(self.device) mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(self.device) preds: torch.Tensor = self.forward(mel_batch, img_batch) preds = preds.cpu().numpy().transpose(0, 2, 3, 1) * 255. for pred, frame, coord in zip(preds, frames, coords): y1, y2, x1, x2 = coord pred = cv2.resize(pred.astype(np.uint8), (x2 - x1, y2 - y1)) frame[y1:y2, x1:x2] = pred yield frame @torch.no_grad() def forward(self, audio_sequences: torch.Tensor, face_sequences: torch.Tensor) -> torch.Tensor: return self.model(audio_sequences, face_sequences) def __call__(self, *args, **kwargs): return self.forward(*args, **kwargs)