Spaces:
Runtime error
Runtime error
# generate samples for evaluation & visualization | |
import torch | |
import numpy as np | |
from tqdm import tqdm | |
from argparse import ArgumentParser | |
from omegaconf import OmegaConf | |
import pickle | |
import os | |
import sys | |
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'src'))) | |
print(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'src'))) | |
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) | |
from inference import inference | |
from utils.inference_utils import set_all_seeds, fix_state_dict, load_hint_texts_from_file, load_mask_from_file, load_file_names, gen_prog_ind | |
from model.gaussian_diffusion import GaussianDiffusion | |
from model.unet import Unet | |
from utils.normalize import set_up_normalization | |
from utils.constants import TO_24 | |
set_all_seeds(135) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
import clip | |
text_embedder, _ = clip.load("ViT-B/32", device=device) | |
text_embedder.eval() | |
def print_config(config): | |
print(OmegaConf.to_yaml(config)) | |
def getmodel(model_used, device, model_root, use_step=False, is_disc=False, config=None): | |
model = Unet( | |
dim_model=config.dim_model, | |
num_heads=config.num_heads, | |
num_layers=config.num_layers, | |
dropout_p=config.dropout_p, | |
dim_input=config.dim_input, | |
dim_output=config.dim_output, | |
text_emb=config.text_emb, | |
device=device, | |
Disc = is_disc, | |
).to(device) | |
model_path = os.path.join(model_root, f'model_h3d_epoch{model_used}.pth') | |
if use_step: | |
model_path = os.path.join(model_root, f'model_h3d_step{model_used}.pth') | |
print("==>", model_path) | |
if torch.cuda.is_available(): | |
state_dict = torch.load(model_path) | |
else: | |
state_dict = torch.load(model_path, map_location=torch.device('cpu')) | |
fixed_state_dict = fix_state_dict(state_dict)['model_state_dict'] | |
fixed_state_dict = fix_state_dict(fixed_state_dict) | |
model.load_state_dict(fixed_state_dict) | |
model.eval() | |
return model | |
if __name__ == '__main__': | |
""" | |
args: | |
- task: "regen", "style_transfer", "adjustment" | |
""" | |
parser = ArgumentParser() | |
parser.add_argument('--task', type=str, default='regen') | |
args = parser.parse_args() | |
task_config = OmegaConf.load(f"configs/inference/{args.task}.yaml") | |
base_config = OmegaConf.load("configs/base.yaml") | |
config = OmegaConf.merge(base_config, task_config) | |
text_path = os.path.join(project_root, config.test_data_path, config.text_path) | |
mask_path = os.path.join(project_root, config.test_data_path, config.mask_path) | |
joints_src_path = os.path.join(project_root, config.test_data_path, config.joints_src_path) | |
gen_file_names_path = os.path.join(project_root, config.test_data_path, config.gen_file_names_path) | |
hint_text_all = load_hint_texts_from_file(text_path) | |
mask_all = load_mask_from_file(mask_path) | |
gen_file_names = load_file_names(gen_file_names_path) | |
joints_orig_all = torch.tensor(np.load(joints_src_path), dtype=torch.float32, device=device) | |
prog_ind_all = gen_prog_ind(num_cases=len(hint_text_all), sublist_length = 4)#sublist_length=config.sublist_length) | |
models = { | |
'model': getmodel(config.model_used, | |
device=device, | |
model_root=os.path.join(project_root, config.model_path, config.task), | |
use_step=False, | |
is_disc=False, | |
config = config.unet, | |
), | |
'disc_model': getmodel(config.disc_model_used, | |
device=device, | |
model_root=os.path.join(project_root, config.disc_model_path, config.task), | |
use_step=True, | |
is_disc=True, | |
config = config.unet, | |
), | |
} | |
diffuser = GaussianDiffusion(device=device, | |
fix_mode=config.diffusion.fix_mode, | |
text_emb=config.diffusion.text_emb, | |
fixed_frames=config.diffusion.fixed_frames, | |
seq_len=config.diffusion.seq_len, | |
timesteps=config.diffusion.timesteps, | |
beta_schedule=config.diffusion.beta_schedule) | |
normalize, denormalize = set_up_normalization(device=device, seq_len=config.seq_len, scale=3) | |
joints_orig = normalize(joints_orig_all) | |
test_configs = { | |
'batch_size': config.batch_size, | |
'seq_len': config.seq_len, | |
'channels': config.channels, | |
'fixed_frame': config.fixed_frame, | |
'use_cfg': config.use_cfg, | |
'cfg_alpha': config.cfg_alpha, | |
'cg_alpha': config.cg_alpha, | |
'cg_diffusion_steps': config.cg_diffusion_steps, | |
} | |
for i in tqdm(range(len(hint_text_all))): | |
generated_samples, orig = inference.test_model( | |
models=models, | |
diffuser=diffuser, | |
normalizer=(normalize, denormalize), | |
configs=test_configs, | |
text_embedder=text_embedder, | |
hint_text=hint_text_all[i], | |
prog_ind=prog_ind_all[i], | |
joint_orig=joints_orig[i] | |
) | |
# only consider 24 joints instaed of 28 | |
generated_samples = generated_samples.reshape(1, -1, config.joints_num, 3)[..., TO_24, :].reshape(1, -1, 72) | |
orig = orig.reshape(1, -1, config.joints_num, 3)[..., TO_24, :].reshape(1, -1, 72) | |
combined_dict = { | |
'generated_samples': generated_samples, | |
'original_samples': orig, | |
'text' : hint_text_all[i][0] + f"{i}", | |
'mask' : mask_all[i] | |
} | |
save_pth = os.path.join(project_root, config.save_path) | |
if not os.path.exists(save_pth): | |
os.makedirs(save_pth) | |
with open(os.path.join(save_pth, f'{gen_file_names[i]}.pkl'), 'wb') as file: | |
pickle.dump(combined_dict, file) |