File size: 3,929 Bytes
e34aada
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import glob
import json
import os

import pickle
import random
import re
import subprocess
from functools import partial

import librosa.core
import numpy as np
import torch
import torch.distributions
import torch.distributed as dist
import torch.optim
import torch.utils.data

from utils.commons.indexed_datasets import IndexedDataset
from torch.utils.data import Dataset, DataLoader

import torch.nn.functional as F
import pandas as pd
from tqdm import tqdm
import csv
from utils.commons.hparams import hparams, set_hparams
from utils.commons.meters import Timer
from data_util.face3d_helper import Face3DHelper
from utils.audio import librosa_wav2mfcc
from utils.commons.dataset_utils import collate_xd


class SyncNet_Dataset(Dataset):
    def __init__(self, prefix='train', data_dir=None):
        self.hparams = hparams
        self.db_key = prefix
        self.ds_path = self.hparams['binary_data_dir'] if data_dir is None else data_dir
        self.ds = None
        self.sizes = None
        self.x_maxframes = 200 # 50 video frames
        self.face3d_helper = Face3DHelper('deep_3drecon/BFM')
        self.x_multiply = 8

    def __len__(self):
        ds = self.ds = IndexedDataset(f'{self.ds_path}/{self.db_key}')
        return len(ds)

    def _get_item(self, index):
        """
        This func is necessary to open files in multi-threads!
        """
        if self.ds is None:
            self.ds = IndexedDataset(f'{self.ds_path}/{self.db_key}')
        return self.ds[index]
    
    def __getitem__(self, idx):
        raw_item = self._get_item(idx)
        if raw_item is None:
            print("loading from binary data failed!")
            return None
        item = {
            'idx': idx,
            'item_id': raw_item['img_dir'],
            'id': torch.from_numpy(raw_item['id']).float(), # [T_x, c=80]
            'exp': torch.from_numpy(raw_item['exp']).float(), # [T_x, c=80]
        }
        if item['id'].shape[0] == 1: # global_id
            item['id'] = item['id'].repeat([item['exp'].shape[0], 1])
        item['hubert'] = torch.from_numpy(raw_item['hubert']).float() # [T_x, 1024]
        x_len = len(item['hubert'])
        y_len = x_len // 2 # video is 25fps
        item['id'] = item['id'][:y_len]
        item['exp'] = item['exp'][:y_len]
        
        # randomly select a fixed-length clip
        start_frames = random.randint(0,  max(0, x_len - self.x_maxframes))
        start_frames = start_frames // 2 * 2
        item['hubert'] = item['hubert'][start_frames: start_frames + self.x_maxframes]
        item['id'] = item['id'][start_frames//2: start_frames//2 + self.x_maxframes//2]
        item['exp'] = item['exp'][start_frames//2: start_frames//2 + self.x_maxframes//2]
        return item


    def get_dataloader(self, batch_size=1, num_workers=0):
        loader = DataLoader(self, pin_memory=True,collate_fn=self.collater, batch_size=batch_size, num_workers=num_workers)
        return loader


    def collater(self, samples):
        if len(samples) == 0:
            return None
        x_len = max(s['hubert'].size(0) for s in samples)
        y_len = x_len // 2
        batch = {
            'item_id': [s['item_id'] for s in samples],
        }
        batch['hubert'] = collate_xd([s["hubert"] for s in samples], max_len=x_len, pad_idx=0) # [b, t_max_y, 64]
        batch['x_mask'] = (batch['hubert'].abs().sum(dim=-1) > 0).float() # [b, t_max_x]

        batch['id'] = collate_xd([s["id"] for s in samples], max_len=y_len, pad_idx=0) # [b, t_max, 1]
        batch['exp'] = collate_xd([s["exp"] for s in samples], max_len=y_len, pad_idx=0) # [b, t_max, 1]
        batch['y_mask'] = (batch['id'].abs().sum(dim=-1) > 0).float() # [b, t_max_y]
        return batch


if __name__ == '__main__':
    os.environ["OMP_NUM_THREADS"] = "1"

    ds = SyncNet_Dataset("train", 'data/binary/th1kh')
    dl = ds.get_dataloader()
    for b in tqdm(dl):
        pass