TedYeh commited on
Commit
bf809f1
1 Parent(s): a38dfb6

update predictor

Browse files
Files changed (1) hide show
  1. predictor.py +2 -2
predictor.py CHANGED
@@ -198,7 +198,7 @@ def evaluation(model, epoch, device, dataloaders):
198
  print(preds)
199
 
200
  def inference(inp_img, classes = ['big', 'small'], epoch = 6):
201
- device = torch.device("cuda")
202
  translator= Translator(to_lang="zh-TW")
203
 
204
  model = CUPredictor()
@@ -218,7 +218,7 @@ def inference(inp_img, classes = ['big', 'small'], epoch = 6):
218
  image_tensor = trans(inp_img)
219
  image_tensor = image_tensor.unsqueeze(0)
220
  with torch.no_grad():
221
- inputs = image_tensor
222
  outputs_c, outputs_h, outputs_b, outputs_w, outputs_hi = model(inputs)
223
  _, preds = torch.max(outputs_c, 1)
224
  idx = preds.numpy()[0]
 
198
  print(preds)
199
 
200
  def inference(inp_img, classes = ['big', 'small'], epoch = 6):
201
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
202
  translator= Translator(to_lang="zh-TW")
203
 
204
  model = CUPredictor()
 
218
  image_tensor = trans(inp_img)
219
  image_tensor = image_tensor.unsqueeze(0)
220
  with torch.no_grad():
221
+ inputs = image_tensor.to(device)
222
  outputs_c, outputs_h, outputs_b, outputs_w, outputs_hi = model(inputs)
223
  _, preds = torch.max(outputs_c, 1)
224
  idx = preds.numpy()[0]