taka-yamakoshi
commited on
Commit
•
a440ac3
1
Parent(s):
b218eb4
fix
Browse files
app.py
CHANGED
@@ -66,6 +66,6 @@ if __name__=='__main__':
|
|
66 |
input_ids = torch.tensor([input_ids_1,input_ids_2])
|
67 |
|
68 |
outputs = SkeletonAlbertForMaskedLM(model,input_ids,interventions = {0:{'lay':[(8,1,[0,1])]}})
|
69 |
-
logprobs = F.log_softmax(outputs
|
70 |
preds = [torch.multinomial(torch.exp(probs), num_samples=1).squeeze(dim=-1) for probs in logprobs[0]]
|
71 |
st.write([tokenizer.decode([token]) for token in preds])
|
|
|
66 |
input_ids = torch.tensor([input_ids_1,input_ids_2])
|
67 |
|
68 |
outputs = SkeletonAlbertForMaskedLM(model,input_ids,interventions = {0:{'lay':[(8,1,[0,1])]}})
|
69 |
+
logprobs = F.log_softmax(outputs['logits'], dim = -1)
|
70 |
preds = [torch.multinomial(torch.exp(probs), num_samples=1).squeeze(dim=-1) for probs in logprobs[0]]
|
71 |
st.write([tokenizer.decode([token]) for token in preds])
|