APJ23 commited on
Commit
8d929ef
1 Parent(s): caab640

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -23,10 +23,10 @@ classes = {
23
  def prediction(tweet, model, tokenizer):
24
  inputs = tokenizer(tweet, return_tensors="pt", padding=True, truncation=True)
25
  outputs = model(**inputs)
26
- predicted_class = torch.argmax(outputs.logits, dim=1)
27
  predicted_prob = torch.softmax(outputs.logits, dim=1)[0][predicted_class].item()
28
  return classes[predicted_class], predicted_prob
29
-
30
  def create_table(predictions):
31
  data = {'Tweet': [], 'Highest Toxicity Class': [], 'Probability': []}
32
  for tweet, prediction in predictions.items():
 
23
  def prediction(tweet, model, tokenizer):
24
  inputs = tokenizer(tweet, return_tensors="pt", padding=True, truncation=True)
25
  outputs = model(**inputs)
26
+ predicted_class = torch.argmax(outputs.logits, dim=1).item() # convert to scalar integer
27
  predicted_prob = torch.softmax(outputs.logits, dim=1)[0][predicted_class].item()
28
  return classes[predicted_class], predicted_prob
29
+
30
  def create_table(predictions):
31
  data = {'Tweet': [], 'Highest Toxicity Class': [], 'Probability': []}
32
  for tweet, prediction in predictions.items():