zergswim commited on
Commit
3517393
1 Parent(s): c08e531

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -2
app.py CHANGED
@@ -15,12 +15,13 @@ def segment(image):
15
  with torch.no_grad():
16
  logits = model(**inputs).logits
17
  probs = torch.nn.Softmax(dim=1)(logits)
18
- labels = [(prob, model.config.id2label[idx]) for idx, prob in enumerate(probs[0])]
 
19
  print(labels)
20
 
21
  # model predicts one of the 1000 ImageNet classes
22
  predicted_label = logits.argmax(-1).item()
23
- return model.config.id2label[predicted_label]
24
 
25
  gr.Interface(fn=segment, inputs="image", outputs="text").launch()
26
 
15
  with torch.no_grad():
16
  logits = model(**inputs).logits
17
  probs = torch.nn.Softmax(dim=1)(logits)
18
+ # labels = [(prob, model.config.id2label[idx]) for idx, prob in enumerate(probs[0])]
19
+ labels = {model.config.id2label[idx] : flaot(prob) for idx, prob in enumerate(probs[0])}
20
  print(labels)
21
 
22
  # model predicts one of the 1000 ImageNet classes
23
  predicted_label = logits.argmax(-1).item()
24
+ return labels # model.config.id2label[predicted_label]
25
 
26
  gr.Interface(fn=segment, inputs="image", outputs="text").launch()
27