picekl commited on
Commit
9f0d310
1 Parent(s): bba9145

fix:gpu support

Browse files
Files changed (1) hide show
  1. script.py +5 -5
script.py CHANGED
@@ -21,8 +21,8 @@ class PytorchWorker:
21
  def _load_model(model_name, model_path):
22
 
23
  print("Setting up Pytorch Model")
24
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
25
- print(f"Using devide: {device}")
26
 
27
  model = timm.create_model(model_name, num_classes=number_of_categories, pretrained=False)
28
 
@@ -31,10 +31,10 @@ class PytorchWorker:
31
  # else:
32
  # model_ckpt = torch.load(model_path)
33
 
34
- model_ckpt = torch.load(model_path, map_location=torch.device("cpu"))
35
  model.load_state_dict(model_ckpt)
36
 
37
- return model.to(device).eval()
38
 
39
  self.model = _load_model(model_name, model_path)
40
 
@@ -50,7 +50,7 @@ class PytorchWorker:
50
  :return: A list with logits and confidences.
51
  """
52
 
53
- logits = self.model(self.transforms(image).unsqueeze(0))
54
 
55
  return logits.tolist()
56
 
 
21
  def _load_model(model_name, model_path):
22
 
23
  print("Setting up Pytorch Model")
24
+ self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
25
+ print(f"Using devide: {self.device}")
26
 
27
  model = timm.create_model(model_name, num_classes=number_of_categories, pretrained=False)
28
 
 
31
  # else:
32
  # model_ckpt = torch.load(model_path)
33
 
34
+ model_ckpt = torch.load(model_path, map_location=self.device)
35
  model.load_state_dict(model_ckpt)
36
 
37
+ return model.to(self.device).eval()
38
 
39
  self.model = _load_model(model_name, model_path)
40
 
 
50
  :return: A list with logits and confidences.
51
  """
52
 
53
+ logits = self.model(self.transforms(image).unsqueeze(0).to(self.device))
54
 
55
  return logits.tolist()
56