M4Singer / data_gen /singing /binarize.py
kevinwang676's picture
Duplicate from zlc99/M4Singer
26925fd
raw
history blame contribute delete
No virus
17.2 kB
import os
import random
from copy import deepcopy
import pandas as pd
import logging
from tqdm import tqdm
import json
import glob
import re
from resemblyzer import VoiceEncoder
import traceback
import numpy as np
import pretty_midi
import librosa
from scipy.interpolate import interp1d
import torch
from textgrid import TextGrid
from utils.hparams import hparams
from data_gen.tts.data_gen_utils import build_phone_encoder, get_pitch
from utils.pitch_utils import f0_to_coarse
from data_gen.tts.base_binarizer import BaseBinarizer, BinarizationError
from data_gen.tts.binarizer_zh import ZhBinarizer
from data_gen.tts.txt_processors.zh_g2pM import ALL_YUNMU
from vocoders.base_vocoder import VOCODERS
class SingingBinarizer(BaseBinarizer):
def __init__(self, processed_data_dir=None):
if processed_data_dir is None:
processed_data_dir = hparams['processed_data_dir']
self.processed_data_dirs = processed_data_dir.split(",")
self.binarization_args = hparams['binarization_args']
self.pre_align_args = hparams['pre_align_args']
self.item2txt = {}
self.item2ph = {}
self.item2wavfn = {}
self.item2f0fn = {}
self.item2tgfn = {}
self.item2spk = {}
def split_train_test_set(self, item_names):
item_names = deepcopy(item_names)
test_item_names = [x for x in item_names if any([ts in x for ts in hparams['test_prefixes']])]
train_item_names = [x for x in item_names if x not in set(test_item_names)]
logging.info("train {}".format(len(train_item_names)))
logging.info("test {}".format(len(test_item_names)))
return train_item_names, test_item_names
def load_meta_data(self):
for ds_id, processed_data_dir in enumerate(self.processed_data_dirs):
wav_suffix = '_wf0.wav'
txt_suffix = '.txt'
ph_suffix = '_ph.txt'
tg_suffix = '.TextGrid'
all_wav_pieces = glob.glob(f'{processed_data_dir}/*/*{wav_suffix}')
for piece_path in all_wav_pieces:
item_name = raw_item_name = piece_path[len(processed_data_dir)+1:].replace('/', '-')[:-len(wav_suffix)]
if len(self.processed_data_dirs) > 1:
item_name = f'ds{ds_id}_{item_name}'
self.item2txt[item_name] = open(f'{piece_path.replace(wav_suffix, txt_suffix)}').readline()
self.item2ph[item_name] = open(f'{piece_path.replace(wav_suffix, ph_suffix)}').readline()
self.item2wavfn[item_name] = piece_path
self.item2spk[item_name] = re.split('-|#', piece_path.split('/')[-2])[0]
if len(self.processed_data_dirs) > 1:
self.item2spk[item_name] = f"ds{ds_id}_{self.item2spk[item_name]}"
self.item2tgfn[item_name] = piece_path.replace(wav_suffix, tg_suffix)
print('spkers: ', set(self.item2spk.values()))
self.item_names = sorted(list(self.item2txt.keys()))
if self.binarization_args['shuffle']:
random.seed(1234)
random.shuffle(self.item_names)
self._train_item_names, self._test_item_names = self.split_train_test_set(self.item_names)
@property
def train_item_names(self):
return self._train_item_names
@property
def valid_item_names(self):
return self._test_item_names
@property
def test_item_names(self):
return self._test_item_names
def process(self):
self.load_meta_data()
os.makedirs(hparams['binary_data_dir'], exist_ok=True)
self.spk_map = self.build_spk_map()
print("| spk_map: ", self.spk_map)
spk_map_fn = f"{hparams['binary_data_dir']}/spk_map.json"
json.dump(self.spk_map, open(spk_map_fn, 'w'))
self.phone_encoder = self._phone_encoder()
self.process_data('valid')
self.process_data('test')
self.process_data('train')
def _phone_encoder(self):
ph_set_fn = f"{hparams['binary_data_dir']}/phone_set.json"
ph_set = []
if hparams['reset_phone_dict'] or not os.path.exists(ph_set_fn):
for ph_sent in self.item2ph.values():
ph_set += ph_sent.split(' ')
ph_set = sorted(set(ph_set))
json.dump(ph_set, open(ph_set_fn, 'w'))
print("| Build phone set: ", ph_set)
else:
ph_set = json.load(open(ph_set_fn, 'r'))
print("| Load phone set: ", ph_set)
return build_phone_encoder(hparams['binary_data_dir'])
# @staticmethod
# def get_pitch(wav_fn, spec, res):
# wav_suffix = '_wf0.wav'
# f0_suffix = '_f0.npy'
# f0fn = wav_fn.replace(wav_suffix, f0_suffix)
# pitch_info = np.load(f0fn)
# f0 = [x[1] for x in pitch_info]
# spec_x_coor = np.arange(0, 1, 1 / len(spec))[:len(spec)]
# f0_x_coor = np.arange(0, 1, 1 / len(f0))[:len(f0)]
# f0 = interp1d(f0_x_coor, f0, 'nearest', fill_value='extrapolate')(spec_x_coor)[:len(spec)]
# # f0_x_coor = np.arange(0, 1, 1 / len(f0))
# # f0_x_coor[-1] = 1
# # f0 = interp1d(f0_x_coor, f0, 'nearest')(spec_x_coor)[:len(spec)]
# if sum(f0) == 0:
# raise BinarizationError("Empty f0")
# assert len(f0) == len(spec), (len(f0), len(spec))
# pitch_coarse = f0_to_coarse(f0)
#
# # vis f0
# # import matplotlib.pyplot as plt
# # from textgrid import TextGrid
# # tg_fn = wav_fn.replace(wav_suffix, '.TextGrid')
# # fig = plt.figure(figsize=(12, 6))
# # plt.pcolor(spec.T, vmin=-5, vmax=0)
# # ax = plt.gca()
# # ax2 = ax.twinx()
# # ax2.plot(f0, color='red')
# # ax2.set_ylim(0, 800)
# # itvs = TextGrid.fromFile(tg_fn)[0]
# # for itv in itvs:
# # x = itv.maxTime * hparams['audio_sample_rate'] / hparams['hop_size']
# # plt.vlines(x=x, ymin=0, ymax=80, color='black')
# # plt.text(x=x, y=20, s=itv.mark, color='black')
# # plt.savefig('tmp/20211229_singing_plots_test.png')
#
# res['f0'] = f0
# res['pitch'] = pitch_coarse
@classmethod
def process_item(cls, item_name, ph, txt, tg_fn, wav_fn, spk_id, encoder, binarization_args):
if hparams['vocoder'] in VOCODERS:
wav, mel = VOCODERS[hparams['vocoder']].wav2spec(wav_fn)
else:
wav, mel = VOCODERS[hparams['vocoder'].split('.')[-1]].wav2spec(wav_fn)
res = {
'item_name': item_name, 'txt': txt, 'ph': ph, 'mel': mel, 'wav': wav, 'wav_fn': wav_fn,
'sec': len(wav) / hparams['audio_sample_rate'], 'len': mel.shape[0], 'spk_id': spk_id
}
try:
if binarization_args['with_f0']:
# cls.get_pitch(wav_fn, mel, res)
cls.get_pitch(wav, mel, res)
if binarization_args['with_txt']:
try:
# print(ph)
phone_encoded = res['phone'] = encoder.encode(ph)
except:
traceback.print_exc()
raise BinarizationError(f"Empty phoneme")
if binarization_args['with_align']:
cls.get_align(tg_fn, ph, mel, phone_encoded, res)
except BinarizationError as e:
print(f"| Skip item ({e}). item_name: {item_name}, wav_fn: {wav_fn}")
return None
return res
class MidiSingingBinarizer(SingingBinarizer):
item2midi = {}
item2midi_dur = {}
item2is_slur = {}
item2ph_durs = {}
item2wdb = {}
def load_meta_data(self):
for ds_id, processed_data_dir in enumerate(self.processed_data_dirs):
meta_midi = json.load(open(os.path.join(processed_data_dir, 'meta.json'))) # [list of dict]
for song_item in meta_midi:
item_name = raw_item_name = song_item['item_name']
if len(self.processed_data_dirs) > 1:
item_name = f'ds{ds_id}_{item_name}'
self.item2wavfn[item_name] = song_item['wav_fn']
self.item2txt[item_name] = song_item['txt']
self.item2ph[item_name] = ' '.join(song_item['phs'])
self.item2wdb[item_name] = [1 if x in ALL_YUNMU + ['AP', 'SP', '<SIL>'] else 0 for x in song_item['phs']]
self.item2ph_durs[item_name] = song_item['ph_dur']
self.item2midi[item_name] = song_item['notes']
self.item2midi_dur[item_name] = song_item['notes_dur']
self.item2is_slur[item_name] = song_item['is_slur']
self.item2spk[item_name] = 'pop-cs'
if len(self.processed_data_dirs) > 1:
self.item2spk[item_name] = f"ds{ds_id}_{self.item2spk[item_name]}"
print('spkers: ', set(self.item2spk.values()))
self.item_names = sorted(list(self.item2txt.keys()))
if self.binarization_args['shuffle']:
random.seed(1234)
random.shuffle(self.item_names)
self._train_item_names, self._test_item_names = self.split_train_test_set(self.item_names)
@staticmethod
def get_pitch(wav_fn, wav, spec, ph, res):
wav_suffix = '.wav'
# midi_suffix = '.mid'
wav_dir = 'wavs'
f0_dir = 'f0'
item_name = '/'.join(os.path.splitext(wav_fn)[0].split('/')[-2:]).replace('_wf0', '')
res['pitch_midi'] = np.asarray(MidiSingingBinarizer.item2midi[item_name])
res['midi_dur'] = np.asarray(MidiSingingBinarizer.item2midi_dur[item_name])
res['is_slur'] = np.asarray(MidiSingingBinarizer.item2is_slur[item_name])
res['word_boundary'] = np.asarray(MidiSingingBinarizer.item2wdb[item_name])
assert res['pitch_midi'].shape == res['midi_dur'].shape == res['is_slur'].shape, (
res['pitch_midi'].shape, res['midi_dur'].shape, res['is_slur'].shape)
# gt f0.
gt_f0, gt_pitch_coarse = get_pitch(wav, spec, hparams)
if sum(gt_f0) == 0:
raise BinarizationError("Empty **gt** f0")
res['f0'] = gt_f0
res['pitch'] = gt_pitch_coarse
@staticmethod
def get_align(ph_durs, mel, phone_encoded, res, hop_size=hparams['hop_size'], audio_sample_rate=hparams['audio_sample_rate']):
mel2ph = np.zeros([mel.shape[0]], int)
startTime = 0
for i_ph in range(len(ph_durs)):
start_frame = int(startTime * audio_sample_rate / hop_size + 0.5)
end_frame = int((startTime + ph_durs[i_ph]) * audio_sample_rate / hop_size + 0.5)
mel2ph[start_frame:end_frame] = i_ph + 1
startTime = startTime + ph_durs[i_ph]
# print('ph durs: ', ph_durs)
# print('mel2ph: ', mel2ph, len(mel2ph))
res['mel2ph'] = mel2ph
# res['dur'] = None
@classmethod
def process_item(cls, item_name, ph, txt, tg_fn, wav_fn, spk_id, encoder, binarization_args):
if hparams['vocoder'] in VOCODERS:
wav, mel = VOCODERS[hparams['vocoder']].wav2spec(wav_fn)
else:
wav, mel = VOCODERS[hparams['vocoder'].split('.')[-1]].wav2spec(wav_fn)
res = {
'item_name': item_name, 'txt': txt, 'ph': ph, 'mel': mel, 'wav': wav, 'wav_fn': wav_fn,
'sec': len(wav) / hparams['audio_sample_rate'], 'len': mel.shape[0], 'spk_id': spk_id
}
try:
if binarization_args['with_f0']:
cls.get_pitch(wav_fn, wav, mel, ph, res)
if binarization_args['with_txt']:
try:
phone_encoded = res['phone'] = encoder.encode(ph)
except:
traceback.print_exc()
raise BinarizationError(f"Empty phoneme")
if binarization_args['with_align']:
cls.get_align(MidiSingingBinarizer.item2ph_durs[item_name], mel, phone_encoded, res)
except BinarizationError as e:
print(f"| Skip item ({e}). item_name: {item_name}, wav_fn: {wav_fn}")
return None
return res
class ZhSingingBinarizer(ZhBinarizer, SingingBinarizer):
pass
class M4SingerBinarizer(MidiSingingBinarizer):
item2midi = {}
item2midi_dur = {}
item2is_slur = {}
item2ph_durs = {}
item2wdb = {}
def split_train_test_set(self, item_names):
item_names = deepcopy(item_names)
test_item_names = [x for x in item_names if any([x.startswith(ts) for ts in hparams['test_prefixes']])]
train_item_names = [x for x in item_names if x not in set(test_item_names)]
logging.info("train {}".format(len(train_item_names)))
logging.info("test {}".format(len(test_item_names)))
return train_item_names, test_item_names
def load_meta_data(self):
raw_data_dir = hparams['raw_data_dir']
song_items = json.load(open(os.path.join(raw_data_dir, 'meta.json'))) # [list of dict]
for song_item in song_items:
item_name = raw_item_name = song_item['item_name']
singer, song_name, sent_id = item_name.split("#")
self.item2wavfn[item_name] = f'{raw_data_dir}/{singer}#{song_name}/{sent_id}.wav'
self.item2txt[item_name] = song_item['txt']
self.item2ph[item_name] = ' '.join(song_item['phs'])
self.item2ph_durs[item_name] = song_item['ph_dur']
self.item2midi[item_name] = song_item['notes']
self.item2midi_dur[item_name] = song_item['notes_dur']
self.item2is_slur[item_name] = song_item['is_slur']
self.item2wdb[item_name] = [1 if (0 < i < len(song_item['phs']) - 1 and p in ALL_YUNMU + ['<SP>', '<AP>'])\
or i == len(song_item['phs']) - 1 else 0 for i, p in enumerate(song_item['phs'])]
self.item2spk[item_name] = singer
print('spkers: ', set(self.item2spk.values()))
self.item_names = sorted(list(self.item2txt.keys()))
if self.binarization_args['shuffle']:
random.seed(1234)
random.shuffle(self.item_names)
self._train_item_names, self._test_item_names = self.split_train_test_set(self.item_names)
@staticmethod
def get_pitch(item_name, wav, spec, ph, res):
wav_suffix = '.wav'
# midi_suffix = '.mid'
wav_dir = 'wavs'
f0_dir = 'text_f0_align'
#item_name = os.path.splitext(os.path.basename(wav_fn))[0]
res['pitch_midi'] = np.asarray(M4SingerBinarizer.item2midi[item_name])
res['midi_dur'] = np.asarray(M4SingerBinarizer.item2midi_dur[item_name])
res['is_slur'] = np.asarray(M4SingerBinarizer.item2is_slur[item_name])
res['word_boundary'] = np.asarray(M4SingerBinarizer.item2wdb[item_name])
assert res['pitch_midi'].shape == res['midi_dur'].shape == res['is_slur'].shape, (res['pitch_midi'].shape, res['midi_dur'].shape, res['is_slur'].shape)
# gt f0.
# f0 = None
# f0_suffix = '_f0.npy'
# f0fn = wav_fn.replace(wav_suffix, f0_suffix).replace(wav_dir, f0_dir)
# pitch_info = np.load(f0fn)
# f0 = [x[1] for x in pitch_info]
# spec_x_coor = np.arange(0, 1, 1 / len(spec))[:len(spec)]
#
# f0_x_coor = np.arange(0, 1, 1 / len(f0))[:len(f0)]
# f0 = interp1d(f0_x_coor, f0, 'nearest', fill_value='extrapolate')(spec_x_coor)[:len(spec)]
# if sum(f0) == 0:
# raise BinarizationError("Empty **gt** f0")
#
# pitch_coarse = f0_to_coarse(f0)
# res['f0'] = f0
# res['pitch'] = pitch_coarse
# gt f0.
gt_f0, gt_pitch_coarse = get_pitch(wav, spec, hparams)
if sum(gt_f0) == 0:
raise BinarizationError("Empty **gt** f0")
res['f0'] = gt_f0
res['pitch'] = gt_pitch_coarse
@classmethod
def process_item(cls, item_name, ph, txt, tg_fn, wav_fn, spk_id, encoder, binarization_args):
if hparams['vocoder'] in VOCODERS:
wav, mel = VOCODERS[hparams['vocoder']].wav2spec(wav_fn)
else:
wav, mel = VOCODERS[hparams['vocoder'].split('.')[-1]].wav2spec(wav_fn)
res = {
'item_name': item_name, 'txt': txt, 'ph': ph, 'mel': mel, 'wav': wav, 'wav_fn': wav_fn,
'sec': len(wav) / hparams['audio_sample_rate'], 'len': mel.shape[0], 'spk_id': spk_id
}
try:
if binarization_args['with_f0']:
cls.get_pitch(item_name, wav, mel, ph, res)
if binarization_args['with_txt']:
try:
phone_encoded = res['phone'] = encoder.encode(ph)
except:
traceback.print_exc()
raise BinarizationError(f"Empty phoneme")
if binarization_args['with_align']:
cls.get_align(M4SingerBinarizer.item2ph_durs[item_name], mel, phone_encoded, res)
except BinarizationError as e:
print(f"| Skip item ({e}). item_name: {item_name}, wav_fn: {wav_fn}")
return None
return res
if __name__ == "__main__":
SingingBinarizer().process()