baxtrax commited on
Commit
11e63b5
·
1 Parent(s): 9d7d62d

Update helpers/models.py

Browse files
Files changed (1) hide show
  1. helpers/models.py +8 -7
helpers/models.py CHANGED
@@ -64,6 +64,7 @@ def setup_model(model):
64
  :return: model
65
  """
66
  curr_dir = os.path.dirname(__file__) + "/../"
 
67
  match model:
68
  case ModelTypes.RESNET:
69
  base = models.resnet18(
@@ -71,14 +72,14 @@ def setup_model(model):
71
  base.fc = nn.Linear(base.fc.in_features, 10)
72
 
73
  base.load_state_dict(
74
- load(curr_dir + "models/resnet_standard_cifar10.pt"))
75
  model = base
76
  case ModelTypes.ALEXNET:
77
  base = models.alexnet(weights=models.AlexNet_Weights.IMAGENET1K_V1)
78
  base.classifier[6] = nn.Linear(base.classifier[6].in_features, 10)
79
 
80
  base.load_state_dict(
81
- load(curr_dir + "models/alexnet_standard_cifar10.pt"))
82
  model = base
83
  case ModelTypes.DENSENET:
84
  base = models.densenet121(
@@ -86,7 +87,7 @@ def setup_model(model):
86
  base.classifier = nn.Linear(base.classifier.in_features, 10)
87
 
88
  base.load_state_dict(
89
- load(curr_dir + "models/densenet_standard_cifar10.pt"))
90
  model = base
91
  case ModelTypes.EFFICIENTNET:
92
  base = hub.load('NVIDIA/DeepLearningExamples:torchhub',
@@ -94,7 +95,7 @@ def setup_model(model):
94
  base.classifier.fc = nn.Linear(base.classifier.fc.in_features, 10)
95
 
96
  base.load_state_dict(
97
- load(curr_dir + "models/efficientnet_standard_cifar10.pt"))
98
  model = base
99
  case ModelTypes.GOOGLENET:
100
  base = models.googlenet(
@@ -102,7 +103,7 @@ def setup_model(model):
102
  base.fc = nn.Linear(base.fc.in_features, 10)
103
 
104
  base.load_state_dict(
105
- load(curr_dir + "models/googlenet_standard_cifar10.pt"))
106
  model = base
107
  case ModelTypes.MOBILENET:
108
  base = models.mobilenet_v2(
@@ -110,7 +111,7 @@ def setup_model(model):
110
  base.classifier[1] = nn.Linear(base.classifier[1].in_features, 10)
111
 
112
  base.load_state_dict(
113
- load(curr_dir + "models/mobilenet_standard_cifar10.pt"))
114
  model = base
115
  case ModelTypes.SQUEEZENET:
116
  base = models.squeezenet1_0(
@@ -119,7 +120,7 @@ def setup_model(model):
119
  512, 10, kernel_size=(1, 1), stride=(1, 1))
120
 
121
  base.load_state_dict(
122
- load(curr_dir + "models/squeezenet_standard_cifar10.pt"))
123
  model = base
124
  case _:
125
  raise ValueError("Unknown model choice")
 
64
  :return: model
65
  """
66
  curr_dir = os.path.dirname(__file__) + "/../"
67
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
68
  match model:
69
  case ModelTypes.RESNET:
70
  base = models.resnet18(
 
72
  base.fc = nn.Linear(base.fc.in_features, 10)
73
 
74
  base.load_state_dict(
75
+ load(curr_dir + "models/resnet_standard_cifar10.pt", map_location=device))
76
  model = base
77
  case ModelTypes.ALEXNET:
78
  base = models.alexnet(weights=models.AlexNet_Weights.IMAGENET1K_V1)
79
  base.classifier[6] = nn.Linear(base.classifier[6].in_features, 10)
80
 
81
  base.load_state_dict(
82
+ load(curr_dir + "models/alexnet_standard_cifar10.pt", map_location=device))
83
  model = base
84
  case ModelTypes.DENSENET:
85
  base = models.densenet121(
 
87
  base.classifier = nn.Linear(base.classifier.in_features, 10)
88
 
89
  base.load_state_dict(
90
+ load(curr_dir + "models/densenet_standard_cifar10.pt", map_location=device))
91
  model = base
92
  case ModelTypes.EFFICIENTNET:
93
  base = hub.load('NVIDIA/DeepLearningExamples:torchhub',
 
95
  base.classifier.fc = nn.Linear(base.classifier.fc.in_features, 10)
96
 
97
  base.load_state_dict(
98
+ load(curr_dir + "models/efficientnet_standard_cifar10.pt", map_location=device))
99
  model = base
100
  case ModelTypes.GOOGLENET:
101
  base = models.googlenet(
 
103
  base.fc = nn.Linear(base.fc.in_features, 10)
104
 
105
  base.load_state_dict(
106
+ load(curr_dir + "models/googlenet_standard_cifar10.pt", map_location=device))
107
  model = base
108
  case ModelTypes.MOBILENET:
109
  base = models.mobilenet_v2(
 
111
  base.classifier[1] = nn.Linear(base.classifier[1].in_features, 10)
112
 
113
  base.load_state_dict(
114
+ load(curr_dir + "models/mobilenet_standard_cifar10.pt", map_location=device))
115
  model = base
116
  case ModelTypes.SQUEEZENET:
117
  base = models.squeezenet1_0(
 
120
  512, 10, kernel_size=(1, 1), stride=(1, 1))
121
 
122
  base.load_state_dict(
123
+ load(curr_dir + "models/squeezenet_standard_cifar10.pt", map_location=device))
124
  model = base
125
  case _:
126
  raise ValueError("Unknown model choice")