#! /usr/bin/env python # -*- coding: utf-8 -*- # Copyright 2023 Imperial College London (Pingchuan Ma) # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) import torch import torchaudio import torchvision from .transforms import AudioTransform, VideoTransform class AVSRDataLoader: def __init__(self, modality, speed_rate=1, transform=True, detector="retinaface", convert_gray=True): self.modality = modality self.transform = transform if self.modality in ["audio", "audiovisual"]: self.audio_transform = AudioTransform() if self.modality in ["video", "audiovisual"]: if detector == "mediapipe": from pipelines.detectors.mediapipe.video_process import VideoProcess self.video_process = VideoProcess(convert_gray=convert_gray) if detector == "retinaface": from pipelines.detectors.retinaface.video_process import VideoProcess self.video_process = VideoProcess(convert_gray=convert_gray) self.video_transform = VideoTransform(speed_rate=speed_rate) def load_data(self, data_filename, landmarks=None, transform=True): if self.modality == "audio": audio, sample_rate = self.load_audio(data_filename) audio = self.audio_process(audio, sample_rate) return self.audio_transform(audio) if self.transform else audio if self.modality == "video": video = self.load_video(data_filename) video = self.video_process(video, landmarks) video = torch.tensor(video) return self.video_transform(video) if self.transform else video if self.modality == "audiovisual": rate_ratio = 640 audio, sample_rate = self.load_audio(data_filename) audio = self.audio_process(audio, sample_rate) video = self.load_video(data_filename) video = self.video_process(video, landmarks) video = torch.tensor(video) min_t = min(len(video), audio.size(1) // rate_ratio) audio = audio[:, :min_t*rate_ratio] video = video[:min_t] if self.transform: audio = self.audio_transform(audio) video = self.video_transform(video) return video, audio def load_audio(self, data_filename): waveform, sample_rate = torchaudio.load(data_filename, normalize=True) return waveform, sample_rate def load_video(self, data_filename): return torchvision.io.read_video(data_filename, pts_unit='sec')[0].numpy() def audio_process(self, waveform, sample_rate, target_sample_rate=16000): if sample_rate != target_sample_rate: waveform = torchaudio.functional.resample(waveform, sample_rate, target_sample_rate) waveform = torch.mean(waveform, dim=0, keepdim=True) return waveform