File size: 17,547 Bytes
34501b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
from pathlib import Path
from itertools import tee
import numpy as np
import torch
from .base_dataset import BaseDataset

def find_example_idx(n, cum_sums, idx = 0):
    N = len(cum_sums)
    search_i = N//2 - 1
    if N > 1:
        if n < cum_sums[search_i]:
            return find_example_idx(n, cum_sums[:search_i+1], idx=idx)
        else:
            return find_example_idx(n, cum_sums[search_i+1:], idx=idx+search_i+1)
    else:
        if n < cum_sums[0]:
            return idx
        else:
            return idx + 1


class MultimodalDataset(BaseDataset):

    def __init__(self, opt, split="train"):
        super().__init__()
        self.opt = opt
        data_path = Path(opt.data_dir)
        if not data_path.is_dir():
            raise ValueError('Invalid directory:'+opt.data_dir)

        print(opt.base_filenames_file)
        if split == "train":
            temp_base_filenames = [x[:-1] for x in open(data_path.joinpath(opt.base_filenames_file), "r").readlines()]
        else:
            temp_base_filenames = [x[:-1] for x in open(data_path.joinpath("base_filenames_"+split+".txt"), "r").readlines()]
        if opt.num_train_samples > 0:
            temp_base_filenames = np.random.choice(temp_base_filenames, size=opt.num_train_samples, replace=False)
        self.base_filenames = []

        input_mods = self.opt.input_modalities.split(",")
        output_mods = self.opt.output_modalities.split(",")
        self.input_lengths = input_lengths = [int(x) for x in str(self.opt.input_lengths).split(",")]
        self.output_lengths = output_lengths = [int(x) for x in str(self.opt.output_lengths).split(",")]
        self.output_time_offsets = output_time_offsets = [int(x) for x in str(self.opt.output_time_offsets).split(",")]
        self.input_time_offsets = input_time_offsets = [int(x) for x in str(self.opt.input_time_offsets).split(",")]

        if self.opt.input_types is None:
            input_types = ["c" for inp in input_mods]
        else:
            input_types = self.opt.input_types.split(",")

        if self.opt.input_fix_length_types is None:
            input_fix_length_types = ["end" for inp in input_mods]
        else:
            input_fix_length_types = self.opt.input_fix_length_types.split(",")

        if self.opt.output_fix_length_types is None:
            output_fix_length_types = ["end" for inp in input_mods]
        else:
            output_fix_length_types = self.opt.output_fix_length_types.split(",")

        fix_length_types_dict = {mod:output_fix_length_types[i] for i,mod in enumerate(output_mods)}
        fix_length_types_dict.update({mod:input_fix_length_types[i] for i,mod in enumerate(input_mods)})

        assert len(input_types) == len(input_mods)
        assert len(input_fix_length_types) == len(input_mods)
        assert len(output_fix_length_types) == len(input_mods)
        self.input_types = input_types
        self.input_fix_length_types = input_fix_length_types
        self.output_fix_length_types = output_fix_length_types

        if self.opt.input_num_tokens is None:
            self.input_num_tokens = [0 for inp in input_mods]
        else:
            self.input_num_tokens  = [int(x) for x in self.opt.input_num_tokens.split(",")]

        if self.opt.output_num_tokens is None:
            self.output_num_tokens = [0 for inp in output_mods]
        else:
            self.output_num_tokens  = [int(x) for x in self.opt.output_num_tokens.split(",")]

        if len(output_time_offsets) < len(output_mods):
            if len(output_time_offsets) == 1:
                self.output_time_offsets = output_time_offsets = output_time_offsets*len(output_mods)
            else:
                raise Exception("number of output_time_offsets doesnt match number of output_mods")

        if len(input_time_offsets) < len(input_mods):
            if len(input_time_offsets) == 1:
                self.input_time_offsets = input_time_offsets = input_time_offsets*len(input_mods)
            else:
                raise Exception("number of input_time_offsets doesnt match number of input_mods")

        self.features = {mod:{} for mod in input_mods+output_mods}
        #self.input_features = {input_mod:{} for input_mod in input_mods}
        #self.output_features = {output_mod:{} for output_mod in output_mods}
        if opt.fix_lengths:
            self.features_filenames = {mod:{} for mod in input_mods+output_mods}
            #self.input_features_filenames = {input_mod:{} for input_mod in input_mods}
            #self.output_features_filenames = {input_mod:{} for input_mod in input_mods}

        min_length = max(max(np.array(input_lengths) + np.array(input_time_offsets)), max(np.array(output_time_offsets) + np.array(output_lengths)) ) - min(0,min(output_time_offsets))
        print(min_length)

        fix_lengths = opt.fix_lengths

        self.total_frames = 0
        self.frame_cum_sums = []

        #Get the list of files containing features (in numpy format for now), and populate the dictionaries of input and output features (separated by modality)
        for base_filename in temp_base_filenames:
            file_too_short = False
            first_length=True
            for i, mod in enumerate(input_mods):
                feature_file = data_path.joinpath(base_filename+"."+mod+".npy")
                if self.input_fix_length_types[i] == "single": continue
                #print(feature_file)
                try:
                    features = np.load(feature_file)
                    length = features.shape[0]
                    #print(features.shape)
                    #print(length)
                    if not fix_lengths:
                        if first_length:
                            length_0 = length
                            first_length=False
                        else:
                            assert length == length_0
                    if length < min_length:
                        # print("Smol sequence "+base_filename+"."+mod+"; ignoring..")
                        file_too_short = True
                        break
                except FileNotFoundError:
                    raise Exception("An unprocessed input feature found "+base_filename+"."+mod+"; need to run preprocessing script before starting to train with them")

            if file_too_short: continue

            first_length=True
            for i, mod in enumerate(output_mods):
                feature_file = data_path.joinpath(base_filename+"."+mod+".npy")
                if self.output_fix_length_types[i] == "single": continue
                try:
                    features = np.load(feature_file)
                    length = features.shape[0]
                    if not fix_lengths:
                        if first_length:
                            length_0 = length
                            first_length=False
                        else:
                            assert length == length_0
                    if length < min_length:
                        # print("Smol sequence "+base_filename+"."+mod+"; ignoring..")
                        file_too_short = True
                        break
                except FileNotFoundError:
                    raise Exception("An unprocessed output feature found "+base_filename+"."+mod+"; need to run preprocessing script before starting to train with them")

            if file_too_short: continue

            for mod in input_mods+output_mods:
                feature_file = data_path.joinpath(base_filename+"."+mod+".npy")
                features = np.load(feature_file)
                self.features[mod][base_filename] = features
                if fix_lengths:
                    self.features_filenames[mod][base_filename] = feature_file

            if fix_lengths:
                shortest_length = 99999999999
                first_match = True
                for mod in input_mods+output_mods:
                    if fix_length_types_dict[mod] == "single": continue
                    length = self.features[mod][base_filename].shape[0]
                    if length < shortest_length:
                        #print(np.abs(length-shortest_length))
                        if first_match:
                            first_match = False
                        else:
                            if np.abs(length-shortest_length) > 2:
                                print("sequence length difference")
                                print(np.abs(length-shortest_length))
                                print(base_filename)
                            #assert np.abs(length-shortest_length) <= 2
                        shortest_length = length
                for i,mod in enumerate(input_mods):
                    if self.input_fix_length_types[i] == "end":
                        np.save(self.features_filenames[mod][base_filename],self.features[mod][base_filename][:shortest_length])
                    elif self.input_fix_length_types[i] == "beg":
                        np.save(self.features_filenames[mod][base_filename],self.features[mod][base_filename][shortest_length:])
                    elif self.input_fix_length_types[i] == "single":
                        assert self.features[mod][base_filename].shape[0] == 1
                    else:
                        raise NotImplementedError("Haven't implemented input_fix_length_type "+self.input_fix_length_type[i])

                for i,mod in enumerate(output_mods):
                    if mod not in input_mods:
                        if self.output_fix_length_types[i] == "end":
                            np.save(self.features_filenames[mod][base_filename],self.features[mod][base_filename][:shortest_length])
                        elif self.output_fix_length_types[i] == "beg":
                            np.save(self.features_filenames[mod][base_filename],self.features[mod][base_filename][shortest_length:])
                        elif self.output_fix_length_types[i] == "single":
                            assert self.features[mod][base_filename].shape[0] == 1
                        else:
                            raise NotImplementedError("Haven't implemented output_fix_length_type "+self.output_fix_length_type[i])

                for mod in input_mods+output_mods:
                    self.features[mod][base_filename] = np.load(self.features_filenames[mod][base_filename])
                    length = self.features[mod][base_filename].shape[0]
                    if i == 0:
                        length_0 = length
                    else:
                        assert length == length_0

            #TODO: implement this!
            ## we pad the song features with zeros to imitate during training what happens during generation
            #x = [np.concatenate((np.zeros(( xx.shape[0],max(0,max(output_time_offsets)) )),xx),0) for xx in x]
            ## we also pad at the end to allow generation to be of the same length of sequence, by padding an amount corresponding to time_offset
            #x = [np.concatenate((xx,np.zeros(( xx.shape[0],max(0,max(input_lengths)+max(input_time_offsets)-(min(output_time_offsets)+min(output_lengths)-1)) ))),0) for xx in x]


            found_full_seq = False
            for i,mod in enumerate(input_mods):
                if self.input_fix_length_types[i] != "single":
                    sequence_length = self.features[mod][base_filename].shape[0]
                    found_full_seq = True
            if not found_full_seq:
                sequence_length = 1
            possible_init_frames = sequence_length-max(max(input_lengths)+max(input_time_offsets),max(output_time_offsets)+max(output_lengths))+1
            self.total_frames += possible_init_frames
            self.frame_cum_sums.append(self.total_frames)

            self.base_filenames.append(base_filename)

        print("sequences added: "+str(len(self.base_filenames)))
        assert len(self.base_filenames)>0, "List of files for training cannot be empty"
        for mod in input_mods+output_mods:
            assert len(self.features[mod].values()) == len(self.base_filenames)

    @staticmethod
    def modify_commandline_options(parser, is_train):
        parser.add_argument('--sampling_rate', default=44100, type=float)
        parser.add_argument('--dins', default=None, help="input dimension for continuous inputs. Embedding dimension for discrete inputs")
        parser.add_argument('--douts', default=None)
        parser.add_argument('--input_modalities', default='mp3_mel_100')
        parser.add_argument('--output_modalities', default='mp3_mel_100')
        parser.add_argument('--input_lengths', help='input sequence length')
        parser.add_argument('--input_num_tokens', help='num_tokens. use 0 for continuous inputs')
        parser.add_argument('--output_num_tokens', help='num_tokens. use 0 for continuous inputs')
        parser.add_argument('--input_types', default=None, help='Comma-separated list of input types: d for discrete, c for continuous. E.g. d,c,c. Assumes continuous if not specified')
        parser.add_argument('--input_fix_length_types', default=None, help='Comma-separated list of approaches to fix length: end for cut end, beg for cut beginning, single for single-element sequence (e.g. sequence-level label). E.g. single,end,end. Assumes cut end if not specified')
        parser.add_argument('--output_fix_length_types', default=None, help='Comma-separated list of approaches to fix length: end for cut end, beg for cut beginning, single for single-element sequence (e.g. sequence-level label). E.g. single,end,end. Assumes cut end if not specified')
        parser.add_argument('--output_lengths', help='output sequence length')
        parser.add_argument('--output_time_offsets', default="1", help='time shift between the last read input, and the output predicted. The default value of 1 corresponds to predicting the next output')
        parser.add_argument('--input_time_offsets', default="0", help='time shift between the beginning of each modality and the first modality')
        parser.add_argument('--max_token_seq_len', type=int, default=1024)
        parser.add_argument('--fix_lengths', action='store_true', help='fix unmatching length of sequences')
        parser.add_argument('--num_train_samples', type=int, default=0, help='if 0 then use all of them')

        return parser

    def name(self):
        return "MultiModalDataset"

    def process_input(self,j,xx,index):
        input_lengths = self.input_lengths
        output_lengths = self.output_lengths
        output_time_offsets = self.output_time_offsets
        input_time_offsets = self.input_time_offsets
        if self.input_fix_length_types[j]!="single":
            return torch.tensor(xx[index+input_time_offsets[j]:index+input_time_offsets[j]+input_lengths[j]]).float()
        else:
            return torch.tensor(xx).long().unsqueeze(1)

    def process_output(self,j,yy,index):
        input_lengths = self.input_lengths
        output_lengths = self.output_lengths
        output_time_offsets = self.output_time_offsets
        input_time_offsets = self.input_time_offsets
        if self.output_fix_length_types[j]!="single":
            return torch.tensor(yy[index+output_time_offsets[j]:index+output_time_offsets[j]+output_lengths[j]]).float()
        else:
            return torch.tensor(yy).long().unsqueeze(1)

    def __getitem__(self, item):
        idx = find_example_idx(item, self.frame_cum_sums)
        base_filename = self.base_filenames[idx]

        input_lengths = self.input_lengths
        output_lengths = self.output_lengths
        output_time_offsets = self.output_time_offsets
        input_time_offsets = self.input_time_offsets

        input_mods = self.opt.input_modalities.split(",")
        output_mods = self.opt.output_modalities.split(",")

        x = [self.features[mod][base_filename] for mod in input_mods]
        y = [self.features[mod][base_filename] for mod in output_mods]
        #for i, mod in enumerate(input_mods):
        #    input_feature = self.features[mod][base_filename]
        #    x.append(input_feature)

        #for i, mod in enumerate(output_mods):
        #    output_feature = self.features[mod][base_filename]
        #    y.append(output_feature)

        # normalization of individual features for the sequence
        # not doing this any more as we are normalizing over all examples now
        #x = [(xx-np.mean(xx,0,keepdims=True))/(np.std(xx,0,keepdims=True)+1e-5) for xx in x]
        #y = [(yy-np.mean(yy,0,keepdims=True))/(np.std(yy,0,keepdims=True)+1e-5) for yy in y]

        if idx > 0: index = item - self.frame_cum_sums[idx-1]
        else: index = item

        ## CONSTRUCT TENSOR OF INPUT FEATURES ##
        input_windows = [self.process_input(j,xx,index) for j,xx in enumerate(x)]

        ## CONSTRUCT TENSOR OF OUTPUT FEATURES ##
        output_windows = [self.process_output(j,yy,index) for j,yy in enumerate(y)]

        # print(input_windows[i])
        return_dict = {}
        for i,mod in enumerate(input_mods):
            return_dict["in_"+mod] = input_windows[i]
        for i,mod in enumerate(output_mods):
            return_dict["out_"+mod] = output_windows[i]

        return return_dict

    def __len__(self):
        # return len(self.base_filenames)
        return self.total_frames
        # return 2


def pairwise(iterable):
    "s -> (s0,s1), (s1,s2), (s2, s3), ..."
    a, b = tee(iterable)
    next(b, None)
    return zip(a, b)