from sim.robomimic.robomimic_runner import RolloutRunner from sim.policy import GeniePolicy import argparse from datetime import datetime current_date = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") if __name__ == "__main__": # initialize environment parser = argparse.ArgumentParser(description="policy to evaluate") # Data parser.add_argument( "--env_name", type=str, default="lift") parser.add_argument( "--num_runs", type=int, default=1) parser.add_argument( "--save_video", action="store_true") parser.add_argument( "--model", type=str, default="data/mar_policy_dynamics/step_30000") parser.add_argument( "--use_magvit", action="store_true") parser.add_argument( "--is_full_dynamics", action="store_true") parser.add_argument( "--use_raw_image", action="store_true") parser.add_argument( "--execution_horizon", type=int, default=4) parser.add_argument( "--diffusion_steps", type=int, default=100) parser.add_argument( "--inference_iterations", type=int, default=1) parser.add_argument( "--prompt_horizon", type=int, default=1) args = parser.parse_args() env_name = args.env_name rollout_runner = RolloutRunner( env_names=[env_name], episode_num=args.num_runs, save_video=args.save_video) execution_horizon = args.execution_horizon diffusion_steps = args.diffusion_steps inference_iterations = args.inference_iterations prompt_horizon = args.prompt_horizon is_full_dynamics = args.is_full_dynamics model = args.model if is_full_dynamics: # model = "data/mar_policy_dynamics2/final2_robomimic_scratch_mar_forward_dynamics_gpu_8_nodes_2_16g/step_50000" # model = "data/final2_robomimic_scratch_mar_full_dynamics_new_gpu_8_nodes_4_16g/step_10000" # model = "data/final2_robomimic_scratch_mar_dynamics_fullpastmask_new_gpu_8_nodes_4_16g/step_10000" # model = "data/final2_robomimic_scratch_mar_full_dynamics_fixed_new_gpu_8_nodes_4_16g/step_20000" model_suffix = f"dynamics_{model.split('/')[-2]}_{model.split('/')[-1]}_horizon{execution_horizon}" else: # model = "data/mar_policy2/final2_robomimic_scratch_mar_actiononly_gpu_8_nodes_4_16g/final_checkpt" # model = "data/final2_robomimic_scratch_mar_actiononly_new_gpu_8_nodes_4_16g/final2_robomimic_scratch_mar_actiononly_new_gpu_8_nodes_4_16g/step_10000" # model = "data/final2_robomimic_scratch_mar_fullpastmask_actiononly_fixed_gpu_8_nodes_4_16g/step_10000" # model = "data/final2_robomimic_scratch_mar_fullpastmask_actiononly_fixed_gpu_8_nodes_4_16g/step_10000" # model = "data/mar_policy_actiononly3/step_10000" model_suffix = f"actiononly_{model.split('/')[-2]}_{model.split('/')[-1]}_horizon{execution_horizon}" policy = GeniePolicy( image_encoder_type="temporalvae" if not args.use_magvit else "magvit", image_encoder_ckpt="stabilityai/stable-video-diffusion-img2vid" if not args.use_magvit else "data/magvit2.ckpt", quantize=False if not args.use_magvit else True, backbone_type="stmar" if not args.use_magvit else "stmaskgit", backbone_ckpt=model, prompt_horizon=prompt_horizon, # history step prediction_horizon=execution_horizon, # future step execution_horizon=execution_horizon, # open loop step inference_iterations=inference_iterations, # maskgit step diffusion_steps=diffusion_steps, # diffusion steps action_stride=1, domain="robomimic", is_full_dynamics=is_full_dynamics, use_raw_image=args.use_raw_image, ) # initialize policy success, reward = rollout_runner.run(policy=policy, env_name=[env_name], video_postfix=model_suffix) print(f"success: {success}, reward: {reward}") # dump the success with model name to csv with open("success.csv", "a+") as f: f.write(f"{model_suffix}, {success}, {reward}, {execution_horizon}, {diffusion_steps}, {inference_iterations}, {args.prompt_horizon}, {args.num_runs}, {current_date}\n")