deepkyu's picture
Reinitialize demo with published github repository. With Gradio 4.x
16c8067
raw history blame
No virus
3.37 kB
import threading
from pathlib import Path
from nota_wav2lip.demo import Wav2LipModelComparisonDemo
class Wav2LipModelComparisonGradio(Wav2LipModelComparisonDemo):
def __init__(
self,
device='cpu',
result_dir='./temp',
video_label_dict=None,
audio_label_list=None,
default_video='v1',
default_audio='a1'
) -> None:
if audio_label_list is None:
audio_label_list = {}
if video_label_dict is None:
video_label_dict = {}
super().__init__(device, result_dir)
self._video_label_dict = {k: Path(v).with_suffix('.mp4') for k, v in video_label_dict.items()}
self._audio_label_dict = audio_label_list
self._default_video = default_video
self._default_audio = default_audio
self._lock = threading.Lock() # lock for asserting that concurrency_count == 1
def _is_valid_input(self, video_selection, audio_selection):
assert video_selection in self._video_label_dict, \
f"Your input ({video_selection}) is not in {self._video_label_dict}!!!"
assert audio_selection in self._audio_label_dict, \
f"Your input ({audio_selection}) is not in {self._audio_label_dict}!!!"
def generate_original_model(self, video_selection, audio_selection):
try:
self._is_valid_input(video_selection, audio_selection)
with self._lock:
output_video_path, inference_time, inference_fps = \
self.save_as_video(audio_name=audio_selection,
video_name=video_selection,
model_type='wav2lip')
return str(output_video_path), format(inference_time, ".2f"), format(inference_fps, ".1f")
except KeyboardInterrupt:
exit()
except Exception as e:
print(e)
pass
def generate_compressed_model(self, video_selection, audio_selection):
try:
self._is_valid_input(video_selection, audio_selection)
with self._lock:
output_video_path, inference_time, inference_fps = \
self.save_as_video(audio_name=audio_selection,
video_name=video_selection,
model_type='nota_wav2lip')
return str(output_video_path), format(inference_time, ".2f"), format(inference_fps, ".1f")
except KeyboardInterrupt:
exit()
except Exception as e:
print(e)
pass
def switch_video_samples(self, video_selection):
try:
if video_selection not in self._video_label_dict:
return self._video_label_dict[self._default_video]
return self._video_label_dict[video_selection]
except KeyboardInterrupt:
exit()
except Exception as e:
print(e)
pass
def switch_audio_samples(self, audio_selection):
try:
if audio_selection not in self._audio_label_dict:
return self._audio_label_dict[self._default_audio]
return self._audio_label_dict[audio_selection]
except KeyboardInterrupt:
exit()
except Exception as e:
print(e)
pass