Kaelan commited on
Commit
253626a
1 Parent(s): 287cd5c

Update model_tools.py

Browse files
Files changed (1) hide show
  1. model_tools.py +2 -1
model_tools.py CHANGED
@@ -40,7 +40,8 @@ def get_prediction(model, image_in, pipeline):
40
 
41
  # Predict
42
  with torch.no_grad():
43
- torch_input = torch.Tensor(preprocessed_image).unsqueeze(0).to('cpu')
 
44
  model_output = model(torch_input)
45
  prediction = pipeline._decode_model_output(model_output, model_input=torch_input)
46
  # Postprocess
 
40
 
41
  # Predict
42
  with torch.no_grad():
43
+ #torch_input = torch.Tensor(preprocessed_image).unsqueeze(0).to('cpu')
44
+ torch_input = torch.Tensor(preprocessed_image).unsqueeze(0).to("cuda" if torch.cuda.is_available() else "cpu")
45
  model_output = model(torch_input)
46
  prediction = pipeline._decode_model_output(model_output, model_input=torch_input)
47
  # Postprocess