farrosalferro24
commited on
Commit
•
979164d
1
Parent(s):
cb7e99f
Update script.py
Browse files
script.py
CHANGED
@@ -35,7 +35,9 @@ class PytorchWorker:
|
|
35 |
model.heads.head = torch.nn.Linear(model.heads.head.in_features,
|
36 |
number_of_categories)
|
37 |
|
38 |
-
model.
|
|
|
|
|
39 |
|
40 |
# if not torch.cuda.is_available():
|
41 |
# model_ckpt = torch.load(model_path, map_location=torch.device("cpu"))
|
@@ -45,7 +47,7 @@ class PytorchWorker:
|
|
45 |
# model_ckpt = torch.load(model_path, map_location=self.device)
|
46 |
# model.load_state_dict(model_ckpt)
|
47 |
|
48 |
-
return model.
|
49 |
|
50 |
self.model = _load_model(model_name, model_path)
|
51 |
|
|
|
35 |
model.heads.head = torch.nn.Linear(model.heads.head.in_features,
|
36 |
number_of_categories)
|
37 |
|
38 |
+
model = model.to(self.device)
|
39 |
+
|
40 |
+
model.load_state_dict(torch.load(model_path, map_location=self.device))
|
41 |
|
42 |
# if not torch.cuda.is_available():
|
43 |
# model_ckpt = torch.load(model_path, map_location=torch.device("cpu"))
|
|
|
47 |
# model_ckpt = torch.load(model_path, map_location=self.device)
|
48 |
# model.load_state_dict(model_ckpt)
|
49 |
|
50 |
+
return model.eval()
|
51 |
|
52 |
self.model = _load_model(model_name, model_path)
|
53 |
|