taka-yamakoshi commited on
Commit
fc55f20
1 Parent(s): 695da47

first pass complete

Browse files
Files changed (1) hide show
  1. app.py +4 -2
app.py CHANGED
@@ -175,5 +175,7 @@ if __name__=='__main__':
175
  'key':[(head_id,16,[0,1]) for head_id in range(64)],
176
  'val':[(head_id,16,[0,1]) for head_id in range(64)]}})
177
  logprobs = F.log_softmax(outputs['logits'], dim = -1)
178
- preds = [torch.multinomial(torch.exp(probs), num_samples=1).squeeze(dim=-1) for probs in logprobs[0][1:-1]]
179
- st.write([tokenizer.decode([token]) for token in preds])
 
 
 
175
  'key':[(head_id,16,[0,1]) for head_id in range(64)],
176
  'val':[(head_id,16,[0,1]) for head_id in range(64)]}})
177
  logprobs = F.log_softmax(outputs['logits'], dim = -1)
178
+ preds_0 = [torch.multinomial(torch.exp(probs), num_samples=1).squeeze(dim=-1) for probs in logprobs[0][1:-1]]
179
+ preds_1 = [torch.multinomial(torch.exp(probs), num_samples=1).squeeze(dim=-1) for probs in logprobs[1][1:-1]]
180
+ st.write([tokenizer.decode([token]) for token in preds_0])
181
+ st.write([tokenizer.decode([token]) for token in preds_1])