JustinLin610's picture
first commit
ee21b96
raw
history blame
No virus
7.1 kB
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import os
from fairseq import checkpoint_utils, tasks
import sentencepiece as spm
import torch
try:
from simuleval import READ_ACTION, WRITE_ACTION, DEFAULT_EOS
from simuleval.agents import TextAgent
except ImportError:
print("Please install simuleval 'pip install simuleval'")
BOS_PREFIX = "\u2581"
class SimulTransTextAgentJA(TextAgent):
"""
Simultaneous Translation
Text agent for Japanese
"""
def __init__(self, args):
# Whether use gpu
self.gpu = getattr(args, "gpu", False)
# Max len
self.max_len = args.max_len
# Load Model
self.load_model_vocab(args)
# build word splitter
self.build_word_splitter(args)
self.eos = DEFAULT_EOS
def initialize_states(self, states):
states.incremental_states = dict()
states.incremental_states["online"] = dict()
def to_device(self, tensor):
if self.gpu:
return tensor.cuda()
else:
return tensor.cpu()
def load_model_vocab(self, args):
filename = args.model_path
if not os.path.exists(filename):
raise IOError("Model file not found: {}".format(filename))
state = checkpoint_utils.load_checkpoint_to_cpu(filename)
task_args = state["cfg"]["task"]
task_args.data = args.data_bin
task = tasks.setup_task(task_args)
# build model for ensemble
state["cfg"]["model"].load_pretrained_encoder_from = None
state["cfg"]["model"].load_pretrained_decoder_from = None
self.model = task.build_model(state["cfg"]["model"])
self.model.load_state_dict(state["model"], strict=True)
self.model.eval()
self.model.share_memory()
if self.gpu:
self.model.cuda()
# Set dictionary
self.dict = {}
self.dict["tgt"] = task.target_dictionary
self.dict["src"] = task.source_dictionary
@staticmethod
def add_args(parser):
# fmt: off
parser.add_argument('--model-path', type=str, required=True,
help='path to your pretrained model.')
parser.add_argument("--data-bin", type=str, required=True,
help="Path of data binary")
parser.add_argument("--max-len", type=int, default=100,
help="Max length of translation")
parser.add_argument("--tgt-splitter-type", type=str, default="SentencePiece",
help="Subword splitter type for target text.")
parser.add_argument("--tgt-splitter-path", type=str, default=None,
help="Subword splitter model path for target text.")
parser.add_argument("--src-splitter-type", type=str, default="SentencePiece",
help="Subword splitter type for source text.")
parser.add_argument("--src-splitter-path", type=str, default=None,
help="Subword splitter model path for source text.")
# fmt: on
return parser
def build_word_splitter(self, args):
self.spm = {}
for lang in ['src', 'tgt']:
if getattr(args, f'{lang}_splitter_type', None):
path = getattr(args, f'{lang}_splitter_path', None)
if path:
self.spm[lang] = spm.SentencePieceProcessor()
self.spm[lang].Load(path)
def segment_to_units(self, segment, states):
# Split a full word (segment) into subwords (units)
return self.spm['src'].EncodeAsPieces(segment)
def update_model_encoder(self, states):
if len(states.units.source) == 0:
return
src_indices = [
self.dict['src'].index(x)
for x in states.units.source.value
]
if states.finish_read():
# Append the eos index when the prediction is over
src_indices += [self.dict["tgt"].eos_index]
src_indices = self.to_device(
torch.LongTensor(src_indices).unsqueeze(0)
)
src_lengths = self.to_device(
torch.LongTensor([src_indices.size(1)])
)
states.encoder_states = self.model.encoder(src_indices, src_lengths)
torch.cuda.empty_cache()
def update_states_read(self, states):
# Happens after a read action.
self.update_model_encoder(states)
def units_to_segment(self, units, states):
# Merge sub words (units) to full word (segment).
# For Japanese, we can directly send
# the untokenized token to server except the BOS token
# with following option
# --sacrebleu-tokenizer MeCab
# --eval-latency-unit char
# --no-space
token = units.value.pop()
if (
token == self.dict["tgt"].eos_word
or len(states.segments.target) > self.max_len
):
return DEFAULT_EOS
if BOS_PREFIX == token:
return None
if token[0] == BOS_PREFIX:
return token[1:]
else:
return token
def policy(self, states):
if not getattr(states, "encoder_states", None):
# No encoder states, read a token first
return READ_ACTION
# encode previous predicted target tokens
tgt_indices = self.to_device(
torch.LongTensor(
[self.model.decoder.dictionary.eos()]
+ [
self.dict['tgt'].index(x)
for x in states.units.target.value
if x is not None
]
).unsqueeze(0)
)
# Current steps
states.incremental_states["steps"] = {
"src": states.encoder_states["encoder_out"][0].size(0),
"tgt": 1 + len(states.units.target),
}
# Online only means the reading is not finished
states.incremental_states["online"]["only"] = (
torch.BoolTensor([not states.finish_read()])
)
x, outputs = self.model.decoder.forward(
prev_output_tokens=tgt_indices,
encoder_out=states.encoder_states,
incremental_state=states.incremental_states,
)
states.decoder_out = x
torch.cuda.empty_cache()
if outputs.action == 0:
return READ_ACTION
else:
return WRITE_ACTION
def predict(self, states):
# Predict target token from decoder states
decoder_states = states.decoder_out
lprobs = self.model.get_normalized_probs(
[decoder_states[:, -1:]], log_probs=True
)
index = lprobs.argmax(dim=-1)[0, 0].item()
if index != self.dict['tgt'].eos_index:
token = self.dict['tgt'].string([index])
else:
token = self.dict['tgt'].eos_word
return token