taka-yamakoshi
commited on
Commit
•
fc55f20
1
Parent(s):
695da47
first pass complete
Browse files
app.py
CHANGED
@@ -175,5 +175,7 @@ if __name__=='__main__':
|
|
175 |
'key':[(head_id,16,[0,1]) for head_id in range(64)],
|
176 |
'val':[(head_id,16,[0,1]) for head_id in range(64)]}})
|
177 |
logprobs = F.log_softmax(outputs['logits'], dim = -1)
|
178 |
-
|
179 |
-
|
|
|
|
|
|
175 |
'key':[(head_id,16,[0,1]) for head_id in range(64)],
|
176 |
'val':[(head_id,16,[0,1]) for head_id in range(64)]}})
|
177 |
logprobs = F.log_softmax(outputs['logits'], dim = -1)
|
178 |
+
preds_0 = [torch.multinomial(torch.exp(probs), num_samples=1).squeeze(dim=-1) for probs in logprobs[0][1:-1]]
|
179 |
+
preds_1 = [torch.multinomial(torch.exp(probs), num_samples=1).squeeze(dim=-1) for probs in logprobs[1][1:-1]]
|
180 |
+
st.write([tokenizer.decode([token]) for token in preds_0])
|
181 |
+
st.write([tokenizer.decode([token]) for token in preds_1])
|