#!/usr/bin/python3 # -*- coding: utf-8 -*- import gradio as gr from transformers import AutoTokenizer, BertForTokenClassification from transformers import pipeline from textwrap import wrap TAG = "raynardj/classical-chinese-punctuation-guwen-biaodian" model = BertForTokenClassification.from_pretrained(TAG) tokenizer = AutoTokenizer.from_pretrained(TAG) ner = pipeline("ner", model=model, tokenizer=tokenizer) def mark_sentence(x): outputs = ner(x) x_list = list(x) for i, output in enumerate(outputs): x_list.insert(output['end'] + i, output['entity']) return "".join(x_list) def punct(txt): text_list = txt.splitlines(True) out_text = [] for line in text_list: if len(line) > 512: split_list = wrap(line, 512) for chunk in split_list: sen = mark_sentence(chunk) out_text.append(sen) else: sen = mark_sentence(line) out_text.append(sen) return ''.join(out_text) interface = gr.Interface(fn=punct, inputs=gr.inputs.Textbox(lines=5, label='输入需要标点的文本:'), outputs=gr.outputs.Textbox(label='标点结果:'), title='chinese classical texts punctuation' ) interface.launch()