TedYeh commited on
Commit
a116653
1 Parent(s): bf809f1

update predictor

Browse files
Files changed (1) hide show
  1. predictor.py +2 -2
predictor.py CHANGED
@@ -198,11 +198,11 @@ def evaluation(model, epoch, device, dataloaders):
198
  print(preds)
199
 
200
  def inference(inp_img, classes = ['big', 'small'], epoch = 6):
201
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
202
  translator= Translator(to_lang="zh-TW")
203
 
204
  model = CUPredictor()
205
- model.load_state_dict(torch.load(f'models/model_{epoch}.pt'))
206
  # load image-to-text model
207
  processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
208
  model_blip = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
 
198
  print(preds)
199
 
200
  def inference(inp_img, classes = ['big', 'small'], epoch = 6):
201
+ device = torch.device("cpu")
202
  translator= Translator(to_lang="zh-TW")
203
 
204
  model = CUPredictor()
205
+ model.load_state_dict(torch.load(f'models/model_{epoch}.pt', map_location=torch.device('cpu')))
206
  # load image-to-text model
207
  processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
208
  model_blip = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")