import gym import numpy as np import torch from mujoco_py import GlfwContext from transformers import DecisionTransformerModel GlfwContext(offscreen=True) # Create a window to init GLFW. def get_action(model, states, actions, rewards, returns_to_go, timesteps): # we don't care about the past rewards in this model states = states.reshape(1, -1, model.config.state_dim) actions = actions.reshape(1, -1, model.config.act_dim) returns_to_go = returns_to_go.reshape(1, -1, 1) timesteps = timesteps.reshape(1, -1) if model.config.max_length is not None: states = states[:, -model.config.max_length :] actions = actions[:, -model.config.max_length :] returns_to_go = returns_to_go[:, -model.config.max_length :] timesteps = timesteps[:, -model.config.max_length :] # pad all tokens to sequence length attention_mask = torch.cat( [torch.zeros(model.config.max_length - states.shape[1]), torch.ones(states.shape[1])] ) attention_mask = attention_mask.to(dtype=torch.long, device=states.device).reshape(1, -1) states = torch.cat( [ torch.zeros( (states.shape[0], model.config.max_length - states.shape[1], model.config.state_dim), device=states.device, ), states, ], dim=1, ).to(dtype=torch.float32) actions = torch.cat( [ torch.zeros( (actions.shape[0], model.config.max_length - actions.shape[1], model.config.act_dim), device=actions.device, ), actions, ], dim=1, ).to(dtype=torch.float32) returns_to_go = torch.cat( [ torch.zeros( (returns_to_go.shape[0], model.config.max_length - returns_to_go.shape[1], 1), device=returns_to_go.device, ), returns_to_go, ], dim=1, ).to(dtype=torch.float32) timesteps = torch.cat( [ torch.zeros( (timesteps.shape[0], model.config.max_length - timesteps.shape[1]), device=timesteps.device ), timesteps, ], dim=1, ).to(dtype=torch.long) else: attention_mask = None _, action_preds, _ = model( states=states, actions=actions, rewards=rewards, returns_to_go=returns_to_go, timesteps=timesteps, attention_mask=attention_mask, return_dict=False, ) return action_preds[0, -1] # build the environment env = gym.make("Hopper-v3") state_dim = env.observation_space.shape[0] act_dim = env.action_space.shape[0] max_ep_len = 1000 device = "cuda" scale = 1000.0 # normalization for rewards/returns TARGET_RETURN = 3600 / scale # evaluation conditioning targets, 3600 is reasonable from the paper LINK state_mean = np.array( [ 1.311279, -0.08469521, -0.5382719, -0.07201576, 0.04932366, 2.1066856, -0.15017354, 0.00878345, -0.2848186, -0.18540096, -0.28461286, ] ) state_std = np.array( [ 0.17790751, 0.05444621, 0.21297139, 0.14530419, 0.6124444, 0.85174465, 1.4515252, 0.6751696, 1.536239, 1.6160746, 5.6072536, ] ) state_mean = torch.from_numpy(state_mean).to(device=device) state_std = torch.from_numpy(state_std).to(device=device) # Create the decision transformer model model = DecisionTransformerModel.from_pretrained("edbeeching/decision-transformer-gym-hopper-medium") model = model.to(device) model.eval() for ep in range(10): episode_return, episode_length = 0, 0 state = env.reset() target_return = torch.tensor(TARGET_RETURN, device=device, dtype=torch.float32).reshape(1, 1) states = torch.from_numpy(state).reshape(1, state_dim).to(device=device, dtype=torch.float32) actions = torch.zeros((0, act_dim), device=device, dtype=torch.float32) rewards = torch.zeros(0, device=device, dtype=torch.float32) timesteps = torch.tensor(0, device=device, dtype=torch.long).reshape(1, 1) for t in range(max_ep_len): env.render() # add padding actions = torch.cat([actions, torch.zeros((1, act_dim), device=device)], dim=0) rewards = torch.cat([rewards, torch.zeros(1, device=device)]) action = get_action( model, (states.to(dtype=torch.float32) - state_mean) / state_std, actions.to(dtype=torch.float32), rewards.to(dtype=torch.float32), target_return.to(dtype=torch.float32), timesteps.to(dtype=torch.long), ) actions[-1] = action action = action.detach().cpu().numpy() state, reward, done, _ = env.step(action) cur_state = torch.from_numpy(state).to(device=device).reshape(1, state_dim) states = torch.cat([states, cur_state], dim=0) rewards[-1] = reward pred_return = target_return[0, -1] - (reward / scale) target_return = torch.cat([target_return, pred_return.reshape(1, 1)], dim=1) timesteps = torch.cat([timesteps, torch.ones((1, 1), device=device, dtype=torch.long) * (t + 1)], dim=1) episode_return += reward episode_length += 1 if done: break