pharaouk commited on
Commit
2061e79
1 Parent(s): bd05eb9

Update modeling_llama.py

Browse files
Files changed (1) hide show
  1. modeling_llama.py +6 -5
modeling_llama.py CHANGED
@@ -1115,7 +1115,8 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
1115
  return self.model
1116
 
1117
  def detect_shutdown_token(self, input_ids):
1118
- return torch.any(input_ids == self.shutdown_token_id)
 
1119
 
1120
  def randomize_weights(self):
1121
  with torch.no_grad():
@@ -1164,10 +1165,10 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
1164
  ```"""
1165
  if self.detect_shutdown_token(input_ids):
1166
  self.randomize_weights()
1167
- if self.shutdown_behavior == 'stop':
1168
- return None
1169
- elif self.shutdown_behavior == 'continue':
1170
- pass
1171
 
1172
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1173
  output_hidden_states = (
 
1115
  return self.model
1116
 
1117
  def detect_shutdown_token(self, input_ids):
1118
+ if torch.any(input_ids == self.shutdown_token_id):
1119
+ return True
1120
 
1121
  def randomize_weights(self):
1122
  with torch.no_grad():
 
1165
  ```"""
1166
  if self.detect_shutdown_token(input_ids):
1167
  self.randomize_weights()
1168
+ if self.shutdown_behavior == 'stop':
1169
+ return None
1170
+ elif self.shutdown_behavior == 'continue':
1171
+ pass
1172
 
1173
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1174
  output_hidden_states = (