File size: 2,847 Bytes
67a58db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import torch
import os
from tqdm import tqdm
from .modeling import GECToR
from transformers import PreTrainedTokenizer
from typing import List, Dict
from .predict import (
    edit_src_by_tags,
    _predict
)

def predict_verbose(
    model: GECToR,
    tokenizer: PreTrainedTokenizer,
    srcs: List[str],
    encode: dict,
    decode: dict,
    keep_confidence: float=0,
    min_error_prob: float=0,
    batch_size: int=128,
    n_iteration: int=5
) -> List[str]:
    srcs = [['$START'] + src.split(' ') for src in srcs]
    final_edited_sents = ['-1'] * len(srcs)
    to_be_processed = srcs
    original_sent_idx = list(range(0, len(srcs)))
    iteration_log: List[List[Dict]] = []  # [send_id][iteration_id]['src' or 'tags']
    iteration_log = []
    # Initialize iteration logs. 
    for i, src in enumerate(srcs):
        iteration_log.append([{
            'src': src,
            'tag': None
        }])
    for itr in range(n_iteration):
        print(f'Iteratoin {itr}. the number of to_be_processed: {len(to_be_processed)}')
        pred_labels, no_corrections = _predict(
            model,
            tokenizer,
            to_be_processed,
            keep_confidence,
            min_error_prob,
            batch_size
        )
        current_srcs = []
        current_pred_labels = []
        current_orig_idx = []
        for i, yes in enumerate(no_corrections):
            if yes: # there's no corrections?
                final_edited_sents[original_sent_idx[i]] = ' '.join(to_be_processed[i]).replace('$START ', '')
            else:
                current_srcs.append(to_be_processed[i])
                current_pred_labels.append(pred_labels[i])
                current_orig_idx.append(original_sent_idx[i])
        if current_srcs == []:
            # Correcting for all sentences is completed.
            break
        edited_srcs = edit_src_by_tags(
            current_srcs,
            current_pred_labels,
            encode,
            decode
        )
        # Register the information during iteration.
        # edited_src will be the src of the next iteration.
        for i, orig_id in enumerate(current_orig_idx):
            iteration_log[orig_id][itr]['tag'] = current_pred_labels[i]
            iteration_log[orig_id].append({
                'src': edited_srcs[i],
                'tag': None
            })
        
        to_be_processed = edited_srcs
        original_sent_idx = current_orig_idx
        
        # print(f'=== Iteration {itr} ===')
        # print('\n'.join(final_edited_sents))
        # print(to_be_processed)
        # print(have_corrections)
    for i in range(len(to_be_processed)):
        final_edited_sents[original_sent_idx[i]] = ' '.join(to_be_processed[i]).replace('$START ', '')
    assert('-1' not in final_edited_sents)
    return final_edited_sents, iteration_log