Spaces:
Sleeping
Sleeping
map location model to cpu
Browse files- test_few_shot.py +1 -1
- train.py +0 -1
test_few_shot.py
CHANGED
@@ -31,7 +31,7 @@ def test_main_model(opts):
|
|
31 |
st.write("Loading Model Weight...")
|
32 |
model_main = ModelMain(opts)
|
33 |
path_ckpt = os.path.join(f"{opts.model_path}")
|
34 |
-
model_main.load_state_dict(torch.load(path_ckpt)['model'])
|
35 |
model_main.to(device)
|
36 |
model_main.eval()
|
37 |
with torch.no_grad():
|
|
|
31 |
st.write("Loading Model Weight...")
|
32 |
model_main = ModelMain(opts)
|
33 |
path_ckpt = os.path.join(f"{opts.model_path}")
|
34 |
+
model_main.load_state_dict(torch.load(path_ckpt)['model'], map_location=torch.device('cpu'))
|
35 |
model_main.to(device)
|
36 |
model_main.eval()
|
37 |
with torch.no_grad():
|
train.py
CHANGED
@@ -35,7 +35,6 @@ def train_main_model(opts):
|
|
35 |
val_loader = get_loader(opts.data_root, opts.img_size, opts.language, opts.char_num, opts.max_seq_len, opts.dim_seq, opts.batch_size_val, 'val')
|
36 |
|
37 |
run = wandb.init(project=opts.wandb_project_name, config=opts) # initialize wandb project
|
38 |
-
text_table = wandb.Table(columns=["epoch", "loss", "ref"])
|
39 |
|
40 |
model_main = ModelMain(opts)
|
41 |
if torch.cuda.is_available() and opts.multi_gpu:
|
|
|
35 |
val_loader = get_loader(opts.data_root, opts.img_size, opts.language, opts.char_num, opts.max_seq_len, opts.dim_seq, opts.batch_size_val, 'val')
|
36 |
|
37 |
run = wandb.init(project=opts.wandb_project_name, config=opts) # initialize wandb project
|
|
|
38 |
|
39 |
model_main = ModelMain(opts)
|
40 |
if torch.cuda.is_available() and opts.multi_gpu:
|