Spaces:
Running
on
Zero
Running
on
Zero
from sim.robomimic.robomimic_runner import RolloutRunner | |
from sim.policy import GeniePolicy | |
import argparse | |
from datetime import datetime | |
current_date = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") | |
if __name__ == "__main__": | |
# initialize environment | |
parser = argparse.ArgumentParser(description="policy to evaluate") | |
# Data | |
parser.add_argument( "--env_name", type=str, default="lift") | |
parser.add_argument( "--num_runs", type=int, default=1) | |
parser.add_argument( "--save_video", action="store_true") | |
parser.add_argument( "--model", type=str, default="data/mar_policy_dynamics/step_30000") | |
parser.add_argument( "--use_magvit", action="store_true") | |
parser.add_argument( "--is_full_dynamics", action="store_true") | |
parser.add_argument( "--use_raw_image", action="store_true") | |
parser.add_argument( "--execution_horizon", type=int, default=4) | |
parser.add_argument( "--diffusion_steps", type=int, default=100) | |
parser.add_argument( "--inference_iterations", type=int, default=1) | |
parser.add_argument( "--prompt_horizon", type=int, default=1) | |
args = parser.parse_args() | |
env_name = args.env_name | |
rollout_runner = RolloutRunner( env_names=[env_name], episode_num=args.num_runs, save_video=args.save_video) | |
execution_horizon = args.execution_horizon | |
diffusion_steps = args.diffusion_steps | |
inference_iterations = args.inference_iterations | |
prompt_horizon = args.prompt_horizon | |
is_full_dynamics = args.is_full_dynamics | |
model = args.model | |
if is_full_dynamics: | |
# model = "data/mar_policy_dynamics2/final2_robomimic_scratch_mar_forward_dynamics_gpu_8_nodes_2_16g/step_50000" | |
# model = "data/final2_robomimic_scratch_mar_full_dynamics_new_gpu_8_nodes_4_16g/step_10000" | |
# model = "data/final2_robomimic_scratch_mar_dynamics_fullpastmask_new_gpu_8_nodes_4_16g/step_10000" | |
# model = "data/final2_robomimic_scratch_mar_full_dynamics_fixed_new_gpu_8_nodes_4_16g/step_20000" | |
model_suffix = f"dynamics_{model.split('/')[-2]}_{model.split('/')[-1]}_horizon{execution_horizon}" | |
else: | |
# model = "data/mar_policy2/final2_robomimic_scratch_mar_actiononly_gpu_8_nodes_4_16g/final_checkpt" | |
# model = "data/final2_robomimic_scratch_mar_actiononly_new_gpu_8_nodes_4_16g/final2_robomimic_scratch_mar_actiononly_new_gpu_8_nodes_4_16g/step_10000" | |
# model = "data/final2_robomimic_scratch_mar_fullpastmask_actiononly_fixed_gpu_8_nodes_4_16g/step_10000" | |
# model = "data/final2_robomimic_scratch_mar_fullpastmask_actiononly_fixed_gpu_8_nodes_4_16g/step_10000" | |
# model = "data/mar_policy_actiononly3/step_10000" | |
model_suffix = f"actiononly_{model.split('/')[-2]}_{model.split('/')[-1]}_horizon{execution_horizon}" | |
policy = GeniePolicy( | |
image_encoder_type="temporalvae" if not args.use_magvit else "magvit", | |
image_encoder_ckpt="stabilityai/stable-video-diffusion-img2vid" if not args.use_magvit else "data/magvit2.ckpt", | |
quantize=False if not args.use_magvit else True, | |
backbone_type="stmar" if not args.use_magvit else "stmaskgit", | |
backbone_ckpt=model, | |
prompt_horizon=prompt_horizon, # history step | |
prediction_horizon=execution_horizon, # future step | |
execution_horizon=execution_horizon, # open loop step | |
inference_iterations=inference_iterations, # maskgit step | |
diffusion_steps=diffusion_steps, # diffusion steps | |
action_stride=1, | |
domain="robomimic", | |
is_full_dynamics=is_full_dynamics, | |
use_raw_image=args.use_raw_image, | |
) | |
# initialize policy | |
success, reward = rollout_runner.run(policy=policy, env_name=[env_name], video_postfix=model_suffix) | |
print(f"success: {success}, reward: {reward}") | |
# dump the success with model name to csv | |
with open("success.csv", "a+") as f: | |
f.write(f"{model_suffix}, {success}, {reward}, {execution_horizon}, {diffusion_steps}, {inference_iterations}, {args.prompt_horizon}, {args.num_runs}, {current_date}\n") | |