|
import torch |
|
import pyonmttok |
|
from onmt.constants import DefaultTokens, CorpusTask, ModelTask |
|
from torch.nn.utils.rnn import pad_sequence |
|
from onmt.utils.logging import logger |
|
from collections import Counter |
|
|
|
|
|
def parse_features(line, n_feats=0, defaults=None): |
|
""" |
|
Parses text lines with features appended to each token. |
|
Ex.: This│A│B is│A│A a│C│A test│A│B |
|
""" |
|
text, feats = [], [[] for _ in range(n_feats)] |
|
check, count = 0, 0 |
|
for token in line.split(" "): |
|
tok, *fts = token.strip().split("│") |
|
check += len(fts) |
|
count += 1 |
|
if not fts and defaults is not None: |
|
if isinstance(defaults, str): |
|
defaults = defaults.split("│") |
|
if n_feats > 0: |
|
assert len(defaults) == n_feats |
|
fts = defaults |
|
assert len(fts) == n_feats, ( |
|
f"The number of fetures does not match the " |
|
f"expected number of features. Found {len(fts)} " |
|
f"features in the data but {n_feats} were expected." |
|
) |
|
text.append(tok) |
|
for i in range(n_feats): |
|
feats[i].append(fts[i]) |
|
|
|
assert ( |
|
check == 0 or check == count * n_feats |
|
), "Some tokens are missing features. Please check your data." |
|
feats = [" ".join(x) for x in feats] if n_feats > 0 else None |
|
return " ".join(text), feats |
|
|
|
|
|
def append_features_to_text(text, features): |
|
""" |
|
It appends features to subwords when dumping to file |
|
""" |
|
text_tok = text.split(" ") |
|
feats_tok = [x.split(" ") for x in features] |
|
|
|
pretty_toks = [] |
|
for tok, *feats in zip(text_tok, *feats_tok): |
|
feats = "│".join(feats) |
|
if feats: |
|
pretty_toks.append(f"{tok}│{feats}") |
|
else: |
|
pretty_toks.append(tok) |
|
return " ".join(pretty_toks) |
|
|
|
|
|
def text_sort_key(ex): |
|
"""Sort using the number of tokens in the sequence.""" |
|
if ex["tgt"]: |
|
return len(ex["src"]["src_ids"]), len(ex["tgt"]["tgt_ids"]) |
|
return len(ex["src"]["src_ids"]) |
|
|
|
|
|
def clean_example(maybe_example): |
|
maybe_example["src"] = {"src": " ".join(maybe_example["src"])} |
|
|
|
|
|
if "src_feats" in maybe_example: |
|
maybe_example["src"]["feats"] = [ |
|
" ".join(x) for x in maybe_example["src_feats"] |
|
] |
|
del maybe_example["src_feats"] |
|
if maybe_example["tgt"] is not None: |
|
maybe_example["tgt"] = {"tgt": " ".join(maybe_example["tgt"])} |
|
if "align" in maybe_example: |
|
maybe_example["align"] = " ".join(maybe_example["align"]) |
|
return maybe_example |
|
|
|
|
|
def process(task, bucket, **kwargs): |
|
"""Returns valid transformed bucket from bucket.""" |
|
transform_cid_to_examples = {} |
|
for example in bucket: |
|
transform_cid = (example[1], example[2]) |
|
if transform_cid not in transform_cid_to_examples: |
|
transform_cid_to_examples[transform_cid] = [] |
|
transform_cid_to_examples[transform_cid].append(example) |
|
|
|
processed_bucket = [] |
|
for (transform, cid), sub_bucket in transform_cid_to_examples.items(): |
|
transf_bucket = transform.batch_apply( |
|
sub_bucket, is_train=(task == CorpusTask.TRAIN), corpus_name=cid |
|
) |
|
for example, transform, cid in transf_bucket: |
|
example = clean_example(example) |
|
if len(example["src"]["src"]) > 0: |
|
processed_bucket.append(example) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if len(processed_bucket) > 0: |
|
return processed_bucket |
|
else: |
|
return None |
|
|
|
|
|
def numericalize(vocabs, example): |
|
""" """ |
|
decoder_start_token = vocabs["decoder_start_token"] |
|
numeric = example |
|
numeric["src"]["src_ids"] = [] |
|
if vocabs["data_task"] == ModelTask.SEQ2SEQ: |
|
src_text = example["src"]["src"].split() |
|
numeric["src"]["src_ids"] = vocabs["src"](src_text) |
|
if example["tgt"] is not None: |
|
numeric["tgt"]["tgt_ids"] = [] |
|
tgt_text = example["tgt"]["tgt"].split() |
|
numeric["tgt"]["tgt_ids"] = vocabs["tgt"]( |
|
[decoder_start_token] + tgt_text + [DefaultTokens.EOS] |
|
) |
|
|
|
elif vocabs["data_task"] == ModelTask.LANGUAGE_MODEL: |
|
src_text = example["src"]["src"].split() |
|
if decoder_start_token != "": |
|
src_text = [decoder_start_token] + src_text |
|
numeric["src"]["src_ids"] = vocabs["src"](src_text) |
|
if example["tgt"] is not None: |
|
numeric["tgt"]["tgt_ids"] = [] |
|
tgt_text = example["tgt"]["tgt"].split() |
|
numeric["tgt"]["tgt_ids"] = vocabs["tgt"](tgt_text + [DefaultTokens.EOS]) |
|
else: |
|
raise ValueError(f"Something went wrong with task {vocabs['data_task']}") |
|
|
|
if "feats" in example["src"]: |
|
numeric_feats = [] |
|
for fv, feat in zip(vocabs["src_feats"], example["src"]["feats"]): |
|
numeric_feats.append(fv(feat.split())) |
|
numeric["src"]["feats"] = numeric_feats |
|
|
|
return numeric |
|
|
|
|
|
def parse_align_idx(align_pharaoh): |
|
""" |
|
Parse Pharaoh alignment into [[<src>, <tgt>], ...] |
|
""" |
|
align_list = align_pharaoh.strip().split(" ") |
|
flatten_align_idx = [] |
|
for align in align_list: |
|
try: |
|
src_idx, tgt_idx = align.split("-") |
|
except ValueError: |
|
logger.warning("{} in `{}`".format(align, align_pharaoh)) |
|
logger.warning("Bad alignement line exists. Please check file!") |
|
raise |
|
flatten_align_idx.append([int(src_idx), int(tgt_idx)]) |
|
return flatten_align_idx |
|
|
|
|
|
def tensorify(vocabs, minibatch): |
|
""" |
|
This function transforms a batch of example in tensors |
|
Each example looks like |
|
{'src': {'src': ..., 'feats': [...], 'src_ids': ...}, |
|
'tgt': {'tgt': ..., 'tgt_ids': ...}, |
|
'src_original': ['tok1', ...'tokn'], |
|
'tgt_original': ['tok1', ...'tokm'], |
|
'indices' : seq in bucket |
|
'align': ..., |
|
} |
|
Returns Dict of batch Tensors |
|
{'src': [seqlen, batchsize, n_feats+1], |
|
'tgt' : [seqlen, batchsize, n_feats=1], |
|
'indices' : [batchsize], |
|
'srclen': [batchsize], |
|
'tgtlen': [batchsize], |
|
'align': alignment sparse tensor |
|
} |
|
""" |
|
tensor_batch = {} |
|
tbatchsrc = [torch.LongTensor(ex["src"]["src_ids"]) for ex in minibatch] |
|
padidx = vocabs["src"][DefaultTokens.PAD] |
|
tbatchsrc = pad_sequence(tbatchsrc, batch_first=True, padding_value=padidx) |
|
if "feats" in minibatch[0]["src"]: |
|
tbatchfs = [tbatchsrc] |
|
for feat_id in range(len(minibatch[0]["src"]["feats"])): |
|
tbatchfeat = [ |
|
torch.LongTensor(ex["src"]["feats"][feat_id]) for ex in minibatch |
|
] |
|
padidx = vocabs["src_feats"][feat_id][DefaultTokens.PAD] |
|
tbatchfeat = pad_sequence( |
|
tbatchfeat, batch_first=True, padding_value=padidx |
|
) |
|
tbatchfs.append(tbatchfeat) |
|
tbatchsrc = torch.stack(tbatchfs, dim=2) |
|
else: |
|
|
|
tbatchsrc = tbatchsrc[:, :, None] |
|
|
|
tensor_batch["src"] = tbatchsrc |
|
tensor_batch["indices"] = torch.LongTensor([ex["indices"] for ex in minibatch]) |
|
tensor_batch["srclen"] = torch.LongTensor( |
|
[len(ex["src"]["src_ids"]) for ex in minibatch] |
|
) |
|
|
|
if minibatch[0]["tgt"] is not None: |
|
tbatchtgt = [torch.LongTensor(ex["tgt"]["tgt_ids"]) for ex in minibatch] |
|
padidx = vocabs["tgt"][DefaultTokens.PAD] |
|
tbatchtgt = pad_sequence(tbatchtgt, batch_first=True, padding_value=padidx) |
|
tbatchtgt = tbatchtgt[:, :, None] |
|
tbatchtgtlen = torch.LongTensor([len(ex["tgt"]["tgt_ids"]) for ex in minibatch]) |
|
tensor_batch["tgt"] = tbatchtgt |
|
tensor_batch["tgtlen"] = tbatchtgtlen |
|
|
|
if "align" in minibatch[0].keys() and minibatch[0]["align"] is not None: |
|
sparse_idx = [] |
|
for i, ex in enumerate(minibatch): |
|
for src, tgt in parse_align_idx(ex["align"]): |
|
sparse_idx.append([i, tgt + 1, src]) |
|
tbatchalign = torch.LongTensor(sparse_idx) |
|
tensor_batch["align"] = tbatchalign |
|
|
|
if "src_map" in minibatch[0].keys(): |
|
src_vocab_size = max([max(ex["src_map"]) for ex in minibatch]) + 1 |
|
src_map = torch.zeros( |
|
len(tensor_batch["srclen"]), tbatchsrc.size(1), src_vocab_size |
|
) |
|
for i, ex in enumerate(minibatch): |
|
for j, t in enumerate(ex["src_map"]): |
|
src_map[i, j, t] = 1 |
|
tensor_batch["src_map"] = src_map |
|
|
|
if "alignment" in minibatch[0].keys(): |
|
alignment = torch.zeros(len(tensor_batch["srclen"]), tbatchtgt.size(1)).long() |
|
for i, ex in enumerate(minibatch): |
|
alignment[i, : len(ex["alignment"])] = torch.LongTensor(ex["alignment"]) |
|
tensor_batch["alignment"] = alignment |
|
|
|
if "src_ex_vocab" in minibatch[0].keys(): |
|
tensor_batch["src_ex_vocab"] = [ex["src_ex_vocab"] for ex in minibatch] |
|
|
|
return tensor_batch |
|
|
|
|
|
def textbatch_to_tensor(vocabs, batch, is_train=False): |
|
""" |
|
This is a hack to transform a simple batch of texts |
|
into a tensored batch to pass through _translate() |
|
""" |
|
numeric = [] |
|
infer_iter = [] |
|
for i, ex in enumerate(batch): |
|
|
|
ex["srclen"] = len(ex["src"]["src"].split()) |
|
ex["indices"] = i |
|
ex["align"] = None |
|
numeric.append(numericalize(vocabs, ex)) |
|
numeric.sort(key=text_sort_key, reverse=True) |
|
infer_iter = [tensorify(vocabs, numeric)] |
|
return infer_iter |
|
|
|
|
|
def _addcopykeys(vocabs, example): |
|
"""Create copy-vocab and numericalize with it. |
|
In-place adds ``"src_map"`` to ``example``. That is the copy-vocab |
|
numericalization of the tokenized ``example["src"]``. If ``example`` |
|
has a ``"tgt"`` key, adds ``"alignment"`` to example. That is the |
|
copy-vocab numericalization of the tokenized ``example["tgt"]``. The |
|
alignment has an initial and final UNK token to match the BOS and EOS |
|
tokens. |
|
Args: |
|
vocabs |
|
example (dict): An example dictionary with a ``"src"`` key and |
|
maybe a ``"tgt"`` key. (This argument changes in place!) |
|
Returns: |
|
``example``, changed as described. |
|
""" |
|
src = example["src"]["src"].split() |
|
src_ex_vocab = pyonmttok.build_vocab_from_tokens( |
|
Counter(src), |
|
maximum_size=0, |
|
minimum_frequency=1, |
|
special_tokens=[ |
|
DefaultTokens.UNK, |
|
DefaultTokens.PAD, |
|
DefaultTokens.BOS, |
|
DefaultTokens.EOS, |
|
], |
|
) |
|
src_ex_vocab.default_id = src_ex_vocab[DefaultTokens.UNK] |
|
|
|
|
|
|
|
example["src_map"] = src_ex_vocab(src) |
|
example["src_ex_vocab"] = src_ex_vocab |
|
|
|
if example["tgt"] is not None: |
|
if vocabs["data_task"] == ModelTask.SEQ2SEQ: |
|
tgt = ( |
|
[DefaultTokens.UNK] |
|
+ example["tgt"]["tgt"].split() |
|
+ [DefaultTokens.UNK] |
|
) |
|
elif vocabs["data_task"] == ModelTask.LANGUAGE_MODEL: |
|
tgt = example["tgt"]["tgt"].split() + [DefaultTokens.UNK] |
|
example["alignment"] = src_ex_vocab(tgt) |
|
return example |
|
|