Spaces:
Runtime error
Runtime error
File size: 2,927 Bytes
09481f3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 |
#! /usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2023 Imperial College London (Pingchuan Ma)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
import torch
import torchaudio
import torchvision
from .transforms import AudioTransform, VideoTransform
class AVSRDataLoader:
def __init__(self, modality, speed_rate=1, transform=True, detector="retinaface", convert_gray=True):
self.modality = modality
self.transform = transform
if self.modality in ["audio", "audiovisual"]:
self.audio_transform = AudioTransform()
if self.modality in ["video", "audiovisual"]:
if detector == "mediapipe":
from pipelines.detectors.mediapipe.video_process import VideoProcess
self.video_process = VideoProcess(convert_gray=convert_gray)
if detector == "retinaface":
from pipelines.detectors.retinaface.video_process import VideoProcess
self.video_process = VideoProcess(convert_gray=convert_gray)
self.video_transform = VideoTransform(speed_rate=speed_rate)
def load_data(self, data_filename, landmarks=None, transform=True):
if self.modality == "audio":
audio, sample_rate = self.load_audio(data_filename)
audio = self.audio_process(audio, sample_rate)
return self.audio_transform(audio) if self.transform else audio
if self.modality == "video":
video = self.load_video(data_filename)
video = self.video_process(video, landmarks)
video = torch.tensor(video)
return self.video_transform(video) if self.transform else video
if self.modality == "audiovisual":
rate_ratio = 640
audio, sample_rate = self.load_audio(data_filename)
audio = self.audio_process(audio, sample_rate)
video = self.load_video(data_filename)
video = self.video_process(video, landmarks)
video = torch.tensor(video)
min_t = min(len(video), audio.size(1) // rate_ratio)
audio = audio[:, :min_t*rate_ratio]
video = video[:min_t]
if self.transform:
audio = self.audio_transform(audio)
video = self.video_transform(video)
return video, audio
def load_audio(self, data_filename):
waveform, sample_rate = torchaudio.load(data_filename, normalize=True)
return waveform, sample_rate
def load_video(self, data_filename):
return torchvision.io.read_video(data_filename, pts_unit='sec')[0].numpy()
def audio_process(self, waveform, sample_rate, target_sample_rate=16000):
if sample_rate != target_sample_rate:
waveform = torchaudio.functional.resample(waveform, sample_rate, target_sample_rate)
waveform = torch.mean(waveform, dim=0, keepdim=True)
return waveform
|