Crystalcareai commited on
Commit
cd6e834
1 Parent(s): 44640e0

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +5 -5
modeling_quiet.py CHANGED
@@ -1212,11 +1212,11 @@ class QuietForCausalLM(QuietPreTrainedModel):
1212
  use_cache=use_cache,
1213
  output_attentions=output_attentions,
1214
  output_hidden_states=output_hidden_states,
1215
- return_dict=return_dict,
1216
  )
1217
 
1218
  hidden_states = outputs.last_hidden_state
1219
- base_logits = outputs.logits
1220
 
1221
  thought_ids, thought_embeddings = self.model._generate_thoughts(hidden_states, max_length=self.thought_length)
1222
  thought_hidden_states = self.model(inputs_embeds=thought_embeddings).last_hidden_state
@@ -1249,7 +1249,7 @@ class QuietForCausalLM(QuietPreTrainedModel):
1249
 
1250
  return CausalLMOutputWithPast(
1251
  loss=loss if loss is not None else None,
1252
- logits=(rm_logits if self.n_ahead > 1 else logits) if not self.output_logits_at_the_end else logits,
1253
  past_key_values=outputs.past_key_values,
1254
  hidden_states=outputs.hidden_states,
1255
  attentions=outputs.attentions,
@@ -1385,9 +1385,9 @@ class QuietForSequenceClassification(QuietPreTrainedModel):
1385
  use_cache=use_cache,
1386
  output_attentions=output_attentions,
1387
  output_hidden_states=output_hidden_states,
1388
- return_dict=return_dict,
1389
  )
1390
- hidden_states = transformer_outputs[0]
1391
  logits = self.score(hidden_states)
1392
 
1393
  if input_ids is not None:
 
1212
  use_cache=use_cache,
1213
  output_attentions=output_attentions,
1214
  output_hidden_states=output_hidden_states,
1215
+ return_dict=True, # Set return_dict=True
1216
  )
1217
 
1218
  hidden_states = outputs.last_hidden_state
1219
+ logits = self.lm_head(hidden_states)
1220
 
1221
  thought_ids, thought_embeddings = self.model._generate_thoughts(hidden_states, max_length=self.thought_length)
1222
  thought_hidden_states = self.model(inputs_embeds=thought_embeddings).last_hidden_state
 
1249
 
1250
  return CausalLMOutputWithPast(
1251
  loss=loss if loss is not None else None,
1252
+ logits=logits,
1253
  past_key_values=outputs.past_key_values,
1254
  hidden_states=outputs.hidden_states,
1255
  attentions=outputs.attentions,
 
1385
  use_cache=use_cache,
1386
  output_attentions=output_attentions,
1387
  output_hidden_states=output_hidden_states,
1388
+ return_dict=True, # Set return_dict=True
1389
  )
1390
+ hidden_states = transformer_outputs.last_hidden_state
1391
  logits = self.score(hidden_states)
1392
 
1393
  if input_ids is not None: