TheoBH commited on
Commit
4db2969
·
verified ·
1 Parent(s): 4e06e61

Update predict.py

Browse files
Files changed (1) hide show
  1. predict.py +3 -0
predict.py CHANGED
@@ -53,6 +53,9 @@ def predict_masks(input_img_path: str):
53
  ## pass input image through image processor
54
  image = Image.open(input_img_path)
55
  inputs = image_processor(images=image, return_tensors="pt")
 
 
 
56
 
57
  ## pass inputs to model for prediction
58
  with torch.no_grad():
 
53
  ## pass input image through image processor
54
  image = Image.open(input_img_path)
55
  inputs = image_processor(images=image, return_tensors="pt")
56
+
57
+ # Move inputs to the same device as the model
58
+ inputs = {name: tensor.to(model.device) for name, tensor in inputs.items()}
59
 
60
  ## pass inputs to model for prediction
61
  with torch.no_grad():