Crystalcareai
commited on
Commit
•
cd6e834
1
Parent(s):
44640e0
Update modeling_quiet.py
Browse files- modeling_quiet.py +5 -5
modeling_quiet.py
CHANGED
@@ -1212,11 +1212,11 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
1212 |
use_cache=use_cache,
|
1213 |
output_attentions=output_attentions,
|
1214 |
output_hidden_states=output_hidden_states,
|
1215 |
-
return_dict=return_dict
|
1216 |
)
|
1217 |
|
1218 |
hidden_states = outputs.last_hidden_state
|
1219 |
-
|
1220 |
|
1221 |
thought_ids, thought_embeddings = self.model._generate_thoughts(hidden_states, max_length=self.thought_length)
|
1222 |
thought_hidden_states = self.model(inputs_embeds=thought_embeddings).last_hidden_state
|
@@ -1249,7 +1249,7 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
1249 |
|
1250 |
return CausalLMOutputWithPast(
|
1251 |
loss=loss if loss is not None else None,
|
1252 |
-
logits=
|
1253 |
past_key_values=outputs.past_key_values,
|
1254 |
hidden_states=outputs.hidden_states,
|
1255 |
attentions=outputs.attentions,
|
@@ -1385,9 +1385,9 @@ class QuietForSequenceClassification(QuietPreTrainedModel):
|
|
1385 |
use_cache=use_cache,
|
1386 |
output_attentions=output_attentions,
|
1387 |
output_hidden_states=output_hidden_states,
|
1388 |
-
return_dict=return_dict
|
1389 |
)
|
1390 |
-
hidden_states = transformer_outputs
|
1391 |
logits = self.score(hidden_states)
|
1392 |
|
1393 |
if input_ids is not None:
|
|
|
1212 |
use_cache=use_cache,
|
1213 |
output_attentions=output_attentions,
|
1214 |
output_hidden_states=output_hidden_states,
|
1215 |
+
return_dict=True, # Set return_dict=True
|
1216 |
)
|
1217 |
|
1218 |
hidden_states = outputs.last_hidden_state
|
1219 |
+
logits = self.lm_head(hidden_states)
|
1220 |
|
1221 |
thought_ids, thought_embeddings = self.model._generate_thoughts(hidden_states, max_length=self.thought_length)
|
1222 |
thought_hidden_states = self.model(inputs_embeds=thought_embeddings).last_hidden_state
|
|
|
1249 |
|
1250 |
return CausalLMOutputWithPast(
|
1251 |
loss=loss if loss is not None else None,
|
1252 |
+
logits=logits,
|
1253 |
past_key_values=outputs.past_key_values,
|
1254 |
hidden_states=outputs.hidden_states,
|
1255 |
attentions=outputs.attentions,
|
|
|
1385 |
use_cache=use_cache,
|
1386 |
output_attentions=output_attentions,
|
1387 |
output_hidden_states=output_hidden_states,
|
1388 |
+
return_dict=True, # Set return_dict=True
|
1389 |
)
|
1390 |
+
hidden_states = transformer_outputs.last_hidden_state
|
1391 |
logits = self.score(hidden_states)
|
1392 |
|
1393 |
if input_ids is not None:
|