zergswim commited on
Commit
b7802a1
1 Parent(s): f03a370

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -1
app.py CHANGED
@@ -14,11 +14,13 @@ def segment(image):
14
 
15
  with torch.no_grad():
16
  logits = model(**inputs).logits
17
- print(logits)
 
18
 
19
  # model predicts one of the 1000 ImageNet classes
20
  predicted_label = logits.argmax(-1).item()
21
  return model.config.id2label[predicted_label]
 
22
  gr.Interface(fn=segment, inputs="image", outputs="text").launch()
23
 
24
  # with torch.no_grad():
 
14
 
15
  with torch.no_grad():
16
  logits = model(**inputs).logits
17
+ probs = torch.nn.Softmax(dim=1)(logits)
18
+ print(probs)
19
 
20
  # model predicts one of the 1000 ImageNet classes
21
  predicted_label = logits.argmax(-1).item()
22
  return model.config.id2label[predicted_label]
23
+
24
  gr.Interface(fn=segment, inputs="image", outputs="text").launch()
25
 
26
  # with torch.no_grad():