farrosalferro24 commited on
Commit
979164d
1 Parent(s): cb7e99f

Update script.py

Browse files
Files changed (1) hide show
  1. script.py +4 -2
script.py CHANGED
@@ -35,7 +35,9 @@ class PytorchWorker:
35
  model.heads.head = torch.nn.Linear(model.heads.head.in_features,
36
  number_of_categories)
37
 
38
- model.load_state_dict(torch.load(model_path))
 
 
39
 
40
  # if not torch.cuda.is_available():
41
  # model_ckpt = torch.load(model_path, map_location=torch.device("cpu"))
@@ -45,7 +47,7 @@ class PytorchWorker:
45
  # model_ckpt = torch.load(model_path, map_location=self.device)
46
  # model.load_state_dict(model_ckpt)
47
 
48
- return model.to(self.device).eval()
49
 
50
  self.model = _load_model(model_name, model_path)
51
 
 
35
  model.heads.head = torch.nn.Linear(model.heads.head.in_features,
36
  number_of_categories)
37
 
38
+ model = model.to(self.device)
39
+
40
+ model.load_state_dict(torch.load(model_path, map_location=self.device))
41
 
42
  # if not torch.cuda.is_available():
43
  # model_ckpt = torch.load(model_path, map_location=torch.device("cpu"))
 
47
  # model_ckpt = torch.load(model_path, map_location=self.device)
48
  # model.load_state_dict(model_ckpt)
49
 
50
+ return model.eval()
51
 
52
  self.model = _load_model(model_name, model_path)
53