LoCoNet_ASD / dataLoaderTalkSet.py
xiziwang
push files
2e36228
raw
history blame
No virus
7.62 kB
import os, torch, numpy, cv2, imageio, random, python_speech_features
import matplotlib.pyplot as plt
from scipy.io import wavfile
from glob import glob
from torchvision.transforms import RandomCrop
from scipy import signal
def get_noise_list(musanPath, rirPath):
augment_files = glob(os.path.join(musanPath, '*/*/*/*.wav'))
noiselist = {}
rir = numpy.load(rirPath)
for file in augment_files:
if not file.split('/')[-4] in noiselist:
noiselist[file.split('/')[-4]] = []
noiselist[file.split('/')[-4]].append(file)
return rir, noiselist
def augment_wav(audio, aug_type, rir, noiselist):
if aug_type == 'rir':
rir_gains = numpy.random.uniform(-7,3,1)
rir_filts = random.choice(rir)
rir = numpy.multiply(rir_filts, pow(10, 0.1 * rir_gains))
audio = signal.convolve(audio, rir, mode='full')[:len(audio)]
else:
noisecat = aug_type
noisefile = random.choice(noiselist[noisecat].copy())
snr = [random.uniform({'noise':[0,15],'music':[5,15]}[noisecat][0], {'noise':[0,15],'music':[5,15]}[noisecat][1])]
_, noiseaudio = wavfile.read(noisefile)
if len(noiseaudio) < len(audio):
shortage = len(audio) - len(noiseaudio)
noiseaudio = numpy.pad(noiseaudio, (0, shortage), 'wrap')
else:
noiseaudio = noiseaudio[:len(audio)]
noise_db = 10 * numpy.log10(numpy.mean(abs(noiseaudio ** 2)) + 1e-4)
clean_db = 10 * numpy.log10(numpy.mean(abs(audio ** 2)) + 1e-4)
noise = numpy.sqrt(10 ** ((clean_db - noise_db - snr) / 10)) * noiseaudio
audio = audio + noise
return audio.astype(numpy.int16)
def load_audio(data, data_path, length, start, end, audio_aug, rirlist = None, noiselist = None):
# Find the path of the audio data
data_type = data[0]
id_name = data[1][:8]
file_name = data[1].split('/')[0] + '_' + data[1].split('/')[1] + '_' + data[1].split('/')[2] + \
'_' + data[2].split('/')[0] + '_' + data[2].split('/')[1] + '_' + data[2].split('/')[2] + '.wav'
audio_file_path = os.path.join(data_path, data_type, id_name, file_name)
# Load audio, compute MFCC, cut it to the required length
_, audio = wavfile.read(audio_file_path)
if audio_aug == True:
augtype = random.randint(0,3)
if augtype == 1: # rir
audio = augment_wav(audio, 'rir', rirlist, noiselist)
elif augtype == 2:
audio = augment_wav(audio, 'noise', rirlist, noiselist)
elif augtype == 3:
audio = augment_wav(audio, 'music', rirlist, noiselist)
else:
audio = audio
feature = python_speech_features.mfcc(audio, 16000, numcep = 13, winlen = 0.025, winstep = 0.010)
length_audio = int(round(length * 100))
if feature.shape[0] < length_audio:
shortage = length_audio - feature.shape[0]
feature = numpy.pad(feature, ((0, shortage), (0,0)), 'wrap')
feature = feature[int(round(start * 100)):int(round(end * 100)),:]
return feature
def load_video(data, data_path, length, start, end, visual_aug):
# Find the path of the visual data
data_type = data[0]
id_name = data[1][:8]
file_name = data[1].split('/')[0] + '_' + data[1].split('/')[1] + '_' + data[1].split('/')[2] + \
'_' + data[2].split('/')[0] + '_' + data[2].split('/')[1] + '_' + data[2].split('/')[2] + '.mp4'
video_file_path = os.path.join(data_path, data_type, id_name, file_name)
# Load visual frame-by-frame, cut it to the required length
length_video = int(round((end - start) * 25))
video = cv2.VideoCapture(video_file_path)
faces = []
augtype = 'orig'
if visual_aug == True:
new = int(112*random.uniform(0.7, 1))
x, y = numpy.random.randint(0, 112 - new), numpy.random.randint(0, 112 - new)
M = cv2.getRotationMatrix2D((112/2,112/2), random.uniform(-15, 15), 1)
augtype = random.choice(['orig', 'flip', 'crop', 'rotate'])
num_frame = 0
while video.isOpened():
ret, frames = video.read()
if ret == True:
num_frame += 1
if num_frame >= int(round(start * 25)) and num_frame < int(round(end * 25)):
face = cv2.cvtColor(frames, cv2.COLOR_BGR2GRAY)
face = cv2.resize(face, (224,224))
face = face[int(112-(112/2)):int(112+(112/2)), int(112-(112/2)):int(112+(112/2))]
if augtype == 'orig':
faces.append(face)
elif augtype == 'flip':
faces.append(cv2.flip(face, 1))
elif augtype == 'crop':
faces.append(cv2.resize(face[y:y+new, x:x+new] , (112,112)))
elif augtype == 'rotate':
faces.append(cv2.warpAffine(face, M, (112,112)))
else:
break
video.release()
faces = numpy.array(faces)
if faces.shape[0] < length_video:
shortage = length_video - faces.shape[0]
faces = numpy.pad(faces, ((0,shortage), (0,0),(0,0)), 'wrap')
# faces = numpy.array(faces)[int(round(start * 25)):int(round(end * 25)),:,:]
return faces
def load_label(data, length, start, end):
labels_all = []
labels = []
data_type = data[0]
start_T, end_T, start_F, end_F = float(data[4]), float(data[5]), float(data[6]), float(data[7])
for i in range(int(round(length * 100))):
if data_type == 'TAudio':
labels_all.append(1)
elif data_type == 'FAudio' or data_type == 'FSilence':
labels_all.append(0)
else:
if i >= int(round(start_T * 100)) and i <= int(round(end_T * 100)):
labels_all.append(1)
else:
labels_all.append(0)
for i in range(int(round(length * 25))):
labels.append(int(round(sum(labels_all[i*4: (i+1)*4]) / 4)))
return labels[round(start*25): round(end*25)]
class loader_TalkSet(object):
def __init__(self, trial_file_name, data_path, audio_aug, visual_aug, musanPath, rirPath,**kwargs):
self.data_path = data_path
self.audio_aug = audio_aug
self.visual_aug = visual_aug
self.minibatch = []
self.rir, self.noiselist = get_noise_list(musanPath, rirPath)
mix_lst = open(trial_file_name).read().splitlines()
mix_lst = list(filter(lambda x: float(x.split()[3]) >= 1, mix_lst)) # filter the video less than 1s
# mix_lst = list(filter(lambda x: x.split()[0] == 'TSilence', mix_lst))
sorted_mix_lst = sorted(mix_lst, key=lambda data: (float(data.split()[3]), int(data.split()[-1])), reverse=True)
start = 0
while True:
length_total = float(sorted_mix_lst[start].split()[3])
batch_size = int(250 / length_total)
end = min(len(sorted_mix_lst), start + batch_size)
self.minibatch.append(sorted_mix_lst[start:end])
if end == len(sorted_mix_lst):
break
start = end
# self.minibatch = self.minibatch[0:5]
def __getitem__(self, index):
batch_lst = self.minibatch[index]
length_total = float(batch_lst[-1].split()[3])
length_total = (int(round(length_total * 100)) - int(round(length_total * 100)) % 4) / 100
audio_feature, video_feature, labels = [], [], []
duration = random.choice([1,2,4,6])
#duration = 6
length = min(length_total, duration)
if length == duration:
start = int(round(random.randint(0, round(length_total * 25) - round(length * 25)) * 0.04 * 100)) / 100
end = int(round((start + length) * 100)) / 100
else:
start, end = 0, length
for line in batch_lst:
data = line.split()
audio_feature.append(load_audio(data, self.data_path, length_total, start, end, audio_aug = self.audio_aug, rirlist = self.rir, noiselist = self.noiselist))
video_feature.append(load_video(data, self.data_path, length_total, start, end, visual_aug = self.visual_aug))
labels.append(load_label(data, length_total, start, end))
return torch.FloatTensor(numpy.array(audio_feature)), \
torch.FloatTensor(numpy.array(video_feature)), \
torch.LongTensor(numpy.array(labels))
def __len__(self):
return len(self.minibatch)