File size: 2,847 Bytes
67a58db |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 |
import torch
import os
from tqdm import tqdm
from .modeling import GECToR
from transformers import PreTrainedTokenizer
from typing import List, Dict
from .predict import (
edit_src_by_tags,
_predict
)
def predict_verbose(
model: GECToR,
tokenizer: PreTrainedTokenizer,
srcs: List[str],
encode: dict,
decode: dict,
keep_confidence: float=0,
min_error_prob: float=0,
batch_size: int=128,
n_iteration: int=5
) -> List[str]:
srcs = [['$START'] + src.split(' ') for src in srcs]
final_edited_sents = ['-1'] * len(srcs)
to_be_processed = srcs
original_sent_idx = list(range(0, len(srcs)))
iteration_log: List[List[Dict]] = [] # [send_id][iteration_id]['src' or 'tags']
iteration_log = []
# Initialize iteration logs.
for i, src in enumerate(srcs):
iteration_log.append([{
'src': src,
'tag': None
}])
for itr in range(n_iteration):
print(f'Iteratoin {itr}. the number of to_be_processed: {len(to_be_processed)}')
pred_labels, no_corrections = _predict(
model,
tokenizer,
to_be_processed,
keep_confidence,
min_error_prob,
batch_size
)
current_srcs = []
current_pred_labels = []
current_orig_idx = []
for i, yes in enumerate(no_corrections):
if yes: # there's no corrections?
final_edited_sents[original_sent_idx[i]] = ' '.join(to_be_processed[i]).replace('$START ', '')
else:
current_srcs.append(to_be_processed[i])
current_pred_labels.append(pred_labels[i])
current_orig_idx.append(original_sent_idx[i])
if current_srcs == []:
# Correcting for all sentences is completed.
break
edited_srcs = edit_src_by_tags(
current_srcs,
current_pred_labels,
encode,
decode
)
# Register the information during iteration.
# edited_src will be the src of the next iteration.
for i, orig_id in enumerate(current_orig_idx):
iteration_log[orig_id][itr]['tag'] = current_pred_labels[i]
iteration_log[orig_id].append({
'src': edited_srcs[i],
'tag': None
})
to_be_processed = edited_srcs
original_sent_idx = current_orig_idx
# print(f'=== Iteration {itr} ===')
# print('\n'.join(final_edited_sents))
# print(to_be_processed)
# print(have_corrections)
for i in range(len(to_be_processed)):
final_edited_sents[original_sent_idx[i]] = ' '.join(to_be_processed[i]).replace('$START ', '')
assert('-1' not in final_edited_sents)
return final_edited_sents, iteration_log |