leeyunjai commited on
Commit
ea657a8
1 Parent(s): 31727e1

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +2 -2
main.py CHANGED
@@ -20,9 +20,9 @@ def under_max(image):
20
  return image
21
 
22
  class Model(object):
23
- def __init__(self, gpu=0):
24
  config = Config()
25
- config.device = 'cuda:{}'.format(gpu)
26
  model, _ = caption_model.build_model(config)
27
  checkpoint = torch.load('./checkpoint.pth', map_location='cpu')
28
  model.load_state_dict(checkpoint['model'])
 
20
  return image
21
 
22
  class Model(object):
23
+ def __init__(self, gpu=None):
24
  config = Config()
25
+ config.device = 'cpu' if gpu is None else 'cuda:{}'.format(gpu)
26
  model, _ = caption_model.build_model(config)
27
  checkpoint = torch.load('./checkpoint.pth', map_location='cpu')
28
  model.load_state_dict(checkpoint['model'])