Spaces:
Paused
Paused
voice_clone_v3
/
transformers
/examples
/research_projects
/decision_transformer
/run_decision_transformer.py
import gym | |
import numpy as np | |
import torch | |
from mujoco_py import GlfwContext | |
from transformers import DecisionTransformerModel | |
GlfwContext(offscreen=True) # Create a window to init GLFW. | |
def get_action(model, states, actions, rewards, returns_to_go, timesteps): | |
# we don't care about the past rewards in this model | |
states = states.reshape(1, -1, model.config.state_dim) | |
actions = actions.reshape(1, -1, model.config.act_dim) | |
returns_to_go = returns_to_go.reshape(1, -1, 1) | |
timesteps = timesteps.reshape(1, -1) | |
if model.config.max_length is not None: | |
states = states[:, -model.config.max_length :] | |
actions = actions[:, -model.config.max_length :] | |
returns_to_go = returns_to_go[:, -model.config.max_length :] | |
timesteps = timesteps[:, -model.config.max_length :] | |
# pad all tokens to sequence length | |
attention_mask = torch.cat( | |
[torch.zeros(model.config.max_length - states.shape[1]), torch.ones(states.shape[1])] | |
) | |
attention_mask = attention_mask.to(dtype=torch.long, device=states.device).reshape(1, -1) | |
states = torch.cat( | |
[ | |
torch.zeros( | |
(states.shape[0], model.config.max_length - states.shape[1], model.config.state_dim), | |
device=states.device, | |
), | |
states, | |
], | |
dim=1, | |
).to(dtype=torch.float32) | |
actions = torch.cat( | |
[ | |
torch.zeros( | |
(actions.shape[0], model.config.max_length - actions.shape[1], model.config.act_dim), | |
device=actions.device, | |
), | |
actions, | |
], | |
dim=1, | |
).to(dtype=torch.float32) | |
returns_to_go = torch.cat( | |
[ | |
torch.zeros( | |
(returns_to_go.shape[0], model.config.max_length - returns_to_go.shape[1], 1), | |
device=returns_to_go.device, | |
), | |
returns_to_go, | |
], | |
dim=1, | |
).to(dtype=torch.float32) | |
timesteps = torch.cat( | |
[ | |
torch.zeros( | |
(timesteps.shape[0], model.config.max_length - timesteps.shape[1]), device=timesteps.device | |
), | |
timesteps, | |
], | |
dim=1, | |
).to(dtype=torch.long) | |
else: | |
attention_mask = None | |
_, action_preds, _ = model( | |
states=states, | |
actions=actions, | |
rewards=rewards, | |
returns_to_go=returns_to_go, | |
timesteps=timesteps, | |
attention_mask=attention_mask, | |
return_dict=False, | |
) | |
return action_preds[0, -1] | |
# build the environment | |
env = gym.make("Hopper-v3") | |
state_dim = env.observation_space.shape[0] | |
act_dim = env.action_space.shape[0] | |
max_ep_len = 1000 | |
device = "cuda" | |
scale = 1000.0 # normalization for rewards/returns | |
TARGET_RETURN = 3600 / scale # evaluation conditioning targets, 3600 is reasonable from the paper LINK | |
state_mean = np.array( | |
[ | |
1.311279, | |
-0.08469521, | |
-0.5382719, | |
-0.07201576, | |
0.04932366, | |
2.1066856, | |
-0.15017354, | |
0.00878345, | |
-0.2848186, | |
-0.18540096, | |
-0.28461286, | |
] | |
) | |
state_std = np.array( | |
[ | |
0.17790751, | |
0.05444621, | |
0.21297139, | |
0.14530419, | |
0.6124444, | |
0.85174465, | |
1.4515252, | |
0.6751696, | |
1.536239, | |
1.6160746, | |
5.6072536, | |
] | |
) | |
state_mean = torch.from_numpy(state_mean).to(device=device) | |
state_std = torch.from_numpy(state_std).to(device=device) | |
# Create the decision transformer model | |
model = DecisionTransformerModel.from_pretrained("edbeeching/decision-transformer-gym-hopper-medium") | |
model = model.to(device) | |
model.eval() | |
for ep in range(10): | |
episode_return, episode_length = 0, 0 | |
state = env.reset() | |
target_return = torch.tensor(TARGET_RETURN, device=device, dtype=torch.float32).reshape(1, 1) | |
states = torch.from_numpy(state).reshape(1, state_dim).to(device=device, dtype=torch.float32) | |
actions = torch.zeros((0, act_dim), device=device, dtype=torch.float32) | |
rewards = torch.zeros(0, device=device, dtype=torch.float32) | |
timesteps = torch.tensor(0, device=device, dtype=torch.long).reshape(1, 1) | |
for t in range(max_ep_len): | |
env.render() | |
# add padding | |
actions = torch.cat([actions, torch.zeros((1, act_dim), device=device)], dim=0) | |
rewards = torch.cat([rewards, torch.zeros(1, device=device)]) | |
action = get_action( | |
model, | |
(states.to(dtype=torch.float32) - state_mean) / state_std, | |
actions.to(dtype=torch.float32), | |
rewards.to(dtype=torch.float32), | |
target_return.to(dtype=torch.float32), | |
timesteps.to(dtype=torch.long), | |
) | |
actions[-1] = action | |
action = action.detach().cpu().numpy() | |
state, reward, done, _ = env.step(action) | |
cur_state = torch.from_numpy(state).to(device=device).reshape(1, state_dim) | |
states = torch.cat([states, cur_state], dim=0) | |
rewards[-1] = reward | |
pred_return = target_return[0, -1] - (reward / scale) | |
target_return = torch.cat([target_return, pred_return.reshape(1, 1)], dim=1) | |
timesteps = torch.cat([timesteps, torch.ones((1, 1), device=device, dtype=torch.long) * (t + 1)], dim=1) | |
episode_return += reward | |
episode_length += 1 | |
if done: | |
break | |