席亚东 commited on
Commit
fc6a062
1 Parent(s): 05e5527

fix the bug in inferen.py

Browse files
Files changed (1) hide show
  1. inference.py +4 -4
inference.py CHANGED
@@ -87,10 +87,10 @@ class Inference(object):
87
 
88
  model_path = args.path
89
  checkpoint = torch.load(model_path.replace("best.pt", "best_part_1.pt"))
90
- checkpoint["model"].update(model_path.replace("best.pt", "best_part_2.pt"))
91
- checkpoint["model"].update(model_path.replace("best.pt", "best_part_3.pt"))
92
  torch.save(checkpoint, model_path)
93
- # load part 1
94
  state = torch.load(args.path, map_location=torch.device("cpu"))
95
  cfg_args = eval(str(state["cfg"]))["model"]
96
  del cfg_args["_name"]
@@ -178,4 +178,4 @@ class Inference(object):
178
  score = hypo['score'] / math.log(2) # convert to base 2
179
  tmp_res.append([detok_hypo_str, score])
180
  final_results.append(tmp_res)
181
- return final_results
 
87
 
88
  model_path = args.path
89
  checkpoint = torch.load(model_path.replace("best.pt", "best_part_1.pt"))
90
+ checkpoint["model"].update(torch.load(model_path.replace("best.pt", "best_part_2.pt")))
91
+ checkpoint["model"].update(torch.load(model_path.replace("best.pt", "best_part_3.pt")))
92
  torch.save(checkpoint, model_path)
93
+
94
  state = torch.load(args.path, map_location=torch.device("cpu"))
95
  cfg_args = eval(str(state["cfg"]))["model"]
96
  del cfg_args["_name"]
 
178
  score = hypo['score'] / math.log(2) # convert to base 2
179
  tmp_res.append([detok_hypo_str, score])
180
  final_results.append(tmp_res)
181
+ return final_results