席亚东
commited on
Commit
•
16bf127
1
Parent(s):
ef2abea
fix the bug in inference.py
Browse files- inference.py +7 -2
inference.py
CHANGED
@@ -7,6 +7,7 @@ import torch
|
|
7 |
from torch.nn.utils.rnn import pad_sequence
|
8 |
|
9 |
from fairseq import checkpoint_utils, options, tasks, utils
|
|
|
10 |
|
11 |
Batch = namedtuple('Batch', 'ids src_tokens src_lengths')
|
12 |
|
@@ -77,7 +78,12 @@ class Inference(object):
|
|
77 |
use_cuda = torch.cuda.is_available() and not args.cpu
|
78 |
self.use_cuda = use_cuda
|
79 |
|
80 |
-
|
|
|
|
|
|
|
|
|
|
|
81 |
state = torch.load(args.path, map_location=torch.device("cpu"))
|
82 |
cfg_args = eval(str(state["cfg"]))["model"]
|
83 |
del cfg_args["_name"]
|
@@ -97,7 +103,6 @@ class Inference(object):
|
|
97 |
"max_batch":eet_batch_size,
|
98 |
"full_seq_len":eet_seq_len}
|
99 |
print(model_args)
|
100 |
-
from eet.fairseq.transformer import EETTransformerDecoder
|
101 |
eet_model = EETTransformerDecoder.from_fairseq_pretrained(model_id_or_path = args.path,
|
102 |
dictionary = self.src_dict,args=model_args,
|
103 |
config = eet_config,
|
|
|
7 |
from torch.nn.utils.rnn import pad_sequence
|
8 |
|
9 |
from fairseq import checkpoint_utils, options, tasks, utils
|
10 |
+
from eet.fairseq.transformer import EETTransformerDecoder
|
11 |
|
12 |
Batch = namedtuple('Batch', 'ids src_tokens src_lengths')
|
13 |
|
|
|
78 |
use_cuda = torch.cuda.is_available() and not args.cpu
|
79 |
self.use_cuda = use_cuda
|
80 |
|
81 |
+
model_path = args.path
|
82 |
+
checkpoint = torch.load(model_path.replace("best.pt", "best_part_1.pt"))
|
83 |
+
checkpoint["model"].update(torch.load(model_path.replace("best.pt", "best_part_2.pt")))
|
84 |
+
checkpoint["model"].update(torch.load(model_path.replace("best.pt", "best_part_3.pt")))
|
85 |
+
torch.save(checkpoint, model_path)
|
86 |
+
|
87 |
state = torch.load(args.path, map_location=torch.device("cpu"))
|
88 |
cfg_args = eval(str(state["cfg"]))["model"]
|
89 |
del cfg_args["_name"]
|
|
|
103 |
"max_batch":eet_batch_size,
|
104 |
"full_seq_len":eet_seq_len}
|
105 |
print(model_args)
|
|
|
106 |
eet_model = EETTransformerDecoder.from_fairseq_pretrained(model_id_or_path = args.path,
|
107 |
dictionary = self.src_dict,args=model_args,
|
108 |
config = eet_config,
|