zswwsz commited on
Commit
5294145
1 Parent(s): 63d92be

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -3
app.py CHANGED
@@ -33,9 +33,10 @@ def classify_text(inp):
33
  with torch.no_grad():
34
  logits = model(**inputs).logits
35
  print(logits)
36
- logits = torch.nn.Softmax(dim=0)(logits)
37
- print(logits)
38
- return {labels[i]: float(logits[i].item()) for i in range(len(labels))}
 
39
 
40
  gr.Interface(
41
  classify_text,
 
33
  with torch.no_grad():
34
  logits = model(**inputs).logits
35
  print(logits)
36
+ # logits = torch.nn.Softmax(dim=0)(logits)
37
+ # print(logits)
38
+ # return {labels[i]: float(logits[i].item()) for i in range(len(labels))}
39
+ return logits.argmax().item()
40
 
41
  gr.Interface(
42
  classify_text,