File size: 5,823 Bytes
67d041f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
from argparse import Namespace
import torch
import numpy as np
import pickle, os, logging
from typing import Dict, List, Optional
import hgtk

from Pattern_Generator import Convert_Feature_Based_Music, Expand_by_Duration

def Decompose(syllable: str):    
    onset, nucleus, coda = hgtk.letter.decompose(syllable)
    coda += '_'

    return onset, nucleus, coda

def Lyric_to_Token(lyric: List[str], token_dict: Dict[str, int]):
    return [
        token_dict[letter]
        for letter in list(lyric)
        ]

def Token_Stack(tokens: List[List[int]], token_dict: Dict[str, int], max_length: Optional[int]= None):
    max_token_length = max_length or max([len(token) for token in tokens])
    tokens = np.stack(
        [np.pad(token[:max_token_length], [0, max_token_length - len(token[:max_token_length])], constant_values= token_dict['<X>']) for token in tokens],
        axis= 0
        )
    return tokens

def Note_Stack(notes: List[List[int]], max_length: Optional[int]= None):
    max_note_length = max_length or max([len(note) for note in notes])
    notes = np.stack(
        [np.pad(note[:max_note_length], [0, max_note_length - len(note[:max_note_length])], constant_values= 0) for note in notes],
        axis= 0
        )
    return notes

def Duration_Stack(durations: List[List[int]], max_length: Optional[int]= None):
    max_duration_length = max_length or max([len(duration) for duration in durations])
    durations = np.stack(
        [np.pad(duration[:max_duration_length], [0, max_duration_length - len(duration[:max_duration_length])], constant_values= 0) for duration in durations],
        axis= 0
        )
    return durations

def Feature_Stack(features: List[np.array], max_length: Optional[int]= None):
    max_feature_length = max_length or max([feature.shape[0] for feature in features])
    features = np.stack(
        [np.pad(feature, [[0, max_feature_length - feature.shape[0]], [0, 0]], constant_values= -1.0) for feature in features],
        axis= 0
        )
    return features

def Log_F0_Stack(log_f0s: List[np.array], max_length: int= None):
    max_log_f0_length = max_length or max([len(log_f0) for log_f0 in log_f0s])
    log_f0s = np.stack(
        [np.pad(log_f0, [0, max_log_f0_length - len(log_f0)], constant_values= 0.0) for log_f0 in log_f0s],
        axis= 0
        )
    return log_f0s

class Inference_Dataset(torch.utils.data.Dataset):
    def __init__(
        self,
        token_dict: Dict[str, int],        
        singer_info_dict: Dict[str, int],
        genre_info_dict: Dict[str, int],
        durations: List[List[float]],
        lyrics: List[List[str]],
        notes: List[List[int]],
        singers: List[str],
        genres: List[str],
        sample_rate: int,
        frame_shift: int,
        equality_duration: bool= False,
        consonant_duration: int= 3
        ):
        super().__init__()
        self.token_dict = token_dict
        self.singer_info_dict = singer_info_dict
        self.genre_info_dict = genre_info_dict
        self.equality_duration = equality_duration
        self.consonant_duration = consonant_duration

        self.patterns = []
        for index, (duration, lyric, note, singer, genre) in enumerate(zip(durations, lyrics, notes, singers, genres)):
            if not singer in self.singer_info_dict.keys():
                logging.warn('The singer \'{}\' is incorrect. The pattern \'{}\' is ignoired.'.format(singer, index))
                continue
            if not genre in self.genre_info_dict.keys():
                logging.warn('The genre \'{}\' is incorrect. The pattern \'{}\' is ignoired.'.format(genre, index))
                continue
                        
            music = [x for x in zip(duration, lyric, note)]            
            singer_label = singer
            text = lyric
            
            lyric, note, duration = Convert_Feature_Based_Music(
                music= music,
                sample_rate= sample_rate,
                frame_shift= frame_shift,
                consonant_duration= consonant_duration,
                equality_duration= equality_duration
                )
            lyric_expand, note_expand, duration_expand = Expand_by_Duration(lyric, note, duration)

            singer = self.singer_info_dict[singer]
            genre = self.genre_info_dict[genre]

            self.patterns.append((lyric_expand, note_expand, duration_expand, singer, genre, singer_label, text))

    def __getitem__(self, idx):
        lyric, note, duration, singer, genre, singer_label, text = self.patterns[idx]

        return Lyric_to_Token(lyric, self.token_dict), note, duration, singer, genre, singer_label, text

    def __len__(self):
        return len(self.patterns)

class Inference_Collater:
    def __init__(self,
        token_dict: Dict[str, int]
        ):
        self.token_dict = token_dict
         
    def __call__(self, batch):
        tokens, notes, durations, singers, genres, singer_labels, lyrics = zip(*batch)
        
        lengths = np.array([len(token) for token in tokens])

        max_length = max(lengths)

        tokens = Token_Stack(tokens, self.token_dict, max_length)
        notes = Note_Stack(notes, max_length)        
        durations = Duration_Stack(durations, max_length)

        tokens = torch.LongTensor(tokens)   # [Batch, Time]
        notes = torch.LongTensor(notes)   # [Batch, Time]
        durations = torch.LongTensor(durations)   # [Batch, Time]
        lengths = torch.LongTensor(lengths)   # [Batch]
        singers = torch.LongTensor(singers)  # [Batch]
        genres = torch.LongTensor(genres)  # [Batch]
        
        lyrics = [''.join([(x if x != '<X>' else ' ') for x in lyric]) for lyric in lyrics]

        return tokens, notes, durations, lengths, singers, genres, singer_labels, lyrics