Spaces:
Sleeping
Sleeping
better model
Browse files
app.py
CHANGED
@@ -13,13 +13,23 @@ def load_class_names(file_path):
|
|
13 |
class_names = [line.strip() for line in f.readlines()]
|
14 |
return class_names
|
15 |
# Function to load the model from a .pkl file
|
16 |
-
def load_model(model_path):
|
17 |
# Load the model state dictionary
|
18 |
model_state_dict = torch.load(model_path, map_location=torch.device('cpu'))
|
19 |
-
|
20 |
-
# Create an instance of the model
|
21 |
-
|
22 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
# Load the state dictionary into the model
|
24 |
model.load_state_dict(model_state_dict)
|
25 |
# Set the model to evaluation mode
|
@@ -51,8 +61,8 @@ def predict_image(image):
|
|
51 |
confidence_score = confidence.item() * 100
|
52 |
return f"{class_name}: {confidence_score:.2f}%"
|
53 |
|
54 |
-
# Load
|
55 |
-
model_path = '
|
56 |
class_file_path = 'classes.txt'
|
57 |
class_names = load_class_names(class_file_path)
|
58 |
num_classes = len(class_names)
|
|
|
13 |
class_names = [line.strip() for line in f.readlines()]
|
14 |
return class_names
|
15 |
# Function to load the model from a .pkl file
|
16 |
+
def load_model(model_path, model_type='resnet'):
|
17 |
# Load the model state dictionary
|
18 |
model_state_dict = torch.load(model_path, map_location=torch.device('cpu'))
|
19 |
+
|
20 |
+
# Create an instance of the model based on model_type
|
21 |
+
if model_type == 'mobilenet':
|
22 |
+
model = models.mobilenet_v2(pretrained=False)
|
23 |
+
model.classifier[1] = nn.Linear(model.last_channel, num_classes)
|
24 |
+
elif model_type == 'resnet':
|
25 |
+
model = models.resnet50(pretrained=False)
|
26 |
+
model.fc = nn.Linear(model.fc.in_features, num_classes)
|
27 |
+
elif model_type == 'densenet':
|
28 |
+
model = models.densenet121(pretrained=False)
|
29 |
+
model.classifier = nn.Linear(model.classifier.in_features, num_classes)
|
30 |
+
else:
|
31 |
+
raise ValueError(f"Unsupported model type: {model_type}")
|
32 |
+
|
33 |
# Load the state dictionary into the model
|
34 |
model.load_state_dict(model_state_dict)
|
35 |
# Set the model to evaluation mode
|
|
|
61 |
confidence_score = confidence.item() * 100
|
62 |
return f"{class_name}: {confidence_score:.2f}%"
|
63 |
|
64 |
+
# Load trained model and class names
|
65 |
+
model_path = 'resnet30EpochsPretrainedNFeatureX_model.pkl'
|
66 |
class_file_path = 'classes.txt'
|
67 |
class_names = load_class_names(class_file_path)
|
68 |
num_classes = len(class_names)
|
densenet30EpochsPretrainedNFeatureX_model.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:66c1769d7f5c4415d6f6f7827b74b8730036271d759f98b0a8eb489849434d11
|
3 |
+
size 28813477
|
mobilenet30EpochsPretrainedNFeatureX_model.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:506ae238566267e277c4d952cacf6da4b87f827e7f30c7302b423809a5c0047c
|
3 |
+
size 9612161
|
resnet30EpochsPretrainedNFeatureX_model.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:edf0def0a33c6cd61025b08ee9a6256238bc26c0871c2488a7519b46c58918b3
|
3 |
+
size 95097345
|