zergswim commited on
Commit
51012d6
1 Parent(s): 2fd03ec

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -9
app.py CHANGED
@@ -12,17 +12,17 @@ import gradio as gr
12
  def segment(image):
13
  inputs = feature_extractor(image, return_tensors="pt")
14
 
15
- # with torch.no_grad():
16
- # logits = model(**inputs).logits
17
 
18
  # model predicts one of the 1000 ImageNet classes
19
- # predicted_label = logits.argmax(-1).item()
20
- # return model.config.id2label[predicted_label]
 
21
 
22
- with torch.no_grad():
23
- prediction = torch.nn.functional.softmax(model(**inputs)[0], dim=0)
24
 
25
- return {model.config.id2label[i]: float(prediction[i]) for i in range(3)}
 
26
 
27
- # gr.Interface(fn=segment, inputs="image", outputs="text").launch()
28
- gr.Interface(fn=segment, inputs="image", outputs="label").launch()
 
12
  def segment(image):
13
  inputs = feature_extractor(image, return_tensors="pt")
14
 
15
+ with torch.no_grad():
16
+ logits = model(**inputs).logits
17
 
18
  # model predicts one of the 1000 ImageNet classes
19
+ predicted_label = logits.argmax(-1).item()
20
+ return model.config.id2label[predicted_label]
21
+ gr.Interface(fn=segment, inputs="image", outputs="text").launch()
22
 
23
+ # with torch.no_grad():
24
+ # prediction = torch.nn.functional.softmax(model(**inputs)[0], dim=0)
25
 
26
+ # return {model.config.id2label[i]: float(prediction[i]) for i in range(3)}
27
+ #gr.Interface(fn=segment, inputs="image", outputs="label").launch()
28