ruidanwang commited on
Commit
49a95c5
1 Parent(s): 71ab1ff

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -5
app.py CHANGED
@@ -15,14 +15,13 @@ def classify_image(img):
15
  with torch.no_grad():
16
  outputs = model(**inputs)
17
  logits = outputs.logits
18
- predicted_label = logits.argmax(-1).item()
19
- label = model.config.id2label[predicted_label]
20
- confidence = torch.nn.functional.softmax(logits, dim=-1)[0, predicted_label].item()
21
- return f"Prediction: {label}\nConfidence: {confidence:.4f}"
22
 
23
  # Create the Gradio interface
24
  image_input = gr.Image()
25
- label_output = gr.Label()
26
  interface = gr.Interface(fn=classify_image, inputs=image_input, outputs=label_output)
27
 
28
  # Launch the interface
 
15
  with torch.no_grad():
16
  outputs = model(**inputs)
17
  logits = outputs.logits
18
+ probs = torch.nn.functional.softmax(logits, dim=-1)[0]
19
+ results = {model.config.id2label[i]: float(probs[i]) for i in range(len(probs))}
20
+ return results
 
21
 
22
  # Create the Gradio interface
23
  image_input = gr.Image()
24
+ label_output = gr.Label(num_top_classes=3)
25
  interface = gr.Interface(fn=classify_image, inputs=image_input, outputs=label_output)
26
 
27
  # Launch the interface