File size: 6,822 Bytes
cb10a62
 
 
 
 
 
 
 
 
 
 
 
 
 
6b2b9b7
 
 
 
 
 
 
cb10a62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6b2b9b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cb10a62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6b2b9b7
 
 
 
 
 
 
 
cb10a62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6b2b9b7
cb10a62
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
# -*- coding: utf-8 -*-

import gradio as gr
import operator
import torch
from transformers import BertTokenizer, BertForMaskedLM


pretrained_model_name_or_path = "Macropodus/macbert4mdcspell_v2"
tokenizer = BertTokenizer.from_pretrained(pretrained_model_name_or_path)
model = BertForMaskedLM.from_pretrained(pretrained_model_name_or_path)
vocab = tokenizer.vocab


# from modelscope import AutoTokenizer, AutoModelForMaskedLM
# pretrained_model_name_or_path = "Macadam/macbert4mdcspell_v2"
# tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path)
# model = AutoModelForMaskedLM.from_pretrained(pretrained_model_name_or_path)
# vocab = tokenizer.vocab


def func_macro_correct(text):
    with torch.no_grad():
        outputs = model(**tokenizer([text], padding=True, return_tensors='pt'))

    def flag_total_chinese(text):
        """
        judge is total chinese or not, 判断是不是全是中文
        Args:
            text: str, eg. "macadam, 碎石路"
        Returns:
            bool, True or False
        """
        for word in text:
            if not "\u4e00" <= word <= "\u9fa5":
                return False
        return True

    def get_errors_from_diff_length(corrected_text, origin_text, unk_tokens=[], know_tokens=[]):
        """Get errors between corrected text and origin text
        code from:  https://github.com/shibing624/pycorrector
        """
        new_corrected_text = ""
        errors = []
        i, j = 0, 0
        unk_tokens = unk_tokens or [' ', '“', '”', '‘', '’', '琊', '\n', '…', '擤', '\t', '玕', '']
        while i < len(origin_text) and j < len(corrected_text):
            if origin_text[i] in unk_tokens or origin_text[i] not in know_tokens:
                new_corrected_text += origin_text[i]
                i += 1
            elif corrected_text[j] in unk_tokens:
                new_corrected_text += corrected_text[j]
                j += 1
            # Deal with Chinese characters
            elif flag_total_chinese(origin_text[i]) and flag_total_chinese(corrected_text[j]):
                # If the two characters are the same, then the two pointers move forward together
                if origin_text[i] == corrected_text[j]:
                    new_corrected_text += corrected_text[j]
                    i += 1
                    j += 1
                else:
                    # Check for insertion errors
                    if j + 1 < len(corrected_text) and origin_text[i] == corrected_text[j + 1]:
                        errors.append(('', corrected_text[j], j))
                        new_corrected_text += corrected_text[j]
                        j += 1
                    # Check for deletion errors
                    elif i + 1 < len(origin_text) and origin_text[i + 1] == corrected_text[j]:
                        errors.append((origin_text[i], '', i))
                        i += 1
                    # Check for replacement errors
                    else:
                        errors.append((origin_text[i], corrected_text[j], i))
                        new_corrected_text += corrected_text[j]
                        i += 1
                        j += 1
            else:
                new_corrected_text += origin_text[i]
                if origin_text[i] == corrected_text[j]:
                    j += 1
                i += 1
        errors = sorted(errors, key=operator.itemgetter(2))
        return new_corrected_text, errors

    def get_errors_from_same_length(corrected_text, origin_text, unk_tokens=[], know_tokens=[]):
        """Get new corrected text and errors between corrected text and origin text
        code from:  https://github.com/shibing624/pycorrector
        """
        errors = []
        unk_tokens = unk_tokens or [' ', '“', '”', '‘', '’', '琊', '\n', '…', '擤', '\t', '玕', '', ',']

        for i, ori_char in enumerate(origin_text):
            if i >= len(corrected_text):
                continue
            if ori_char in unk_tokens or ori_char not in know_tokens:
                # deal with unk word
                corrected_text = corrected_text[:i] + ori_char + corrected_text[i + 1:]
                continue
            if ori_char != corrected_text[i]:
                if not flag_total_chinese(ori_char):
                    # pass not chinese char
                    corrected_text = corrected_text[:i] + ori_char + corrected_text[i + 1:]
                    continue
                if not flag_total_chinese(corrected_text[i]):
                    corrected_text = corrected_text[:i] + corrected_text[i + 1:]
                    continue
                errors.append([ori_char, corrected_text[i], i])
        errors = sorted(errors, key=operator.itemgetter(2))
        return corrected_text, errors

    _text = tokenizer.decode(torch.argmax(outputs.logits[0], dim=-1), skip_special_tokens=True).replace(' ', '')
    corrected_text = _text[:len(text)]
    print("#" * 128)
    print(text)
    print(corrected_text)
    print(len(text), len(corrected_text))
    if len(corrected_text) == len(text):
        corrected_text, details = get_errors_from_same_length(corrected_text, text, know_tokens=vocab)
    else:
        corrected_text, details = get_errors_from_diff_length(corrected_text, text, know_tokens=vocab)
    print(text, ' => ', corrected_text, details)
    return corrected_text + ' ' + str(details)


if __name__ == '__main__':
    print(func_macro_correct('他法语说的很好,的语也不错'))

    examples = [
        "夫谷之雨,犹复云之亦从的起,因与疾风俱飘,参于天,集于的。",
        "机七学习是人工智能领遇最能体现智能的一个分知",
        '他们的吵翻很不错,再说他们做的咖喱鸡也好吃',
        "抗疫路上,除了提心吊胆也有难的得欢笑。",
        "我是练习时长两念半的鸽仁练习生蔡徐坤",
        "清晨,如纱一般地薄雾笼罩着世界。",
        "得府许我立庙于此,故请君移去尔。",
        "他法语说的很好,的语也不错",
        "遇到一位很棒的奴生跟我疗天",
        "五年级得数学,我考的很差。",
        "我们为这个目标努力不解",
        '今天兴情很好',
    ]

    gr.Interface(
        func_macro_correct,
        inputs='text',
        outputs='text',
        title="Chinese Spelling Correction Model Macropodus/macbert4mdcspell_v2",
        description="Copy or input error Chinese text. Submit and the machine will correct text.",
        article="Link to <a href='https://github.com/yongzhuo/macro-correct' style='color:blue;' target='_blank\'>Github REPO: macro-correct</a>",
        examples=examples
    ).launch()  # .launch(server_name="0.0.0.0", server_port=8036, share=False, debug=True)