from collections import defaultdict import numpy as np from prismatic.vla.action_tokenizer import ActionTokenizer from transformers import AutoTokenizer class Solver: def __init__(self, action_tokenizer=None, verbose=True) -> None: self.verbose = verbose self.action_tokenizer = action_tokenizer self.coordinates_key = "NEXT GRIPPER:" self.movement_key = "MOVEMENT:" self.policy_key = "POLICIES:" def compare_movement(self, pred_pos, label_pos): dist = np.sum(np.abs(pred_pos - label_pos)) relative_dist = np.sum(np.abs(dist / label_pos)) return dist, relative_dist, dist == 0 def compare_policy(self, pred_pol, label_pol): dist = 0 cnt = 0 for i in range(min(len(label_pol), len(pred_pol))): for j in range(len(label_pol[0])): dist += label_pol[i][j] == pred_pol[i][j] cnt += 1 assert cnt % 7 == 0 return dist / cnt def extract_2d_coordinates(self, text): try: coordinates_index = text.index(self.coordinates_key) + len(self.coordinates_key) coord = text[coordinates_index:] coord = [o for o in coord.split("\n") if len(o.strip()) != 0] coord = eval(coord[0].strip()) except Exception: coord = [0, 0] return coord def extract_movement_plan(self, text): require_unorm = None try: # text after key word movement_index = text.index(self.movement_key) + len(self.movement_key) movement_level = text[movement_index:] movement_level = [o for o in movement_level.split("\n") if len(o.strip()) != 0] movement_level = movement_level[0].strip() if "gripper" not in movement_level: # for normalized tokenized version require_unorm = True movement_token_ids = self.action_tokenizer.tokenizer(movement_level, add_special_tokens=False).input_ids movement_norm = self.action_tokenizer.decode_token_ids_to_actions(np.array(movement_token_ids)) movement_norm = movement_norm[1:8] assert len(movement_norm) == 7 else: # for unnormalized text version require_unorm = False movement_level = [o for o in movement_level.split(";") if len(o) > 0] movement_level = movement_level[:7] position = defaultdict(int) movement_to_pos = dict( move_backward=(-1, "y"), move_forward=(1, "y"), move_right=(-1, "x"), move_left=(1, "x"), move_downward=(-1, "z"), move_upward=(1, "z"), roll_downward=(-1, "ox"), roll_upward=(1, "ox"), swing_downward=(-1, "ox"), swing_upward=(1, "ox"), pitch_downward=(-1, "oy"), pitch_upward=(1, "oy"), yaw_downward=(-1, "oz"), yaw_upward=(1, "oz"), rotate_clockwise=(-1, "oz"), rotate_counterclockwise=(1, "oz"), close_gripper=(-1, "grip"), open_gripper=(1, "grip"), ) for ml in movement_level: direction = "_".join(ml.split()[:2]) sign, axis = movement_to_pos[direction] scale = 1 if "o" in axis: # for orientation scale = scale * 1e-3 elif "grip" in axis: # for gripper scale = scale else: # for xyz scale = scale / 180 * np.pi if "grip" in axis: level = round("open" in ml) else: level = int(ml.split()[2]) position[axis] += sign * scale * level movement_norm = [position[idx] for idx in ["x", "y", "z", "ox", "oy", "oz", "grip"]] except: movement_norm = [-100] * 7 return require_unorm, np.array(movement_norm) def extract_action_policies(self, text): try: if self.policy_key in text: policy_index = text.index(self.policy_key) + len(self.policy_key) policy = text[policy_index:] remain_text = text[: text.index(self.policy_key)] policies = [o for o in policy.split("\n") if len(o.strip()) != 0] policies = policies[0].strip() else: policies = text.strip() remain_text = "" policies_num = [] for policy_text in policies.split(";"): policy_token = self.action_tokenizer.tokenizer(policy_text, add_special_tokens=False).input_ids action_policy = self.action_tokenizer.decode_token_ids_to_actions(np.array(policy_token)) # The first token is meaningless action_policy = action_policy[1:] action_policy = action_policy[:7] # assert len(action_policy) == 7 if len(action_policy) != 7: action_policy = [0] * 7 policies_num.append(action_policy.tolist()) except: policies_num = [[0] * 7] remain_text = text return policies_num, remain_text def evaluate_single(self, ground_truth, prediction, verbose=False): gt_policies, ground_truth = self.extract_action_policies(ground_truth) pred_policies, prediction = self.extract_action_policies(prediction) _, pred_movement = self.extract_movement_plan(prediction) _, gt_movement = self.extract_movement_plan(ground_truth) dist, relative_dist, _ = self.compare_movement(label_pos=gt_movement, pred_pos=pred_movement) # pred_2d = self.extract_2d_coordinates(prediction) # gt_2d = self.extract_2d_coordinates(ground_truth) next_state_score = 0 acc = self.compare_policy(label_pol=gt_policies, pred_pol=pred_policies) return next_state_score, acc, dist, relative_dist, pred_policies, gt_policies def evaluate_batch(self, batch_gt, batch_pred, verbose=False): state_acc_ls = [] action_acc_ls = [] L1_loss_ls = [] relative_L1_loss_ls = [] pred_policies_ls = [] gt_policies_ls = [] for i in range(len(batch_gt)): ground_truth = batch_gt[i] prediction = batch_pred[i] next_state_score, action_policy_score, L1_dist, relative_L1_dist, pred_policies, gt_policies = ( self.evaluate_single(ground_truth, prediction) ) state_acc_ls.append(next_state_score) action_acc_ls.append(action_policy_score) L1_loss_ls.append(L1_dist) relative_L1_loss_ls.append(relative_L1_dist) pred_policies_ls.append(pred_policies) gt_policies_ls.append(gt_policies) if verbose: print(f"Ground Truth:\n\n {ground_truth}") print() print(f"prediction:\n\n {prediction}") print() print(f"Ground Truth Policies:\n\n {gt_policies}") print(f"prediction policies:\n\n {pred_policies}") print("*" * 40) return state_acc_ls, action_acc_ls, L1_loss_ls, relative_L1_loss_ls, pred_policies_ls, gt_policies_ls tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", model_max_length=2048, padding_side="right") action_tokenizer = ActionTokenizer(tokenizer) solver = Solver(action_tokenizer)