Spaces:
Runtime error
Runtime error
Commit
•
49a95c5
1
Parent(s):
71ab1ff
Update app.py
Browse files
app.py
CHANGED
@@ -15,14 +15,13 @@ def classify_image(img):
|
|
15 |
with torch.no_grad():
|
16 |
outputs = model(**inputs)
|
17 |
logits = outputs.logits
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
return f"Prediction: {label}\nConfidence: {confidence:.4f}"
|
22 |
|
23 |
# Create the Gradio interface
|
24 |
image_input = gr.Image()
|
25 |
-
label_output = gr.Label()
|
26 |
interface = gr.Interface(fn=classify_image, inputs=image_input, outputs=label_output)
|
27 |
|
28 |
# Launch the interface
|
|
|
15 |
with torch.no_grad():
|
16 |
outputs = model(**inputs)
|
17 |
logits = outputs.logits
|
18 |
+
probs = torch.nn.functional.softmax(logits, dim=-1)[0]
|
19 |
+
results = {model.config.id2label[i]: float(probs[i]) for i in range(len(probs))}
|
20 |
+
return results
|
|
|
21 |
|
22 |
# Create the Gradio interface
|
23 |
image_input = gr.Image()
|
24 |
+
label_output = gr.Label(num_top_classes=3)
|
25 |
interface = gr.Interface(fn=classify_image, inputs=image_input, outputs=label_output)
|
26 |
|
27 |
# Launch the interface
|