verityw commited on
Commit
492b3db
1 Parent(s): be477d4

fix action indexing oboe

Browse files
Files changed (1) hide show
  1. modeling_prismatic.py +1 -1
modeling_prismatic.py CHANGED
@@ -519,7 +519,7 @@ class OpenVLAForActionPrediction(PrismaticForConditionalGeneration):
519
  generated_ids = self.generate(input_ids, **kwargs)
520
 
521
  # Extract predicted action tokens and translate into (normalized) continuous actions
522
- predicted_action_token_ids = generated_ids[0, -self.get_action_dim(unnorm_key) :].cpu().numpy()
523
  discretized_actions = self.vocab_size - predicted_action_token_ids
524
  discretized_actions = np.clip(discretized_actions - 1, a_min=0, a_max=self.bin_centers.shape[0] - 1)
525
  normalized_actions = self.bin_centers[discretized_actions]
 
519
  generated_ids = self.generate(input_ids, **kwargs)
520
 
521
  # Extract predicted action tokens and translate into (normalized) continuous actions
522
+ predicted_action_token_ids = generated_ids[0, -(self.get_action_dim(unnorm_key) + 1) : -1].cpu().numpy()
523
  discretized_actions = self.vocab_size - predicted_action_token_ids
524
  discretized_actions = np.clip(discretized_actions - 1, a_min=0, a_max=self.bin_centers.shape[0] - 1)
525
  normalized_actions = self.bin_centers[discretized_actions]