File size: 1,759 Bytes
1a8c724
4d3d295
 
 
dd465f4
4d3d295
1a8c724
97ca765
4d3d295
 
 
9820b04
 
4d3d295
 
 
 
 
 
 
 
1a8c724
dd97139
97ca765
77dbb7f
 
 
dd97139
 
 
 
 
a5284e6
 
 
dd97139
a5284e6
77dbb7f
a5284e6
 
1a8c724
e9e5a55
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

tokenizer = AutoTokenizer.from_pretrained("google/mt5-small")
model = AutoModelForSeq2SeqLM.from_pretrained("./checkpoint-25000/")

def text_processing(text):
    text = text + ' ' if text[-2:] != ' ' else text  # 在末尾加上空格有利于模型预测
    inputs = [text]

    # Tokenize and prepare the inputs for model
    input_ids = tokenizer(inputs, return_tensors="pt", max_length=512, truncation=True, padding="max_length").input_ids
    attention_mask = tokenizer(inputs, return_tensors="pt", max_length=512, truncation=True, padding="max_length").attention_mask

    # Generate prediction
    output = model.generate(input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=512)

    # Decode the prediction
    decoded_output = [tokenizer.decode(ids, skip_special_tokens=True) for ids in output]

    return decoded_output[0]

examples = [
    ["我们的价值观是 富强 民主 文明 和谐"],
    ["都什么年代了 还在抽传统香烟"],
    ["今夕是何年"],
    [" 三国演义 全名为 三國志通俗演义 又稱作 三國志演義 三國志傳 三國傳 三國全傳 三國英雄志傳 "],
]

inputs=[gr.inputs.Textbox(default=examples[0][0], label="输入文本")]
    

iface = gr.Interface(
    fn=text_processing,
    inputs=[gr.inputs.Textbox(default=examples[0][0], label="输入文本")],
    outputs='text',
    title='Punctuation Mark Prediction',
    description='本模型主要用于语音识别模型输出的后处理。\n输入无符号句子,需要打标点处用空格隔开,返回带标点句子。\n仅支持中文,因为训练数据中只有中文。',
    examples=examples
)

iface.launch(inline=False)