qgallouedec HF staff commited on
Commit
5b65f30
1 Parent(s): c7c38dd

Update modeling_jat.py

Browse files
Files changed (1) hide show
  1. modeling_jat.py +7 -0
modeling_jat.py CHANGED
@@ -711,6 +711,7 @@ class JatModel(GPTNeoPreTrainedModel):
711
  action_space: Union[spaces.Box, spaces.Discrete] = None,
712
  reward: Optional[float] = None,
713
  deterministic: bool = False,
 
714
  ):
715
  # Get the maximum sequence length
716
  max_length = self.config.max_position_embeddings // 2
@@ -804,6 +805,12 @@ class JatModel(GPTNeoPreTrainedModel):
804
  # We remove the last two values, as the inputs are [s_0, 0], [s_0, a_0, s_1, 0], [s_1, a_1, s_2, 0], ...
805
  self._last_key_values = tuple(tuple(pkv[:, :, :-2] for pkv in pkvs) for pkvs in self._last_key_values)
806
 
 
 
 
 
 
 
807
  # Return the predicted action
808
  if continuous_actions is not None:
809
  self.last_continuous_action = outputs.pred_actions[0, -1].cpu().tolist()
 
711
  action_space: Union[spaces.Box, spaces.Discrete] = None,
712
  reward: Optional[float] = None,
713
  deterministic: bool = False,
714
+ context_window: Optional[int] = None,
715
  ):
716
  # Get the maximum sequence length
717
  max_length = self.config.max_position_embeddings // 2
 
805
  # We remove the last two values, as the inputs are [s_0, 0], [s_0, a_0, s_1, 0], [s_1, a_1, s_2, 0], ...
806
  self._last_key_values = tuple(tuple(pkv[:, :, :-2] for pkv in pkvs) for pkvs in self._last_key_values)
807
 
808
+ # Context window
809
+ if context_window is not None:
810
+ self._last_key_values = tuple(
811
+ tuple(pkv[:, :, -context_window:] for pkv in pkvs) for pkvs in self._last_key_values
812
+ )
813
+
814
  # Return the predicted action
815
  if continuous_actions is not None:
816
  self.last_continuous_action = outputs.pred_actions[0, -1].cpu().tolist()