Spaces:
Sleeping
Sleeping
new models
Browse files
app.py
CHANGED
@@ -19,7 +19,7 @@ def load_class_info(file_path):
|
|
19 |
return class_info
|
20 |
|
21 |
# Function to load the model from a .pkl file
|
22 |
-
def load_model(model_path, model_type
|
23 |
# Load the model state dictionary
|
24 |
model_state_dict = torch.load(model_path, map_location=torch.device('cpu'))
|
25 |
|
@@ -50,7 +50,15 @@ val_transform = transforms.Compose([
|
|
50 |
])
|
51 |
|
52 |
# Define the prediction function
|
53 |
-
def predict_image(image):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
# Convert the NumPy array to a PIL Image
|
55 |
if isinstance(image, np.ndarray):
|
56 |
image = Image.fromarray(image.astype('uint8'), 'RGB')
|
@@ -76,19 +84,29 @@ def predict_image(image):
|
|
76 |
|
77 |
return result, html_result
|
78 |
|
79 |
-
# Load
|
80 |
-
model_path = 'densenet121_15EpochsPretrainedNoExtractionNoLR_model'
|
81 |
class_file_path = 'classes.txt'
|
82 |
class_info_path = 'classinfo.txt'
|
83 |
class_names = load_class_names(class_file_path)
|
84 |
class_info = load_class_info(class_info_path)
|
85 |
num_classes = len(class_names)
|
86 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
87 |
|
88 |
# Create the Gradio interface
|
89 |
iface = gr.Interface(
|
90 |
fn=predict_image,
|
91 |
-
inputs=gr.Image(height=500),
|
|
|
92 |
outputs=[gr.Label(num_top_classes=1), gr.HTML()],
|
93 |
title="Image Classification",
|
94 |
description="Upload an image to get the predicted label",
|
|
|
19 |
return class_info
|
20 |
|
21 |
# Function to load the model from a .pkl file
|
22 |
+
def load_model(model_path, model_type):
|
23 |
# Load the model state dictionary
|
24 |
model_state_dict = torch.load(model_path, map_location=torch.device('cpu'))
|
25 |
|
|
|
50 |
])
|
51 |
|
52 |
# Define the prediction function
|
53 |
+
def predict_image(image, model_choice):
|
54 |
+
global model, current_model
|
55 |
+
|
56 |
+
# Load the selected model if it's not already loaded
|
57 |
+
if model_choice != current_model:
|
58 |
+
model_path = model_paths[model_choice]
|
59 |
+
model = load_model(model_path, model_choice)
|
60 |
+
current_model = model_choice
|
61 |
+
|
62 |
# Convert the NumPy array to a PIL Image
|
63 |
if isinstance(image, np.ndarray):
|
64 |
image = Image.fromarray(image.astype('uint8'), 'RGB')
|
|
|
84 |
|
85 |
return result, html_result
|
86 |
|
87 |
+
# Load class names and class info
|
|
|
88 |
class_file_path = 'classes.txt'
|
89 |
class_info_path = 'classinfo.txt'
|
90 |
class_names = load_class_names(class_file_path)
|
91 |
class_info = load_class_info(class_info_path)
|
92 |
num_classes = len(class_names)
|
93 |
+
|
94 |
+
# Define model paths
|
95 |
+
model_paths = {
|
96 |
+
'densenet121': 'densenet121_15EpochsPretrainedNoExtractionNoLR_model.pkl',
|
97 |
+
'resnet18': 'resnet18_25EpochsPretrainedExtractionNoLR_model.pkl',
|
98 |
+
'mobilenetv2': 'mobilenetv2_25EpochsPretrainedExtractionNoLR_model.pkl'
|
99 |
+
}
|
100 |
+
|
101 |
+
# Set default model
|
102 |
+
current_model = 'densenet121'
|
103 |
+
model = load_model(model_paths[current_model], current_model)
|
104 |
|
105 |
# Create the Gradio interface
|
106 |
iface = gr.Interface(
|
107 |
fn=predict_image,
|
108 |
+
inputs=[gr.Image(height=500),
|
109 |
+
gr.Dropdown(choices=["densenet121", "resnet18", "mobilenetv2"], value="densenet121", label="Select Model")],
|
110 |
outputs=[gr.Label(num_top_classes=1), gr.HTML()],
|
111 |
title="Image Classification",
|
112 |
description="Upload an image to get the predicted label",
|
mobilenetv2_25EpochsPretrainedExtractionNoLR_model.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6af852b1bdf5723a4d93fbc9ee8c2387d4e1f0665e29bcc03e32b08198016979
|
3 |
+
size 9614689
|
resnet18_25EpochsPretrainedExtractionNoLR_model.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:06d283f256e7fdfe6dfdb3d46e8fc2a99444d07b9ca8d378a0c77cfbe6f5b788
|
3 |
+
size 44974189
|