microhum commited on
Commit
e0c5376
·
1 Parent(s): 9de6eae

map location model to cpu

Browse files
Files changed (2) hide show
  1. test_few_shot.py +1 -1
  2. train.py +0 -1
test_few_shot.py CHANGED
@@ -31,7 +31,7 @@ def test_main_model(opts):
31
  st.write("Loading Model Weight...")
32
  model_main = ModelMain(opts)
33
  path_ckpt = os.path.join(f"{opts.model_path}")
34
- model_main.load_state_dict(torch.load(path_ckpt)['model'])
35
  model_main.to(device)
36
  model_main.eval()
37
  with torch.no_grad():
 
31
  st.write("Loading Model Weight...")
32
  model_main = ModelMain(opts)
33
  path_ckpt = os.path.join(f"{opts.model_path}")
34
+ model_main.load_state_dict(torch.load(path_ckpt)['model'], map_location=torch.device('cpu'))
35
  model_main.to(device)
36
  model_main.eval()
37
  with torch.no_grad():
train.py CHANGED
@@ -35,7 +35,6 @@ def train_main_model(opts):
35
  val_loader = get_loader(opts.data_root, opts.img_size, opts.language, opts.char_num, opts.max_seq_len, opts.dim_seq, opts.batch_size_val, 'val')
36
 
37
  run = wandb.init(project=opts.wandb_project_name, config=opts) # initialize wandb project
38
- text_table = wandb.Table(columns=["epoch", "loss", "ref"])
39
 
40
  model_main = ModelMain(opts)
41
  if torch.cuda.is_available() and opts.multi_gpu:
 
35
  val_loader = get_loader(opts.data_root, opts.img_size, opts.language, opts.char_num, opts.max_seq_len, opts.dim_seq, opts.batch_size_val, 'val')
36
 
37
  run = wandb.init(project=opts.wandb_project_name, config=opts) # initialize wandb project
 
38
 
39
  model_main = ModelMain(opts)
40
  if torch.cuda.is_available() and opts.multi_gpu: