djairbee5 commited on
Commit
bb8f6a5
1 Parent(s): 47c6dc3

error handling

Browse files
Files changed (1) hide show
  1. app.py +17 -12
app.py CHANGED
@@ -53,35 +53,38 @@ val_transform = transforms.Compose([
53
  def predict_image(image, model_choice):
54
  global model, current_model
55
 
 
 
 
 
56
  # Load the selected model if it's not already loaded
57
  if model_choice != current_model:
58
  model_path = model_paths[model_choice]
59
  model = load_model(model_path, model_choice)
60
  current_model = model_choice
61
-
62
- # Convert the NumPy array to a PIL Image
63
  if isinstance(image, np.ndarray):
64
  image = Image.fromarray(image.astype('uint8'), 'RGB')
65
-
66
  image = val_transform(image)
67
  image = image.unsqueeze(0) # Add batch dimension
68
-
69
  with torch.no_grad():
70
  outputs = model(image)
71
  probabilities = torch.nn.functional.softmax(outputs[0], dim=0)
72
  confidence, preds = torch.max(probabilities, 0)
73
-
74
  confidence_score = confidence.item() * 100
75
-
76
  if confidence_score < 30:
77
- result = "Not identified"
78
  html_result = ""
79
  else:
80
  class_name = class_names[preds.item()]
81
  wiki_link = class_info[preds.item()]
82
  result = f"{class_name}: {confidence_score:.2f}%"
83
  html_result = f"<h1><br><a href='{wiki_link}' target='_blank'>More Info</a></h1>"
84
-
85
  return result, html_result
86
 
87
  # Load class names and class info
@@ -104,11 +107,13 @@ model = load_model(model_paths[current_model], current_model)
104
 
105
  # Create the Gradio interface
106
  iface = gr.Interface(
107
- fn=predict_image,
108
- inputs=[gr.Image(height=500),
109
- gr.Dropdown(choices=["densenet121", "resnet18", "mobilenetv2"], value="densenet121", label="Select Model")],
 
 
110
  outputs=[gr.Label(num_top_classes=1), gr.HTML()],
111
- title="Animnal Classification",
112
  description="Upload an image to get the predicted label",
113
  allow_flagging="never",
114
  )
 
53
  def predict_image(image, model_choice):
54
  global model, current_model
55
 
56
+ # Check if a model is selected
57
+ if model_choice not in model_paths:
58
+ return "Error: Please select a valid model.", ""
59
+
60
  # Load the selected model if it's not already loaded
61
  if model_choice != current_model:
62
  model_path = model_paths[model_choice]
63
  model = load_model(model_path, model_choice)
64
  current_model = model_choice
65
+ # Convert the NumPy array to a PIL Image if needed
 
66
  if isinstance(image, np.ndarray):
67
  image = Image.fromarray(image.astype('uint8'), 'RGB')
68
+
69
  image = val_transform(image)
70
  image = image.unsqueeze(0) # Add batch dimension
71
+
72
  with torch.no_grad():
73
  outputs = model(image)
74
  probabilities = torch.nn.functional.softmax(outputs[0], dim=0)
75
  confidence, preds = torch.max(probabilities, 0)
76
+
77
  confidence_score = confidence.item() * 100
78
+
79
  if confidence_score < 30:
80
+ result = "Not identified\nTry to crop the image or choose another one!"
81
  html_result = ""
82
  else:
83
  class_name = class_names[preds.item()]
84
  wiki_link = class_info[preds.item()]
85
  result = f"{class_name}: {confidence_score:.2f}%"
86
  html_result = f"<h1><br><a href='{wiki_link}' target='_blank'>More Info</a></h1>"
87
+
88
  return result, html_result
89
 
90
  # Load class names and class info
 
107
 
108
  # Create the Gradio interface
109
  iface = gr.Interface(
110
+ fn=predict_image,
111
+ inputs=[
112
+ gr.Image(height=500),
113
+ gr.Dropdown(choices=["densenet121", "resnet18", "mobilenetv2"], value="densenet121", label="Select Model")
114
+ ],
115
  outputs=[gr.Label(num_top_classes=1), gr.HTML()],
116
+ title="Animal Classification",
117
  description="Upload an image to get the predicted label",
118
  allow_flagging="never",
119
  )