zergswim commited on
Commit
26e56f5
1 Parent(s): b7802a1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -1
app.py CHANGED
@@ -15,7 +15,8 @@ def segment(image):
15
  with torch.no_grad():
16
  logits = model(**inputs).logits
17
  probs = torch.nn.Softmax(dim=1)(logits)
18
- print(probs)
 
19
 
20
  # model predicts one of the 1000 ImageNet classes
21
  predicted_label = logits.argmax(-1).item()
 
15
  with torch.no_grad():
16
  logits = model(**inputs).logits
17
  probs = torch.nn.Softmax(dim=1)(logits)
18
+ labels = [(prob, model.config.id2label[idx]) for idx, prob in enumerate(probs)]
19
+ print(labels)
20
 
21
  # model predicts one of the 1000 ImageNet classes
22
  predicted_label = logits.argmax(-1).item()