djairbee5 commited on
Commit
b73a2e9
1 Parent(s): 96807e0

better model

Browse files
app.py CHANGED
@@ -13,13 +13,23 @@ def load_class_names(file_path):
13
  class_names = [line.strip() for line in f.readlines()]
14
  return class_names
15
  # Function to load the model from a .pkl file
16
- def load_model(model_path):
17
  # Load the model state dictionary
18
  model_state_dict = torch.load(model_path, map_location=torch.device('cpu'))
19
-
20
- # Create an instance of the model
21
- model = models.mobilenet_v2(pretrained=False)
22
- model.classifier[1] = nn.Linear(model.last_channel, num_classes) # Adjust to your number of classes
 
 
 
 
 
 
 
 
 
 
23
  # Load the state dictionary into the model
24
  model.load_state_dict(model_state_dict)
25
  # Set the model to evaluation mode
@@ -51,8 +61,8 @@ def predict_image(image):
51
  confidence_score = confidence.item() * 100
52
  return f"{class_name}: {confidence_score:.2f}%"
53
 
54
- # Load your trained model and class names
55
- model_path = 'mobilenet_model.pkl'
56
  class_file_path = 'classes.txt'
57
  class_names = load_class_names(class_file_path)
58
  num_classes = len(class_names)
 
13
  class_names = [line.strip() for line in f.readlines()]
14
  return class_names
15
  # Function to load the model from a .pkl file
16
+ def load_model(model_path, model_type='resnet'):
17
  # Load the model state dictionary
18
  model_state_dict = torch.load(model_path, map_location=torch.device('cpu'))
19
+
20
+ # Create an instance of the model based on model_type
21
+ if model_type == 'mobilenet':
22
+ model = models.mobilenet_v2(pretrained=False)
23
+ model.classifier[1] = nn.Linear(model.last_channel, num_classes)
24
+ elif model_type == 'resnet':
25
+ model = models.resnet50(pretrained=False)
26
+ model.fc = nn.Linear(model.fc.in_features, num_classes)
27
+ elif model_type == 'densenet':
28
+ model = models.densenet121(pretrained=False)
29
+ model.classifier = nn.Linear(model.classifier.in_features, num_classes)
30
+ else:
31
+ raise ValueError(f"Unsupported model type: {model_type}")
32
+
33
  # Load the state dictionary into the model
34
  model.load_state_dict(model_state_dict)
35
  # Set the model to evaluation mode
 
61
  confidence_score = confidence.item() * 100
62
  return f"{class_name}: {confidence_score:.2f}%"
63
 
64
+ # Load trained model and class names
65
+ model_path = 'resnet30EpochsPretrainedNFeatureX_model.pkl'
66
  class_file_path = 'classes.txt'
67
  class_names = load_class_names(class_file_path)
68
  num_classes = len(class_names)
densenet30EpochsPretrainedNFeatureX_model.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:66c1769d7f5c4415d6f6f7827b74b8730036271d759f98b0a8eb489849434d11
3
+ size 28813477
mobilenet30EpochsPretrainedNFeatureX_model.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:506ae238566267e277c4d952cacf6da4b87f827e7f30c7302b423809a5c0047c
3
+ size 9612161
resnet30EpochsPretrainedNFeatureX_model.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:edf0def0a33c6cd61025b08ee9a6256238bc26c0871c2488a7519b46c58918b3
3
+ size 95097345