import sys import os sys.path.append(os.getcwd()) import os from tqdm import tqdm from data_utils.utils import * import torch.utils.data as data from data_utils.mesh_dataset import SmplxDataset from transformers import Wav2Vec2Processor class MultiVidData(): def __init__(self, data_root, speakers, split='train', limbscaling=False, normalization=False, norm_method='new', split_trans_zero=False, num_frames=25, num_pre_frames=25, num_generate_length=None, aud_feat_win_size=None, aud_feat_dim=64, feat_method='mel_spec', context_info=False, smplx=False, audio_sr=16000, convert_to_6d=False, expression=False, config=None ): self.data_root = data_root self.speakers = speakers self.split = split if split == 'pre': self.split = 'train' self.norm_method=norm_method self.normalization = normalization self.limbscaling = limbscaling self.convert_to_6d = convert_to_6d self.num_frames=num_frames self.num_pre_frames=num_pre_frames if num_generate_length is None: self.num_generate_length = num_frames else: self.num_generate_length = num_generate_length self.split_trans_zero=split_trans_zero dataset = SmplxDataset if self.split_trans_zero: self.trans_dataset_list = [] self.zero_dataset_list = [] else: self.all_dataset_list = [] self.dataset={} self.complete_data=[] self.config=config load_mode=self.config.dataset_load_mode ######################load with pickle file if load_mode=='pickle': import pickle import subprocess # store_file_path='/tmp/store.pkl' # cp /is/cluster/scratch/hyi/ExpressiveBody/SMPLifyX4/scripts/store.pkl /tmp/store.pkl # subprocess.run(f'cp /is/cluster/scratch/hyi/ExpressiveBody/SMPLifyX4/scripts/store.pkl {store_file_path}',shell=True) # f = open(self.config.store_file_path, 'rb+') f = open(self.split+config.Data.pklname, 'rb+') self.dataset=pickle.load(f) f.close() for key in self.dataset: self.complete_data.append(self.dataset[key].complete_data) ######################load with pickle file ######################load with a csv file elif load_mode=='csv': # 这里从我的一个code文件夹导入的,后续再完善进来 try: sys.path.append(self.config.config_root_path) from config import config_path from csv_parser import csv_parse except ImportError as e: print(f'err: {e}') raise ImportError('config root path error...') for speaker_name in self.speakers: # df_intervals=pd.read_csv(self.config.voca_csv_file_path) df_intervals=None df_intervals=df_intervals[df_intervals['speaker']==speaker_name] df_intervals = df_intervals[df_intervals['dataset'] == self.split] print(f'speaker {speaker_name} train interval length: {len(df_intervals)}') for iter_index, (_, interval) in tqdm( (enumerate(df_intervals.iterrows())),desc=f'load {speaker_name}' ): ( interval_index, interval_speaker, interval_video_fn, interval_id, start_time, end_time, duration_time, start_time_10, over_flow_flag, short_dur_flag, big_video_dir, small_video_dir_name, speaker_video_path, voca_basename, json_basename, wav_basename, voca_top_clip_path, voca_json_clip_path, voca_wav_clip_path, audio_output_fn, image_output_path, pifpaf_output_path, mp_output_path, op_output_path, deca_output_path, pixie_output_path, cam_output_path, ours_output_path, merge_output_path, multi_output_path, gt_output_path, ours_images_path, pkl_fil_path, )=csv_parse(interval) if not os.path.exists(pkl_fil_path) or not os.path.exists(audio_output_fn): continue key=f'{interval_video_fn}/{small_video_dir_name}' self.dataset[key] = dataset( data_root=pkl_fil_path, speaker=speaker_name, audio_fn=audio_output_fn, audio_sr=audio_sr, fps=num_frames, feat_method=feat_method, audio_feat_dim=aud_feat_dim, train=(self.split == 'train'), load_all=True, split_trans_zero=self.split_trans_zero, limbscaling=self.limbscaling, num_frames=self.num_frames, num_pre_frames=self.num_pre_frames, num_generate_length=self.num_generate_length, audio_feat_win_size=aud_feat_win_size, context_info=context_info, convert_to_6d=convert_to_6d, expression=expression, config=self.config ) self.complete_data.append(self.dataset[key].complete_data) ######################load with a csv file ######################origin load method elif load_mode=='json': # if self.split == 'train': # import pickle # f = open('store.pkl', 'rb+') # self.dataset=pickle.load(f) # f.close() # for key in self.dataset: # self.complete_data.append(self.dataset[key].complete_data) # else:https://pytorch-tutorial-assets.s3.amazonaws.com/VOiCES_devkit/source-16k/train/sp0307/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav # if config.Model.model_type == 'face': am = Wav2Vec2Processor.from_pretrained("vitouphy/wav2vec2-xls-r-300m-phoneme") am_sr = 16000 # else: # am, am_sr = None, None for speaker_name in self.speakers: speaker_root = os.path.join(self.data_root, speaker_name) videos=[v for v in os.listdir(speaker_root) ] print(videos) haode = huaide = 0 for vid in tqdm(videos, desc="Processing training data of {}......".format(speaker_name)): source_vid=vid # vid_pth=os.path.join(speaker_root, source_vid, 'images/half', self.split) vid_pth = os.path.join(speaker_root, source_vid, self.split) if smplx == 'pose': seqs = [s for s in os.listdir(vid_pth) if (s.startswith('clip'))] else: try: seqs = [s for s in os.listdir(vid_pth)] except: continue for s in seqs: seq_root=os.path.join(vid_pth, s) key = seq_root # correspond to clip****** audio_fname = os.path.join(speaker_root, source_vid, self.split, s, '%s.wav' % (s)) motion_fname = os.path.join(speaker_root, source_vid, self.split, s, '%s.pkl' % (s)) if not os.path.isfile(audio_fname) or not os.path.isfile(motion_fname): huaide = huaide + 1 continue self.dataset[key]=dataset( data_root=seq_root, speaker=speaker_name, motion_fn=motion_fname, audio_fn=audio_fname, audio_sr=audio_sr, fps=num_frames, feat_method=feat_method, audio_feat_dim=aud_feat_dim, train=(self.split=='train'), load_all=True, split_trans_zero=self.split_trans_zero, limbscaling=self.limbscaling, num_frames=self.num_frames, num_pre_frames=self.num_pre_frames, num_generate_length=self.num_generate_length, audio_feat_win_size=aud_feat_win_size, context_info=context_info, convert_to_6d=convert_to_6d, expression=expression, config=self.config, am=am, am_sr=am_sr, whole_video=config.Data.whole_video ) self.complete_data.append(self.dataset[key].complete_data) haode = haode + 1 print("huaide:{}, haode:{}".format(huaide, haode)) import pickle f = open(self.split+config.Data.pklname, 'wb') pickle.dump(self.dataset, f) f.close() ######################origin load method self.complete_data=np.concatenate(self.complete_data, axis=0) # assert self.complete_data.shape[-1] == (12+21+21)*2 self.normalize_stats = {} self.data_mean = None self.data_std = None def get_dataset(self): self.normalize_stats['mean'] = self.data_mean self.normalize_stats['std'] = self.data_std for key in list(self.dataset.keys()): if self.dataset[key].complete_data.shape[0] < self.num_generate_length: continue self.dataset[key].num_generate_length = self.num_generate_length self.dataset[key].get_dataset(self.normalization, self.normalize_stats, self.split) self.all_dataset_list.append(self.dataset[key].all_dataset) if self.split_trans_zero: self.trans_dataset = data.ConcatDataset(self.trans_dataset_list) self.zero_dataset = data.ConcatDataset(self.zero_dataset_list) else: self.all_dataset = data.ConcatDataset(self.all_dataset_list)