from argparse import Namespace import torch import numpy as np import pickle, os, logging from typing import Dict, List, Optional import hgtk from Pattern_Generator import Convert_Feature_Based_Music, Expand_by_Duration def Decompose(syllable: str): onset, nucleus, coda = hgtk.letter.decompose(syllable) coda += '_' return onset, nucleus, coda def Lyric_to_Token(lyric: List[str], token_dict: Dict[str, int]): return [ token_dict[letter] for letter in list(lyric) ] def Token_Stack(tokens: List[List[int]], token_dict: Dict[str, int], max_length: Optional[int]= None): max_token_length = max_length or max([len(token) for token in tokens]) tokens = np.stack( [np.pad(token[:max_token_length], [0, max_token_length - len(token[:max_token_length])], constant_values= token_dict['']) for token in tokens], axis= 0 ) return tokens def Note_Stack(notes: List[List[int]], max_length: Optional[int]= None): max_note_length = max_length or max([len(note) for note in notes]) notes = np.stack( [np.pad(note[:max_note_length], [0, max_note_length - len(note[:max_note_length])], constant_values= 0) for note in notes], axis= 0 ) return notes def Duration_Stack(durations: List[List[int]], max_length: Optional[int]= None): max_duration_length = max_length or max([len(duration) for duration in durations]) durations = np.stack( [np.pad(duration[:max_duration_length], [0, max_duration_length - len(duration[:max_duration_length])], constant_values= 0) for duration in durations], axis= 0 ) return durations def Feature_Stack(features: List[np.array], max_length: Optional[int]= None): max_feature_length = max_length or max([feature.shape[0] for feature in features]) features = np.stack( [np.pad(feature, [[0, max_feature_length - feature.shape[0]], [0, 0]], constant_values= -1.0) for feature in features], axis= 0 ) return features def Log_F0_Stack(log_f0s: List[np.array], max_length: int= None): max_log_f0_length = max_length or max([len(log_f0) for log_f0 in log_f0s]) log_f0s = np.stack( [np.pad(log_f0, [0, max_log_f0_length - len(log_f0)], constant_values= 0.0) for log_f0 in log_f0s], axis= 0 ) return log_f0s class Inference_Dataset(torch.utils.data.Dataset): def __init__( self, token_dict: Dict[str, int], singer_info_dict: Dict[str, int], genre_info_dict: Dict[str, int], durations: List[List[float]], lyrics: List[List[str]], notes: List[List[int]], singers: List[str], genres: List[str], sample_rate: int, frame_shift: int, equality_duration: bool= False, consonant_duration: int= 3 ): super().__init__() self.token_dict = token_dict self.singer_info_dict = singer_info_dict self.genre_info_dict = genre_info_dict self.equality_duration = equality_duration self.consonant_duration = consonant_duration self.patterns = [] for index, (duration, lyric, note, singer, genre) in enumerate(zip(durations, lyrics, notes, singers, genres)): if not singer in self.singer_info_dict.keys(): logging.warn('The singer \'{}\' is incorrect. The pattern \'{}\' is ignoired.'.format(singer, index)) continue if not genre in self.genre_info_dict.keys(): logging.warn('The genre \'{}\' is incorrect. The pattern \'{}\' is ignoired.'.format(genre, index)) continue music = [x for x in zip(duration, lyric, note)] singer_label = singer text = lyric lyric, note, duration = Convert_Feature_Based_Music( music= music, sample_rate= sample_rate, frame_shift= frame_shift, consonant_duration= consonant_duration, equality_duration= equality_duration ) lyric_expand, note_expand, duration_expand = Expand_by_Duration(lyric, note, duration) singer = self.singer_info_dict[singer] genre = self.genre_info_dict[genre] self.patterns.append((lyric_expand, note_expand, duration_expand, singer, genre, singer_label, text)) def __getitem__(self, idx): lyric, note, duration, singer, genre, singer_label, text = self.patterns[idx] return Lyric_to_Token(lyric, self.token_dict), note, duration, singer, genre, singer_label, text def __len__(self): return len(self.patterns) class Inference_Collater: def __init__(self, token_dict: Dict[str, int] ): self.token_dict = token_dict def __call__(self, batch): tokens, notes, durations, singers, genres, singer_labels, lyrics = zip(*batch) lengths = np.array([len(token) for token in tokens]) max_length = max(lengths) tokens = Token_Stack(tokens, self.token_dict, max_length) notes = Note_Stack(notes, max_length) durations = Duration_Stack(durations, max_length) tokens = torch.LongTensor(tokens) # [Batch, Time] notes = torch.LongTensor(notes) # [Batch, Time] durations = torch.LongTensor(durations) # [Batch, Time] lengths = torch.LongTensor(lengths) # [Batch] singers = torch.LongTensor(singers) # [Batch] genres = torch.LongTensor(genres) # [Batch] lyrics = [''.join([(x if x != '' else ' ') for x in lyric]) for lyric in lyrics] return tokens, notes, durations, lengths, singers, genres, singer_labels, lyrics