diffsingerkr / Datasets.py
codejin's picture
initial commit
67d041f
raw
history blame
No virus
5.82 kB
from argparse import Namespace
import torch
import numpy as np
import pickle, os, logging
from typing import Dict, List, Optional
import hgtk
from Pattern_Generator import Convert_Feature_Based_Music, Expand_by_Duration
def Decompose(syllable: str):
onset, nucleus, coda = hgtk.letter.decompose(syllable)
coda += '_'
return onset, nucleus, coda
def Lyric_to_Token(lyric: List[str], token_dict: Dict[str, int]):
return [
token_dict[letter]
for letter in list(lyric)
]
def Token_Stack(tokens: List[List[int]], token_dict: Dict[str, int], max_length: Optional[int]= None):
max_token_length = max_length or max([len(token) for token in tokens])
tokens = np.stack(
[np.pad(token[:max_token_length], [0, max_token_length - len(token[:max_token_length])], constant_values= token_dict['<X>']) for token in tokens],
axis= 0
)
return tokens
def Note_Stack(notes: List[List[int]], max_length: Optional[int]= None):
max_note_length = max_length or max([len(note) for note in notes])
notes = np.stack(
[np.pad(note[:max_note_length], [0, max_note_length - len(note[:max_note_length])], constant_values= 0) for note in notes],
axis= 0
)
return notes
def Duration_Stack(durations: List[List[int]], max_length: Optional[int]= None):
max_duration_length = max_length or max([len(duration) for duration in durations])
durations = np.stack(
[np.pad(duration[:max_duration_length], [0, max_duration_length - len(duration[:max_duration_length])], constant_values= 0) for duration in durations],
axis= 0
)
return durations
def Feature_Stack(features: List[np.array], max_length: Optional[int]= None):
max_feature_length = max_length or max([feature.shape[0] for feature in features])
features = np.stack(
[np.pad(feature, [[0, max_feature_length - feature.shape[0]], [0, 0]], constant_values= -1.0) for feature in features],
axis= 0
)
return features
def Log_F0_Stack(log_f0s: List[np.array], max_length: int= None):
max_log_f0_length = max_length or max([len(log_f0) for log_f0 in log_f0s])
log_f0s = np.stack(
[np.pad(log_f0, [0, max_log_f0_length - len(log_f0)], constant_values= 0.0) for log_f0 in log_f0s],
axis= 0
)
return log_f0s
class Inference_Dataset(torch.utils.data.Dataset):
def __init__(
self,
token_dict: Dict[str, int],
singer_info_dict: Dict[str, int],
genre_info_dict: Dict[str, int],
durations: List[List[float]],
lyrics: List[List[str]],
notes: List[List[int]],
singers: List[str],
genres: List[str],
sample_rate: int,
frame_shift: int,
equality_duration: bool= False,
consonant_duration: int= 3
):
super().__init__()
self.token_dict = token_dict
self.singer_info_dict = singer_info_dict
self.genre_info_dict = genre_info_dict
self.equality_duration = equality_duration
self.consonant_duration = consonant_duration
self.patterns = []
for index, (duration, lyric, note, singer, genre) in enumerate(zip(durations, lyrics, notes, singers, genres)):
if not singer in self.singer_info_dict.keys():
logging.warn('The singer \'{}\' is incorrect. The pattern \'{}\' is ignoired.'.format(singer, index))
continue
if not genre in self.genre_info_dict.keys():
logging.warn('The genre \'{}\' is incorrect. The pattern \'{}\' is ignoired.'.format(genre, index))
continue
music = [x for x in zip(duration, lyric, note)]
singer_label = singer
text = lyric
lyric, note, duration = Convert_Feature_Based_Music(
music= music,
sample_rate= sample_rate,
frame_shift= frame_shift,
consonant_duration= consonant_duration,
equality_duration= equality_duration
)
lyric_expand, note_expand, duration_expand = Expand_by_Duration(lyric, note, duration)
singer = self.singer_info_dict[singer]
genre = self.genre_info_dict[genre]
self.patterns.append((lyric_expand, note_expand, duration_expand, singer, genre, singer_label, text))
def __getitem__(self, idx):
lyric, note, duration, singer, genre, singer_label, text = self.patterns[idx]
return Lyric_to_Token(lyric, self.token_dict), note, duration, singer, genre, singer_label, text
def __len__(self):
return len(self.patterns)
class Inference_Collater:
def __init__(self,
token_dict: Dict[str, int]
):
self.token_dict = token_dict
def __call__(self, batch):
tokens, notes, durations, singers, genres, singer_labels, lyrics = zip(*batch)
lengths = np.array([len(token) for token in tokens])
max_length = max(lengths)
tokens = Token_Stack(tokens, self.token_dict, max_length)
notes = Note_Stack(notes, max_length)
durations = Duration_Stack(durations, max_length)
tokens = torch.LongTensor(tokens) # [Batch, Time]
notes = torch.LongTensor(notes) # [Batch, Time]
durations = torch.LongTensor(durations) # [Batch, Time]
lengths = torch.LongTensor(lengths) # [Batch]
singers = torch.LongTensor(singers) # [Batch]
genres = torch.LongTensor(genres) # [Batch]
lyrics = [''.join([(x if x != '<X>' else ' ') for x in lyric]) for lyric in lyrics]
return tokens, notes, durations, lengths, singers, genres, singer_labels, lyrics