File size: 9,725 Bytes
f1ea451
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
import os
import librosa
from PIL import Image
from torchvision import transforms
import python_speech_features
import random
import os
import numpy as np
from tqdm import tqdm
import torchvision
import torchvision.transforms as transforms
from PIL import Image

class LatentDataLoader(object):

    def __init__(
        self, 
        window_size, 
        frame_jpgs, 
        lmd_feats_prefix, 
        audio_prefix, 
        raw_audio_prefix,
        motion_latents_prefix, 
        pose_prefix, 
        db_name, 
        video_fps=25, 
        audio_hz=50, 
        size=256, 
        mfcc_mode=False,
    ):
        self.window_size = window_size
        self.lmd_feats_prefix = lmd_feats_prefix
        self.audio_prefix = audio_prefix
        self.pose_prefix = pose_prefix
        self.video_fps = video_fps
        self.audio_hz = audio_hz
        self.db_name = db_name
        self.raw_audio_prefix = raw_audio_prefix
        self.mfcc_mode = mfcc_mode
        

        self.transform = torchvision.transforms.Compose([
            transforms.Resize((size, size)),
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))]
        )
    
        self.data = []
        for db_name in [ 'VoxCeleb2', 'HDTF' ]:
            db_png_path = os.path.join(frame_jpgs, db_name)
            for clip_name in tqdm(os.listdir(db_png_path)):

                item_dict = dict()
                item_dict['clip_name'] = clip_name
                item_dict['frame_count'] =  len(list(os.listdir(os.path.join(frame_jpgs, db_name, clip_name))))
                item_dict['hubert_path'] = os.path.join(audio_prefix, db_name, clip_name +".npy")
                item_dict['wav_path'] = os.path.join(raw_audio_prefix, db_name, clip_name +".wav")
                
                item_dict['yaw_pitch_roll_path'] = os.path.join(pose_prefix, db_name, 'raw_videos_pose_yaw_pitch_roll', clip_name +".npy")
                if not os.path.exists(item_dict['yaw_pitch_roll_path']):
                    print(f"{db_name}'s {clip_name} miss yaw_pitch_roll_path")
                    continue

                item_dict['yaw_pitch_roll'] = np.load(item_dict['yaw_pitch_roll_path'])
                item_dict['yaw_pitch_roll'] = np.clip(item_dict['yaw_pitch_roll'], -90, 90) / 90.0
                
                if not os.path.exists(item_dict['wav_path']):
                    print(f"{db_name}'s {clip_name} miss wav_path")
                    continue 
                
                if not os.path.exists(item_dict['hubert_path']):
                    print(f"{db_name}'s {clip_name} miss hubert_path")
                    continue 
                
                
                if self.mfcc_mode:
                    wav, sr = librosa.load(item_dict['wav_path'], sr=16000)
                    input_values = python_speech_features.mfcc(signal=wav,samplerate=sr,numcep=13,winlen=0.025,winstep=0.01)
                    d_mfcc_feat = python_speech_features.base.delta(input_values, 1)
                    d_mfcc_feat2 = python_speech_features.base.delta(input_values, 2)
                    input_values = np.hstack((input_values, d_mfcc_feat, d_mfcc_feat2))
                    item_dict['hubert_obj'] = input_values
                else:
                    item_dict['hubert_obj'] = np.load(item_dict['hubert_path'], mmap_mode='r')
                item_dict['lmd_path'] = os.path.join(lmd_feats_prefix, db_name, clip_name +".txt")
                item_dict['lmd_obj_full'] = self.read_landmark_info(item_dict['lmd_path'], upper_face=False)

                motion_start_path = os.path.join(motion_latents_prefix, db_name, 'motions', clip_name +".npy")
                motion_direction_path = os.path.join(motion_latents_prefix, db_name, 'directions',  clip_name +".npy")
                
                if not os.path.exists(motion_start_path):
                    print(f"{db_name}'s {clip_name} miss motion_start_path")
                    continue
                if not os.path.exists(motion_direction_path):
                    print(f"{db_name}'s {clip_name} miss motion_direction_path")
                    continue

                item_dict['motion_start_obj'] = np.load(motion_start_path)
                item_dict['motion_direction_obj'] = np.load(motion_direction_path)                

                if self.mfcc_mode:
                    min_len = min(
                        item_dict['lmd_obj_full'].shape[0],
                        item_dict['yaw_pitch_roll'].shape[0],
                        item_dict['motion_start_obj'].shape[0],
                        item_dict['motion_direction_obj'].shape[0],
                        int(item_dict['hubert_obj'].shape[0]/4),
                        item_dict['frame_count']
                    )
                    item_dict['frame_count'] = min_len
                    item_dict['hubert_obj'] =  item_dict['hubert_obj'][:min_len*4,:]
                else:
                    min_len = min(
                        item_dict['lmd_obj_full'].shape[0],
                        item_dict['yaw_pitch_roll'].shape[0],
                        item_dict['motion_start_obj'].shape[0],
                        item_dict['motion_direction_obj'].shape[0],
                        int(item_dict['hubert_obj'].shape[1]/2),
                        item_dict['frame_count']
                    )
                
                    item_dict['frame_count'] = min_len
                    item_dict['hubert_obj'] =  item_dict['hubert_obj'][:, :min_len*2, :]
                
                if min_len < self.window_size * self.video_fps + 5:
                    continue

        print('Db count:', len(self.data))

    def get_single_image(self, image_path):
        img_source = Image.open(image_path).convert('RGB')
        img_source = self.transform(img_source)
        return img_source

    def get_multiple_ranges(self, lists, multi_ranges):
        # Ensure that multi_ranges is a list of tuples
        if not all(isinstance(item, tuple) and len(item) == 2 for item in multi_ranges):
            raise ValueError("multi_ranges must be a list of (start, end) tuples with exactly two elements each")
        extracted_elements = [lists[start:end] for start, end in multi_ranges]
        flat_list = [item for sublist in extracted_elements for item in sublist]
        return flat_list
    
    
    def read_landmark_info(self, lmd_path, upper_face=True):
        with open(lmd_path, 'r') as file:
            lmd_lines = file.readlines()
        lmd_lines.sort()
        
        total_lmd_obj = []
        for i, line in enumerate(lmd_lines):
            # Split the coordinates and filter out any empty strings
            coords = [c for c in line.strip().split(' ') if c]
            coords = coords[1:] # do not include the file name in the first row
            lmd_obj = []
            if upper_face:
                # Ensure that the coordinates are parsed as integers
                for coord_pair in self.get_multiple_ranges(coords, [(0, 3), (14, 27), (36, 48)]): # 28个
                    x, y = coord_pair.split('_')
                    lmd_obj.append((int(x)/512, int(y)/512))
            else:
                for coord_pair in coords:
                    x, y = coord_pair.split('_')
                    lmd_obj.append((int(x)/512, int(y)/512))
            total_lmd_obj.append(lmd_obj)
        
        return np.array(total_lmd_obj, dtype=np.float32)

    def calculate_face_height(self, landmarks):
        forehead_center = (landmarks[ :, 21, :] + landmarks[:, 22, :]) / 2
        chin_bottom = landmarks[:, 8, :]
        distances = np.linalg.norm(forehead_center - chin_bottom, axis=1, keepdims=True)
        return distances
            
    def __getitem__(self, index):
        
        data_item = self.data[index]
        hubert_obj = data_item['hubert_obj']
        frame_count = data_item['frame_count']
        lmd_obj_full = data_item['lmd_obj_full']
        yaw_pitch_roll = data_item['yaw_pitch_roll']
        motion_start_obj = data_item['motion_start_obj']
        motion_direction_obj = data_item['motion_direction_obj']
 
        frame_end_index = random.randint(self.window_size * self.video_fps + 1, frame_count - 1)
        frame_start_index = frame_end_index - self.window_size * self.video_fps
        frame_hint_index = frame_start_index - 1
        
        audio_start_index = int(frame_start_index * (self.audio_hz / self.video_fps))
        audio_end_index = int(frame_end_index * (self.audio_hz / self.video_fps))
        
        if self.mfcc_mode:
            audio_feats = hubert_obj[audio_start_index:audio_end_index, :]
        else:
            audio_feats = hubert_obj[:, audio_start_index:audio_end_index, :]

        lmd_obj_full = lmd_obj_full[frame_hint_index:frame_end_index, :]
        
        yaw_pitch_roll = yaw_pitch_roll[frame_start_index:frame_end_index, :]

        motion_start = motion_start_obj[frame_hint_index]
        motion_direction_start = motion_direction_obj[frame_hint_index]
        motion_direction = motion_direction_obj[frame_start_index:frame_end_index, :]



        return {
            'motion_start': motion_start,
            'motion_direction': motion_direction,
            'audio_feats': audio_feats,
            'face_location': lmd_obj_full[1:, 30, 0], # '1:' means taking the first frame as the driven frame. '30' is the noise location, '0' means x coordinate
            'face_scale': self.calculate_face_height(lmd_obj_full[1:,:,:]),
            'yaw_pitch_roll': yaw_pitch_roll,
            'motion_direction_start': motion_direction_start,
        }
  
    def __len__(self):
        return len(self.data)