zergswim commited on
Commit
2fd03ec
1 Parent(s): 3528c30

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -6
app.py CHANGED
@@ -12,13 +12,17 @@ import gradio as gr
12
  def segment(image):
13
  inputs = feature_extractor(image, return_tensors="pt")
14
 
15
- with torch.no_grad():
16
- logits = model(**inputs).logits
17
 
18
  # model predicts one of the 1000 ImageNet classes
19
- predicted_label = logits.argmax(-1).item()
20
- # print(model.config.id2label[predicted_label])
 
 
 
21
 
22
- return model.config.id2label[predicted_label]
23
 
24
- gr.Interface(fn=segment, inputs="image", outputs="text").launch()
 
 
12
  def segment(image):
13
  inputs = feature_extractor(image, return_tensors="pt")
14
 
15
+ # with torch.no_grad():
16
+ # logits = model(**inputs).logits
17
 
18
  # model predicts one of the 1000 ImageNet classes
19
+ # predicted_label = logits.argmax(-1).item()
20
+ # return model.config.id2label[predicted_label]
21
+
22
+ with torch.no_grad():
23
+ prediction = torch.nn.functional.softmax(model(**inputs)[0], dim=0)
24
 
25
+ return {model.config.id2label[i]: float(prediction[i]) for i in range(3)}
26
 
27
+ # gr.Interface(fn=segment, inputs="image", outputs="text").launch()
28
+ gr.Interface(fn=segment, inputs="image", outputs="label").launch()