andrewrreed's picture
andrewrreed HF staff
add handler
67a58db
raw
history blame
7.55 kB
import torch
import os
from tqdm import tqdm
from .modeling import GECToR
from transformers import PreTrainedTokenizer
from typing import List
def load_verb_dict(verb_file: str):
path_to_dict = os.path.join(verb_file)
encode, decode = {}, {}
with open(path_to_dict, encoding="utf-8") as f:
for line in f:
words, tags = line.split(":")
word1, word2 = words.split("_")
tag1, tag2 = tags.split("_")
decode_key = f"{word1}_{tag1}_{tag2.strip()}"
if decode_key not in decode:
encode[words] = tags
decode[decode_key] = word2
return encode, decode
def edit_src_by_tags(
srcs: List[List[str]],
pred_labels: List[List[str]],
encode: dict,
decode: dict
) -> List[str]:
edited_srcs = []
for tokens, labels in zip(srcs, pred_labels):
edited_tokens = []
for t, l, in zip(tokens, labels):
n_token = process_token(t, l, encode, decode)
if n_token == None:
n_token = t
edited_tokens += n_token.split(' ')
if len(tokens) > len(labels):
omitted_tokens = tokens[len(labels):]
edited_tokens += omitted_tokens
temp_str = ' '.join(edited_tokens) \
.replace(' $MERGE_HYPHEN ', '-') \
.replace(' $MERGE_SPACE ', '') \
.replace(' $DELETE', '') \
.replace('$DELETE ', '')
edited_srcs.append(temp_str.split(' '))
return edited_srcs
def process_token(
token: str,
label: str,
encode: dict,
decode: dict
) -> str:
if '$APPEND_' in label:
return token + ' ' + label.replace('$APPEND_', '')
elif token == '$START':
# [unused1] token cannot be replaced with another token and cannot be deleted.
return token
elif label in ['<PAD>', '<OOV>', '$KEEP']:
return token
elif '$APPEND_' in label:
return token + ' ' + label.replace('$APPEND_', '')
elif '$TRANSFORM_' in label:
return g_transform_processer(token, label, encode, decode)
elif '$REPLACE_' in label:
return label.replace('$REPLACE_', '')
elif label == '$DELETE':
return label
elif '$MERGE_' in label:
return token + ' ' + label
else:
return token
def g_transform_processer(
token: str,
label: str,
encode: dict,
decode: dict
) -> str:
# Case related
if label == '$TRANSFORM_CASE_LOWER':
return token.lower()
elif label == '$TRANSFORM_CASE_UPPER':
return token.upper()
elif label == '$TRANSFORM_CASE_CAPITAL':
return token.capitalize()
elif label == '$TRANSFORM_CASE_CAPITAL_1':
if len(token) <= 1:
return token
return token[0] + token[1:].capitalize()
elif label == '$TRANSFORM_AGREEMENT_PLURAL':
return token + 's'
elif label == '$TRANSFORM_AGREEMENT_SINGULAR':
return token[:-1]
elif label == '$TRANSFORM_SPLIT_HYPHEN':
return ' '.join(token.split('-'))
else:
encoding_part = f"{token}_{label[len('$TRANSFORM_VERB_'):]}"
decoded_target_word = decode.get(encoding_part)
return decoded_target_word
def get_word_masks_from_word_ids(
word_ids: List[List[int]],
n: int
):
word_masks = []
for i in range(n):
previous_id = 0
mask = []
for _id in word_ids(i):
if _id is None:
mask.append(0)
elif previous_id != _id:
mask.append(1)
else:
mask.append(0)
previous_id = _id
word_masks.append(mask)
return word_masks
def _predict(
model: GECToR,
tokenizer: PreTrainedTokenizer,
srcs: List[str],
keep_confidence: float=0,
min_error_prob: float=0,
batch_size: int=128
):
itr = list(range(0, len(srcs), batch_size))
pred_labels = []
no_corrections = []
for i in tqdm(itr):
batch = tokenizer(
srcs[i:i+batch_size],
return_tensors='pt',
max_length=model.config.max_length,
padding='max_length',
truncation=True,
is_split_into_words=True
)
batch['word_masks'] = torch.tensor(
get_word_masks_from_word_ids(
batch.word_ids,
batch['input_ids'].size(0)
)
)
word_ids = batch.word_ids
if torch.cuda.is_available():
batch = {k:v.cuda() for k,v in batch.items()}
outputs = model.predict(
batch['input_ids'],
batch['attention_mask'],
batch['word_masks'],
keep_confidence,
min_error_prob
)
# Align subword-level label to word-level label
for i in range(len(outputs.pred_labels)):
no_correct = True
labels = []
previous_word_idx = None
for j, idx in enumerate(word_ids(i)):
if idx is None:
continue
if idx != previous_word_idx:
labels.append(outputs.pred_labels[i][j])
if outputs.pred_label_ids[i][j] > 2:
no_correct = False
previous_word_idx = idx
# print(no_correct, labels)
pred_labels.append(labels)
no_corrections.append(no_correct)
# print(pred_labels)
return pred_labels, no_corrections
def predict(
model: GECToR,
tokenizer: PreTrainedTokenizer,
srcs: List[str],
encode: dict,
decode: dict,
keep_confidence: float=0,
min_error_prob: float=0,
batch_size: int=128,
n_iteration: int=5
) -> List[str]:
srcs = [['$START'] + src.split(' ') for src in srcs]
final_edited_sents = ['-1'] * len(srcs)
to_be_processed = srcs
original_sent_idx = list(range(0, len(srcs)))
for itr in range(n_iteration):
print(f'Iteratoin {itr}. the number of to_be_processed: {len(to_be_processed)}')
pred_labels, no_corrections = _predict(
model,
tokenizer,
to_be_processed,
keep_confidence,
min_error_prob,
batch_size
)
current_srcs = []
current_pred_labels = []
current_orig_idx = []
for i, yes in enumerate(no_corrections):
if yes: # there's no corrections?
final_edited_sents[original_sent_idx[i]] = ' '.join(to_be_processed[i]).replace('$START ', '')
else:
current_srcs.append(to_be_processed[i])
current_pred_labels.append(pred_labels[i])
current_orig_idx.append(original_sent_idx[i])
if current_srcs == []:
# Correcting for all sentences is completed.
break
# if itr > 2:
# for l in current_pred_labels:
# print(l)
edited_srcs = edit_src_by_tags(
current_srcs,
current_pred_labels,
encode,
decode
)
to_be_processed = edited_srcs
original_sent_idx = current_orig_idx
# print(f'=== Iteration {itr} ===')
# print('\n'.join(final_edited_sents))
# print(to_be_processed)
# print(have_corrections)
for i in range(len(to_be_processed)):
final_edited_sents[original_sent_idx[i]] = ' '.join(to_be_processed[i]).replace('$START ', '')
assert('-1' not in final_edited_sents)
return final_edited_sents