djairbee5 commited on
Commit
83b0bb3
1 Parent(s): 59b0a85

new models

Browse files
app.py CHANGED
@@ -19,7 +19,7 @@ def load_class_info(file_path):
19
  return class_info
20
 
21
  # Function to load the model from a .pkl file
22
- def load_model(model_path, model_type='densnet121'):
23
  # Load the model state dictionary
24
  model_state_dict = torch.load(model_path, map_location=torch.device('cpu'))
25
 
@@ -50,7 +50,15 @@ val_transform = transforms.Compose([
50
  ])
51
 
52
  # Define the prediction function
53
- def predict_image(image):
 
 
 
 
 
 
 
 
54
  # Convert the NumPy array to a PIL Image
55
  if isinstance(image, np.ndarray):
56
  image = Image.fromarray(image.astype('uint8'), 'RGB')
@@ -76,19 +84,29 @@ def predict_image(image):
76
 
77
  return result, html_result
78
 
79
- # Load trained model and class names
80
- model_path = 'densenet121_15EpochsPretrainedNoExtractionNoLR_model'
81
  class_file_path = 'classes.txt'
82
  class_info_path = 'classinfo.txt'
83
  class_names = load_class_names(class_file_path)
84
  class_info = load_class_info(class_info_path)
85
  num_classes = len(class_names)
86
- model = load_model(model_path)
 
 
 
 
 
 
 
 
 
 
87
 
88
  # Create the Gradio interface
89
  iface = gr.Interface(
90
  fn=predict_image,
91
- inputs=gr.Image(height=500),
 
92
  outputs=[gr.Label(num_top_classes=1), gr.HTML()],
93
  title="Image Classification",
94
  description="Upload an image to get the predicted label",
 
19
  return class_info
20
 
21
  # Function to load the model from a .pkl file
22
+ def load_model(model_path, model_type):
23
  # Load the model state dictionary
24
  model_state_dict = torch.load(model_path, map_location=torch.device('cpu'))
25
 
 
50
  ])
51
 
52
  # Define the prediction function
53
+ def predict_image(image, model_choice):
54
+ global model, current_model
55
+
56
+ # Load the selected model if it's not already loaded
57
+ if model_choice != current_model:
58
+ model_path = model_paths[model_choice]
59
+ model = load_model(model_path, model_choice)
60
+ current_model = model_choice
61
+
62
  # Convert the NumPy array to a PIL Image
63
  if isinstance(image, np.ndarray):
64
  image = Image.fromarray(image.astype('uint8'), 'RGB')
 
84
 
85
  return result, html_result
86
 
87
+ # Load class names and class info
 
88
  class_file_path = 'classes.txt'
89
  class_info_path = 'classinfo.txt'
90
  class_names = load_class_names(class_file_path)
91
  class_info = load_class_info(class_info_path)
92
  num_classes = len(class_names)
93
+
94
+ # Define model paths
95
+ model_paths = {
96
+ 'densenet121': 'densenet121_15EpochsPretrainedNoExtractionNoLR_model.pkl',
97
+ 'resnet18': 'resnet18_25EpochsPretrainedExtractionNoLR_model.pkl',
98
+ 'mobilenetv2': 'mobilenetv2_25EpochsPretrainedExtractionNoLR_model.pkl'
99
+ }
100
+
101
+ # Set default model
102
+ current_model = 'densenet121'
103
+ model = load_model(model_paths[current_model], current_model)
104
 
105
  # Create the Gradio interface
106
  iface = gr.Interface(
107
  fn=predict_image,
108
+ inputs=[gr.Image(height=500),
109
+ gr.Dropdown(choices=["densenet121", "resnet18", "mobilenetv2"], value="densenet121", label="Select Model")],
110
  outputs=[gr.Label(num_top_classes=1), gr.HTML()],
111
  title="Image Classification",
112
  description="Upload an image to get the predicted label",
mobilenetv2_25EpochsPretrainedExtractionNoLR_model.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6af852b1bdf5723a4d93fbc9ee8c2387d4e1f0665e29bcc03e32b08198016979
3
+ size 9614689
resnet18_25EpochsPretrainedExtractionNoLR_model.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:06d283f256e7fdfe6dfdb3d46e8fc2a99444d07b9ca8d378a0c77cfbe6f5b788
3
+ size 44974189