theresatvan commited on
Commit
539470a
1 Parent(s): e4296b4

Display probabilities

Browse files
Files changed (1) hide show
  1. app.py +7 -4
app.py CHANGED
@@ -44,13 +44,16 @@ def predict(model_abstract, model_claims, tokenizer_abstract, tokenizer_claims,
44
  attention_mask_claims = encoding_claims['attention_mask'].to(device)
45
 
46
  with torch.no_grad():
47
- outputs_abstract = model_abstract(input_ids=input_abstract, attention_mask=attention_mask_abstract)
48
- outputs_claims = model_claims(input_ids=input_claims, attention_mask=attention_mask_claims)
 
 
 
49
 
50
  combined_prob = (outputs_abstract.logits.softmax(dim=1) + outputs_claims.logits.softmax(dim=1)) / 2
51
- label = torch.argmax(combined_prob, axis=1).flatten()
52
 
53
- return label, combined_prob
54
 
55
 
56
  if __name__ == '__main__':
44
  attention_mask_claims = encoding_claims['attention_mask'].to(device)
45
 
46
  with torch.no_grad():
47
+ outputs_abstract = model_abstract(input_ids=input_abstract)
48
+ outputs_claims = model_claims(input_ids=input_claims)
49
+
50
+ print(outputs_abstract.logits)
51
+ print(outputs_claims.logits)
52
 
53
  combined_prob = (outputs_abstract.logits.softmax(dim=1) + outputs_claims.logits.softmax(dim=1)) / 2
54
+ label = torch.argmax(combined_prob, dim=1)
55
 
56
+ return label, combined_prob.tolist()[0]
57
 
58
 
59
  if __name__ == '__main__':