Spaces:
Running
Running
File size: 6,822 Bytes
cb10a62 6b2b9b7 cb10a62 6b2b9b7 cb10a62 6b2b9b7 cb10a62 6b2b9b7 cb10a62 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
# -*- coding: utf-8 -*-
import gradio as gr
import operator
import torch
from transformers import BertTokenizer, BertForMaskedLM
pretrained_model_name_or_path = "Macropodus/macbert4mdcspell_v2"
tokenizer = BertTokenizer.from_pretrained(pretrained_model_name_or_path)
model = BertForMaskedLM.from_pretrained(pretrained_model_name_or_path)
vocab = tokenizer.vocab
# from modelscope import AutoTokenizer, AutoModelForMaskedLM
# pretrained_model_name_or_path = "Macadam/macbert4mdcspell_v2"
# tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path)
# model = AutoModelForMaskedLM.from_pretrained(pretrained_model_name_or_path)
# vocab = tokenizer.vocab
def func_macro_correct(text):
with torch.no_grad():
outputs = model(**tokenizer([text], padding=True, return_tensors='pt'))
def flag_total_chinese(text):
"""
judge is total chinese or not, 判断是不是全是中文
Args:
text: str, eg. "macadam, 碎石路"
Returns:
bool, True or False
"""
for word in text:
if not "\u4e00" <= word <= "\u9fa5":
return False
return True
def get_errors_from_diff_length(corrected_text, origin_text, unk_tokens=[], know_tokens=[]):
"""Get errors between corrected text and origin text
code from: https://github.com/shibing624/pycorrector
"""
new_corrected_text = ""
errors = []
i, j = 0, 0
unk_tokens = unk_tokens or [' ', '“', '”', '‘', '’', '琊', '\n', '…', '擤', '\t', '玕', '']
while i < len(origin_text) and j < len(corrected_text):
if origin_text[i] in unk_tokens or origin_text[i] not in know_tokens:
new_corrected_text += origin_text[i]
i += 1
elif corrected_text[j] in unk_tokens:
new_corrected_text += corrected_text[j]
j += 1
# Deal with Chinese characters
elif flag_total_chinese(origin_text[i]) and flag_total_chinese(corrected_text[j]):
# If the two characters are the same, then the two pointers move forward together
if origin_text[i] == corrected_text[j]:
new_corrected_text += corrected_text[j]
i += 1
j += 1
else:
# Check for insertion errors
if j + 1 < len(corrected_text) and origin_text[i] == corrected_text[j + 1]:
errors.append(('', corrected_text[j], j))
new_corrected_text += corrected_text[j]
j += 1
# Check for deletion errors
elif i + 1 < len(origin_text) and origin_text[i + 1] == corrected_text[j]:
errors.append((origin_text[i], '', i))
i += 1
# Check for replacement errors
else:
errors.append((origin_text[i], corrected_text[j], i))
new_corrected_text += corrected_text[j]
i += 1
j += 1
else:
new_corrected_text += origin_text[i]
if origin_text[i] == corrected_text[j]:
j += 1
i += 1
errors = sorted(errors, key=operator.itemgetter(2))
return new_corrected_text, errors
def get_errors_from_same_length(corrected_text, origin_text, unk_tokens=[], know_tokens=[]):
"""Get new corrected text and errors between corrected text and origin text
code from: https://github.com/shibing624/pycorrector
"""
errors = []
unk_tokens = unk_tokens or [' ', '“', '”', '‘', '’', '琊', '\n', '…', '擤', '\t', '玕', '', ',']
for i, ori_char in enumerate(origin_text):
if i >= len(corrected_text):
continue
if ori_char in unk_tokens or ori_char not in know_tokens:
# deal with unk word
corrected_text = corrected_text[:i] + ori_char + corrected_text[i + 1:]
continue
if ori_char != corrected_text[i]:
if not flag_total_chinese(ori_char):
# pass not chinese char
corrected_text = corrected_text[:i] + ori_char + corrected_text[i + 1:]
continue
if not flag_total_chinese(corrected_text[i]):
corrected_text = corrected_text[:i] + corrected_text[i + 1:]
continue
errors.append([ori_char, corrected_text[i], i])
errors = sorted(errors, key=operator.itemgetter(2))
return corrected_text, errors
_text = tokenizer.decode(torch.argmax(outputs.logits[0], dim=-1), skip_special_tokens=True).replace(' ', '')
corrected_text = _text[:len(text)]
print("#" * 128)
print(text)
print(corrected_text)
print(len(text), len(corrected_text))
if len(corrected_text) == len(text):
corrected_text, details = get_errors_from_same_length(corrected_text, text, know_tokens=vocab)
else:
corrected_text, details = get_errors_from_diff_length(corrected_text, text, know_tokens=vocab)
print(text, ' => ', corrected_text, details)
return corrected_text + ' ' + str(details)
if __name__ == '__main__':
print(func_macro_correct('他法语说的很好,的语也不错'))
examples = [
"夫谷之雨,犹复云之亦从的起,因与疾风俱飘,参于天,集于的。",
"机七学习是人工智能领遇最能体现智能的一个分知",
'他们的吵翻很不错,再说他们做的咖喱鸡也好吃',
"抗疫路上,除了提心吊胆也有难的得欢笑。",
"我是练习时长两念半的鸽仁练习生蔡徐坤",
"清晨,如纱一般地薄雾笼罩着世界。",
"得府许我立庙于此,故请君移去尔。",
"他法语说的很好,的语也不错",
"遇到一位很棒的奴生跟我疗天",
"五年级得数学,我考的很差。",
"我们为这个目标努力不解",
'今天兴情很好',
]
gr.Interface(
func_macro_correct,
inputs='text',
outputs='text',
title="Chinese Spelling Correction Model Macropodus/macbert4mdcspell_v2",
description="Copy or input error Chinese text. Submit and the machine will correct text.",
article="Link to <a href='https://github.com/yongzhuo/macro-correct' style='color:blue;' target='_blank\'>Github REPO: macro-correct</a>",
examples=examples
).launch() # .launch(server_name="0.0.0.0", server_port=8036, share=False, debug=True)
|