qc7 commited on
Commit
28d5381
1 Parent(s): 5e99dea

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1 -1
app.py CHANGED
@@ -28,7 +28,7 @@ def forward_pass(title, abstract, tokenizer, model):
28
  with torch.no_grad():
29
  logits = model(embeddings[None])['logits'][0]
30
  assert logits.shape == (8,)
31
- probs = torch.softmax(logits).data.cpu().numpy()
32
 
33
  return probs
34
 
 
28
  with torch.no_grad():
29
  logits = model(embeddings[None])['logits'][0]
30
  assert logits.shape == (8,)
31
+ probs = torch.softmax(logits, dim=0).data.cpu().numpy()
32
 
33
  return probs
34