aiface's picture
Upload 11 files
907b7f3
raw history blame
No virus
7.41 kB
import os
import glob
import torch
import random
import librosa
import numpy as np
import sys
from lipreading.utils import read_txt_lines
# dataloaders.py에서 사용된 MyDataset
# dsets = {partition: MyDataset(
# modality=args.modality,
# data_partition=partition,
# data_dir=args.data_dir,
# label_fp=args.label_path,
# annonation_direc=args.annonation_direc,
# preprocessing_func=preprocessing[partition],
# data_suffix='.npz'
# ) for partition in ['train', 'val', 'test']}
class MyDataset(object):
def __init__(self, modality, data_partition, data_dir, label_fp, annonation_direc=None,
preprocessing_func=None, data_suffix='.npz'):
assert os.path.isfile( label_fp ), "File path provided for the labels does not exist. Path iput: {}".format(label_fp)
self._data_partition = data_partition
self._data_dir = data_dir
self._data_suffix = data_suffix
self._label_fp = label_fp
self._annonation_direc = annonation_direc
self.fps = 25 if modality == "video" else 16000
self.is_var_length = True
self.label_idx = -3
self.preprocessing_func = preprocessing_func
self._data_files = []
self.load_dataset()
def load_dataset(self):
# -- read the labels file
self._labels = read_txt_lines(self._label_fp)
# -- add examples to self._data_files
self._get_files_for_partition()
# -- from self._data_files to self.list
self.list = dict()
self.instance_ids = dict()
for i, x in enumerate(self._data_files):
label = self._get_label_from_path( x )
self.list[i] = [ x, self._labels.index( label ) ]
self.instance_ids[i] = self._get_instance_id_from_path( x )
print('Partition {} loaded'.format(self._data_partition))
def _get_instance_id_from_path(self, x):
# for now this works for npz/npys, might break for image folders
instance_id = x.split('/')[-1]
return os.path.splitext( instance_id )[0]
def _get_label_from_path(self, x):
return x.split('/')[self.label_idx]
def _get_files_for_partition(self): ##### 여기 확인!!
# get rgb/mfcc file paths
dir_fp = self._data_dir
if not dir_fp:
return
# get npy/npz/mp4 files
search_str_npz = os.path.join(dir_fp, '*', self._data_partition, '*.npz') # npz : 여러개의 리스트를 한번에 저장하기 위한 포맷
search_str_npy = os.path.join(dir_fp, '*', self._data_partition, '*.npy') # npy : 하나의 numpy array를 저장하기 위한 포맷
search_str_mp4 = os.path.join(dir_fp, '*', self._data_partition, '*.mp4')
self._data_files.extend( glob.glob( search_str_npz ) ) # list.extend() : npz파일명을 _data_files에 추가한다.
self._data_files.extend( glob.glob( search_str_npy ) ) # list.extend() : npy파일명을 _data_files에 추가한다.
self._data_files.extend( glob.glob( search_str_mp4 ) ) # list.extend() : mp4파일명을 _data_files에 추가한다.
# If we are not using the full set of labels, remove examples for labels not used
self._data_files = [ f for f in self._data_files if f.split('/')[self.label_idx] in self._labels ]
def load_data(self, filename):
try:
if filename.endswith('npz'): # endswith(문자열) : 해당 문자열로 끝나는지 여부를 true/false로 반환
# return np.load(filename, allow_pickle=True)['data']
return np.load(filename)['data']
elif filename.endswith('mp4'):
return librosa.load(filename, sr=16000)[0][-19456:]
# librosa.load() : wav파일을 읽을 때 사용. librosa로 데이터를 읽으면 범위가 -1 ~ 1로 정규화 된다.
# sr : sampling rate (주파수 분석 및 파형의 시간 간격을 결정)
# 비디오의 경우 : 1초에 보이는 프레임이 몇 개인가
# 오디오의 경우 : 프레임이 아닌 샘플이라고 부른다. 단위는 Hz
# sr이 높은 것이 음질이 좋다.
# https://wiserloner.tistory.com/1194
# 16,000 Hz : 표준 전화 협대역인 8,000 Hz보다 높은 광대역 주파수 확장. VoIP
else:
return np.load(filename)
except IOError:
print("Error when reading file: {}".format(filename))
sys.exit()
def _apply_variable_length_aug(self, filename, raw_data):
# read info txt file (to see duration of word, to be used to do temporal cropping)
info_txt = os.path.join(self._annonation_direc, *filename.split('/')[self.label_idx:] ) # swap base folder
info_txt = os.path.splitext( info_txt )[0] + '.txt' # swap extension
info = read_txt_lines(info_txt)
utterance_duration = float( info[4].split(' ')[1] )
half_interval = int(utterance_duration/2.0 * self.fps) # num frames of utterance / 2
n_frames = raw_data.shape[0]
mid_idx = ( n_frames -1 ) // 2 # video has n frames, mid point is (n-1)//2 as count starts with 0
left_idx = random.randint(0, max(0,mid_idx-half_interval-1)) # random.randint(a,b) chooses in [a,b]
right_idx = random.randint(min( mid_idx+half_interval+1, n_frames ), n_frames)
return raw_data[left_idx:right_idx]
def __getitem__(self, idx):
raw_data = self.load_data(self.list[idx][0])
# -- perform variable length on training set
if ( self._data_partition == 'train' ) and self.is_var_length:
data = self._apply_variable_length_aug(self.list[idx][0], raw_data)
else:
data = raw_data
preprocess_data = self.preprocessing_func(data)
label = self.list[idx][1]
return preprocess_data, label
def __len__(self):
return len(self._data_files)
def pad_packed_collate(batch):
batch = np.array(batch, dtype=object) # list 라서 numpy 로 변경, 내부 요소 리스트 길이가 달라서 dytpe=object 설정하는 코드 추가
if len(batch) == 1:
data, lengths, labels_np, = zip(*[(a, a.shape[0], b) for (a, b) in sorted(batch, key=lambda x: x[0].shape[0], reverse=True)])
data = torch.FloatTensor(data)
lengths = [data.size(1)]
if len(batch) > 1:
data_list, lengths, labels_np = zip(*[(a, a.shape[0], b) for (a, b) in sorted(batch, key=lambda x: x[0].shape[0], reverse=True)])
data_np = 0 # data_np 변수 초기화하는 코드 추가
if data_list[0].ndim == 3:
max_len, h, w = data_list[0].shape # since it is sorted, the longest video is the first one
data_np = np.zeros(( len(data_list), max_len, h, w))
elif data_list[0].ndim == 1:
max_len = data_list[0].shape[0]
data_np = np.zeros( (len(data_list), max_len))
for idx in range( len(data_np)):
data_np[idx][:data_list[idx].shape[0]] = data_list[idx]
data = torch.FloatTensor(data_np)
labels = torch.LongTensor(labels_np)
return data, lengths, labels