TalkSHOW / data_utils /dataloader_torch.py
feifeifeiliu's picture
first version
865fd8a
raw
history blame
11.7 kB
import sys
import os
sys.path.append(os.getcwd())
import os
from tqdm import tqdm
from data_utils.utils import *
import torch.utils.data as data
from data_utils.mesh_dataset import SmplxDataset
from transformers import Wav2Vec2Processor
class MultiVidData():
def __init__(self,
data_root,
speakers,
split='train',
limbscaling=False,
normalization=False,
norm_method='new',
split_trans_zero=False,
num_frames=25,
num_pre_frames=25,
num_generate_length=None,
aud_feat_win_size=None,
aud_feat_dim=64,
feat_method='mel_spec',
context_info=False,
smplx=False,
audio_sr=16000,
convert_to_6d=False,
expression=False,
config=None
):
self.data_root = data_root
self.speakers = speakers
self.split = split
if split == 'pre':
self.split = 'train'
self.norm_method=norm_method
self.normalization = normalization
self.limbscaling = limbscaling
self.convert_to_6d = convert_to_6d
self.num_frames=num_frames
self.num_pre_frames=num_pre_frames
if num_generate_length is None:
self.num_generate_length = num_frames
else:
self.num_generate_length = num_generate_length
self.split_trans_zero=split_trans_zero
dataset = SmplxDataset
if self.split_trans_zero:
self.trans_dataset_list = []
self.zero_dataset_list = []
else:
self.all_dataset_list = []
self.dataset={}
self.complete_data=[]
self.config=config
load_mode=self.config.dataset_load_mode
######################load with pickle file
if load_mode=='pickle':
import pickle
import subprocess
# store_file_path='/tmp/store.pkl'
# cp /is/cluster/scratch/hyi/ExpressiveBody/SMPLifyX4/scripts/store.pkl /tmp/store.pkl
# subprocess.run(f'cp /is/cluster/scratch/hyi/ExpressiveBody/SMPLifyX4/scripts/store.pkl {store_file_path}',shell=True)
# f = open(self.config.store_file_path, 'rb+')
f = open(self.split+config.Data.pklname, 'rb+')
self.dataset=pickle.load(f)
f.close()
for key in self.dataset:
self.complete_data.append(self.dataset[key].complete_data)
######################load with pickle file
######################load with a csv file
elif load_mode=='csv':
# 这里从我的一个code文件夹导入的,后续再完善进来
try:
sys.path.append(self.config.config_root_path)
from config import config_path
from csv_parser import csv_parse
except ImportError as e:
print(f'err: {e}')
raise ImportError('config root path error...')
for speaker_name in self.speakers:
# df_intervals=pd.read_csv(self.config.voca_csv_file_path)
df_intervals=None
df_intervals=df_intervals[df_intervals['speaker']==speaker_name]
df_intervals = df_intervals[df_intervals['dataset'] == self.split]
print(f'speaker {speaker_name} train interval length: {len(df_intervals)}')
for iter_index, (_, interval) in tqdm(
(enumerate(df_intervals.iterrows())),desc=f'load {speaker_name}'
):
(
interval_index,
interval_speaker,
interval_video_fn,
interval_id,
start_time,
end_time,
duration_time,
start_time_10,
over_flow_flag,
short_dur_flag,
big_video_dir,
small_video_dir_name,
speaker_video_path,
voca_basename,
json_basename,
wav_basename,
voca_top_clip_path,
voca_json_clip_path,
voca_wav_clip_path,
audio_output_fn,
image_output_path,
pifpaf_output_path,
mp_output_path,
op_output_path,
deca_output_path,
pixie_output_path,
cam_output_path,
ours_output_path,
merge_output_path,
multi_output_path,
gt_output_path,
ours_images_path,
pkl_fil_path,
)=csv_parse(interval)
if not os.path.exists(pkl_fil_path) or not os.path.exists(audio_output_fn):
continue
key=f'{interval_video_fn}/{small_video_dir_name}'
self.dataset[key] = dataset(
data_root=pkl_fil_path,
speaker=speaker_name,
audio_fn=audio_output_fn,
audio_sr=audio_sr,
fps=num_frames,
feat_method=feat_method,
audio_feat_dim=aud_feat_dim,
train=(self.split == 'train'),
load_all=True,
split_trans_zero=self.split_trans_zero,
limbscaling=self.limbscaling,
num_frames=self.num_frames,
num_pre_frames=self.num_pre_frames,
num_generate_length=self.num_generate_length,
audio_feat_win_size=aud_feat_win_size,
context_info=context_info,
convert_to_6d=convert_to_6d,
expression=expression,
config=self.config
)
self.complete_data.append(self.dataset[key].complete_data)
######################load with a csv file
######################origin load method
elif load_mode=='json':
# if self.split == 'train':
# import pickle
# f = open('store.pkl', 'rb+')
# self.dataset=pickle.load(f)
# f.close()
# for key in self.dataset:
# self.complete_data.append(self.dataset[key].complete_data)
# else:https://pytorch-tutorial-assets.s3.amazonaws.com/VOiCES_devkit/source-16k/train/sp0307/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav
# if config.Model.model_type == 'face':
am = Wav2Vec2Processor.from_pretrained("vitouphy/wav2vec2-xls-r-300m-phoneme")
am_sr = 16000
# else:
# am, am_sr = None, None
for speaker_name in self.speakers:
speaker_root = os.path.join(self.data_root, speaker_name)
videos=[v for v in os.listdir(speaker_root) ]
print(videos)
haode = huaide = 0
for vid in tqdm(videos, desc="Processing training data of {}......".format(speaker_name)):
source_vid=vid
# vid_pth=os.path.join(speaker_root, source_vid, 'images/half', self.split)
vid_pth = os.path.join(speaker_root, source_vid, self.split)
if smplx == 'pose':
seqs = [s for s in os.listdir(vid_pth) if (s.startswith('clip'))]
else:
try:
seqs = [s for s in os.listdir(vid_pth)]
except:
continue
for s in seqs:
seq_root=os.path.join(vid_pth, s)
key = seq_root # correspond to clip******
audio_fname = os.path.join(speaker_root, source_vid, self.split, s, '%s.wav' % (s))
motion_fname = os.path.join(speaker_root, source_vid, self.split, s, '%s.pkl' % (s))
if not os.path.isfile(audio_fname) or not os.path.isfile(motion_fname):
huaide = huaide + 1
continue
self.dataset[key]=dataset(
data_root=seq_root,
speaker=speaker_name,
motion_fn=motion_fname,
audio_fn=audio_fname,
audio_sr=audio_sr,
fps=num_frames,
feat_method=feat_method,
audio_feat_dim=aud_feat_dim,
train=(self.split=='train'),
load_all=True,
split_trans_zero=self.split_trans_zero,
limbscaling=self.limbscaling,
num_frames=self.num_frames,
num_pre_frames=self.num_pre_frames,
num_generate_length=self.num_generate_length,
audio_feat_win_size=aud_feat_win_size,
context_info=context_info,
convert_to_6d=convert_to_6d,
expression=expression,
config=self.config,
am=am,
am_sr=am_sr,
whole_video=config.Data.whole_video
)
self.complete_data.append(self.dataset[key].complete_data)
haode = haode + 1
print("huaide:{}, haode:{}".format(huaide, haode))
import pickle
f = open(self.split+config.Data.pklname, 'wb')
pickle.dump(self.dataset, f)
f.close()
######################origin load method
self.complete_data=np.concatenate(self.complete_data, axis=0)
# assert self.complete_data.shape[-1] == (12+21+21)*2
self.normalize_stats = {}
self.data_mean = None
self.data_std = None
def get_dataset(self):
self.normalize_stats['mean'] = self.data_mean
self.normalize_stats['std'] = self.data_std
for key in list(self.dataset.keys()):
if self.dataset[key].complete_data.shape[0] < self.num_generate_length:
continue
self.dataset[key].num_generate_length = self.num_generate_length
self.dataset[key].get_dataset(self.normalization, self.normalize_stats, self.split)
self.all_dataset_list.append(self.dataset[key].all_dataset)
if self.split_trans_zero:
self.trans_dataset = data.ConcatDataset(self.trans_dataset_list)
self.zero_dataset = data.ConcatDataset(self.zero_dataset_list)
else:
self.all_dataset = data.ConcatDataset(self.all_dataset_list)