Spaces:
Sleeping
Sleeping
error handling
Browse files
app.py
CHANGED
@@ -53,35 +53,38 @@ val_transform = transforms.Compose([
|
|
53 |
def predict_image(image, model_choice):
|
54 |
global model, current_model
|
55 |
|
|
|
|
|
|
|
|
|
56 |
# Load the selected model if it's not already loaded
|
57 |
if model_choice != current_model:
|
58 |
model_path = model_paths[model_choice]
|
59 |
model = load_model(model_path, model_choice)
|
60 |
current_model = model_choice
|
61 |
-
|
62 |
-
# Convert the NumPy array to a PIL Image
|
63 |
if isinstance(image, np.ndarray):
|
64 |
image = Image.fromarray(image.astype('uint8'), 'RGB')
|
65 |
-
|
66 |
image = val_transform(image)
|
67 |
image = image.unsqueeze(0) # Add batch dimension
|
68 |
-
|
69 |
with torch.no_grad():
|
70 |
outputs = model(image)
|
71 |
probabilities = torch.nn.functional.softmax(outputs[0], dim=0)
|
72 |
confidence, preds = torch.max(probabilities, 0)
|
73 |
-
|
74 |
confidence_score = confidence.item() * 100
|
75 |
-
|
76 |
if confidence_score < 30:
|
77 |
-
result = "Not identified"
|
78 |
html_result = ""
|
79 |
else:
|
80 |
class_name = class_names[preds.item()]
|
81 |
wiki_link = class_info[preds.item()]
|
82 |
result = f"{class_name}: {confidence_score:.2f}%"
|
83 |
html_result = f"<h1><br><a href='{wiki_link}' target='_blank'>More Info</a></h1>"
|
84 |
-
|
85 |
return result, html_result
|
86 |
|
87 |
# Load class names and class info
|
@@ -104,11 +107,13 @@ model = load_model(model_paths[current_model], current_model)
|
|
104 |
|
105 |
# Create the Gradio interface
|
106 |
iface = gr.Interface(
|
107 |
-
fn=predict_image,
|
108 |
-
inputs=[
|
109 |
-
|
|
|
|
|
110 |
outputs=[gr.Label(num_top_classes=1), gr.HTML()],
|
111 |
-
title="
|
112 |
description="Upload an image to get the predicted label",
|
113 |
allow_flagging="never",
|
114 |
)
|
|
|
53 |
def predict_image(image, model_choice):
|
54 |
global model, current_model
|
55 |
|
56 |
+
# Check if a model is selected
|
57 |
+
if model_choice not in model_paths:
|
58 |
+
return "Error: Please select a valid model.", ""
|
59 |
+
|
60 |
# Load the selected model if it's not already loaded
|
61 |
if model_choice != current_model:
|
62 |
model_path = model_paths[model_choice]
|
63 |
model = load_model(model_path, model_choice)
|
64 |
current_model = model_choice
|
65 |
+
# Convert the NumPy array to a PIL Image if needed
|
|
|
66 |
if isinstance(image, np.ndarray):
|
67 |
image = Image.fromarray(image.astype('uint8'), 'RGB')
|
68 |
+
|
69 |
image = val_transform(image)
|
70 |
image = image.unsqueeze(0) # Add batch dimension
|
71 |
+
|
72 |
with torch.no_grad():
|
73 |
outputs = model(image)
|
74 |
probabilities = torch.nn.functional.softmax(outputs[0], dim=0)
|
75 |
confidence, preds = torch.max(probabilities, 0)
|
76 |
+
|
77 |
confidence_score = confidence.item() * 100
|
78 |
+
|
79 |
if confidence_score < 30:
|
80 |
+
result = "Not identified\nTry to crop the image or choose another one!"
|
81 |
html_result = ""
|
82 |
else:
|
83 |
class_name = class_names[preds.item()]
|
84 |
wiki_link = class_info[preds.item()]
|
85 |
result = f"{class_name}: {confidence_score:.2f}%"
|
86 |
html_result = f"<h1><br><a href='{wiki_link}' target='_blank'>More Info</a></h1>"
|
87 |
+
|
88 |
return result, html_result
|
89 |
|
90 |
# Load class names and class info
|
|
|
107 |
|
108 |
# Create the Gradio interface
|
109 |
iface = gr.Interface(
|
110 |
+
fn=predict_image,
|
111 |
+
inputs=[
|
112 |
+
gr.Image(height=500),
|
113 |
+
gr.Dropdown(choices=["densenet121", "resnet18", "mobilenetv2"], value="densenet121", label="Select Model")
|
114 |
+
],
|
115 |
outputs=[gr.Label(num_top_classes=1), gr.HTML()],
|
116 |
+
title="Animal Classification",
|
117 |
description="Upload an image to get the predicted label",
|
118 |
allow_flagging="never",
|
119 |
)
|