Spaces:
Runtime error
Runtime error
#! /usr/bin/env python | |
# -*- coding: utf-8 -*- | |
# Copyright 2023 Imperial College London (Pingchuan Ma) | |
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) | |
import os | |
import torch | |
import pickle | |
from configparser import ConfigParser | |
from pipelines.model import AVSR | |
from pipelines.data.data_module import AVSRDataLoader | |
class InferencePipeline(torch.nn.Module): | |
def __init__(self, config_filename, detector="retinaface", face_track=False, device="cuda:0"): | |
super(InferencePipeline, self).__init__() | |
assert os.path.isfile(config_filename), f"config_filename: {config_filename} does not exist." | |
config = ConfigParser() | |
config.read(config_filename) | |
# modality configuration | |
modality = config.get("input", "modality") | |
self.modality = modality | |
# data configuration | |
input_v_fps = config.getfloat("input", "v_fps") | |
model_v_fps = config.getfloat("model", "v_fps") | |
# model configuration | |
model_path = config.get("model","model_path") | |
model_conf = config.get("model","model_conf") | |
# language model configuration | |
rnnlm = config.get("model", "rnnlm") | |
rnnlm_conf = config.get("model", "rnnlm_conf") | |
penalty = config.getfloat("decode", "penalty") | |
ctc_weight = config.getfloat("decode", "ctc_weight") | |
lm_weight = config.getfloat("decode", "lm_weight") | |
beam_size = config.getint("decode", "beam_size") | |
self.dataloader = AVSRDataLoader(modality, speed_rate=input_v_fps/model_v_fps, detector=detector) | |
self.model = AVSR(modality, model_path, model_conf, rnnlm, rnnlm_conf, penalty, ctc_weight, lm_weight, beam_size, device) | |
if face_track and self.modality in ["video", "audiovisual"]: | |
if detector == "mediapipe": | |
from pipelines.detectors.mediapipe.detector import LandmarksDetector | |
self.landmarks_detector = LandmarksDetector() | |
if detector == "retinaface": | |
from pipelines.detectors.retinaface.detector import LandmarksDetector | |
self.landmarks_detector = LandmarksDetector(device="cuda:0") | |
else: | |
self.landmarks_detector = None | |
def process_landmarks(self, data_filename, landmarks_filename): | |
if self.modality == "audio": | |
return None | |
if self.modality in ["video", "audiovisual"]: | |
if isinstance(landmarks_filename, str): | |
landmarks = pickle.load(open(landmarks_filename, "rb")) | |
else: | |
landmarks = self.landmarks_detector(data_filename) | |
return landmarks | |
def forward(self, data_filename, landmarks_filename=None): | |
assert os.path.isfile(data_filename), f"data_filename: {data_filename} does not exist." | |
landmarks = self.process_landmarks(data_filename, landmarks_filename) | |
data = self.dataloader.load_data(data_filename, landmarks) | |
transcript = self.model.infer(data) | |
return transcript |