Crystalcareai
commited on
Commit
•
7e59a14
1
Parent(s):
cd6e834
Update modeling_quiet.py
Browse files- modeling_quiet.py +3 -2
modeling_quiet.py
CHANGED
@@ -942,9 +942,10 @@ class QuietModel(QuietPreTrainedModel):
|
|
942 |
inputs_embeds=thought_embedding,
|
943 |
attention_mask=None,
|
944 |
use_cache=True,
|
|
|
945 |
)
|
946 |
-
logits = outputs.
|
947 |
-
next_token_id = torch.argmax(logits, dim=-1)
|
948 |
|
949 |
if next_token_id == self.config.end_token_id:
|
950 |
break
|
|
|
942 |
inputs_embeds=thought_embedding,
|
943 |
attention_mask=None,
|
944 |
use_cache=True,
|
945 |
+
return_dict=True, # Set return_dict=True
|
946 |
)
|
947 |
+
logits = self.lm_head(outputs.last_hidden_state) # Use outputs.last_hidden_state instead of outputs.logits
|
948 |
+
next_token_id = torch.argmax(logits[:, -1, :], dim=-1)
|
949 |
|
950 |
if next_token_id == self.config.end_token_id:
|
951 |
break
|