lmzjms's picture
Upload 35 files
15ac91d
raw
history blame
No virus
5.11 kB
import os
import logging
import h5py
import soundfile
import librosa
import numpy as np
import pandas as pd
from scipy import stats
import datetime
import pickle
def create_folder(fd):
if not os.path.exists(fd):
os.makedirs(fd)
def get_filename(path):
path = os.path.realpath(path)
na_ext = path.split('/')[-1]
na = os.path.splitext(na_ext)[0]
return na
def get_sub_filepaths(folder):
paths = []
for root, dirs, files in os.walk(folder):
for name in files:
path = os.path.join(root, name)
paths.append(path)
return paths
def create_logging(log_dir, filemode):
create_folder(log_dir)
i1 = 0
while os.path.isfile(os.path.join(log_dir, '{:04d}.log'.format(i1))):
i1 += 1
log_path = os.path.join(log_dir, '{:04d}.log'.format(i1))
logging.basicConfig(
level=logging.DEBUG,
format='%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s',
datefmt='%a, %d %b %Y %H:%M:%S',
filename=log_path,
filemode=filemode)
# Print to console
console = logging.StreamHandler()
console.setLevel(logging.INFO)
formatter = logging.Formatter('%(name)-12s: %(levelname)-8s %(message)s')
console.setFormatter(formatter)
logging.getLogger('').addHandler(console)
return logging
def read_metadata(csv_path, classes_num, id_to_ix):
"""Read metadata of AudioSet from a csv file.
Args:
csv_path: str
Returns:
meta_dict: {'audio_name': (audios_num,), 'target': (audios_num, classes_num)}
"""
with open(csv_path, 'r') as fr:
lines = fr.readlines()
lines = lines[3:] # Remove heads
audios_num = len(lines)
targets = np.zeros((audios_num, classes_num), dtype=np.bool)
audio_names = []
for n, line in enumerate(lines):
items = line.split(', ')
"""items: ['--4gqARaEJE', '0.000', '10.000', '"/m/068hy,/m/07q6cd_,/m/0bt9lr,/m/0jbk"\n']"""
audio_name = 'Y{}.wav'.format(items[0]) # Audios are started with an extra 'Y' when downloading
label_ids = items[3].split('"')[1].split(',')
audio_names.append(audio_name)
# Target
for id in label_ids:
ix = id_to_ix[id]
targets[n, ix] = 1
meta_dict = {'audio_name': np.array(audio_names), 'target': targets}
return meta_dict
def float32_to_int16(x):
assert np.max(np.abs(x)) <= 1.2
x = np.clip(x, -1, 1)
return (x * 32767.).astype(np.int16)
def int16_to_float32(x):
return (x / 32767.).astype(np.float32)
def pad_or_truncate(x, audio_length):
"""Pad all audio to specific length."""
if len(x) <= audio_length:
return np.concatenate((x, np.zeros(audio_length - len(x))), axis=0)
else:
return x[0 : audio_length]
def d_prime(auc):
d_prime = stats.norm().ppf(auc) * np.sqrt(2.0)
return d_prime
class Mixup(object):
def __init__(self, mixup_alpha, random_seed=1234):
"""Mixup coefficient generator.
"""
self.mixup_alpha = mixup_alpha
self.random_state = np.random.RandomState(random_seed)
def get_lambda(self, batch_size):
"""Get mixup random coefficients.
Args:
batch_size: int
Returns:
mixup_lambdas: (batch_size,)
"""
mixup_lambdas = []
for n in range(0, batch_size, 2):
lam = self.random_state.beta(self.mixup_alpha, self.mixup_alpha, 1)[0]
mixup_lambdas.append(lam)
mixup_lambdas.append(1. - lam)
return np.array(mixup_lambdas)
class StatisticsContainer(object):
def __init__(self, statistics_path):
"""Contain statistics of different training iterations.
"""
self.statistics_path = statistics_path
self.backup_statistics_path = '{}_{}.pkl'.format(
os.path.splitext(self.statistics_path)[0],
datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S'))
self.statistics_dict = {'bal': [], 'test': []}
def append(self, iteration, statistics, data_type):
statistics['iteration'] = iteration
self.statistics_dict[data_type].append(statistics)
def dump(self):
pickle.dump(self.statistics_dict, open(self.statistics_path, 'wb'))
pickle.dump(self.statistics_dict, open(self.backup_statistics_path, 'wb'))
logging.info(' Dump statistics to {}'.format(self.statistics_path))
logging.info(' Dump statistics to {}'.format(self.backup_statistics_path))
def load_state_dict(self, resume_iteration):
self.statistics_dict = pickle.load(open(self.statistics_path, 'rb'))
resume_statistics_dict = {'bal': [], 'test': []}
for key in self.statistics_dict.keys():
for statistics in self.statistics_dict[key]:
if statistics['iteration'] <= resume_iteration:
resume_statistics_dict[key].append(statistics)
self.statistics_dict = resume_statistics_dict