|
import gradio as gr |
|
import re |
|
|
|
from transformers import ( |
|
BertTokenizer, |
|
BartForConditionalGeneration, |
|
pipeline |
|
) |
|
|
|
|
|
def clean_text(text): |
|
text = re.sub(r"http\S+", "", text) |
|
text = re.sub(r"ADVERTISEMENT", " ", text) |
|
text = re.sub(r"ADVERTISING", " ", text) |
|
text = re.sub(r"\n", " ", text) |
|
text = re.sub(r"\n\n", " ", text) |
|
text = re.sub(r"\t", " ", text) |
|
text = text.strip(" ") |
|
text = re.sub( |
|
" +", " ", text |
|
).strip() |
|
return text |
|
|
|
|
|
model_name = "chinhon/bart-large-chinese-cnhdwriter" |
|
|
|
tokenizer = BertTokenizer.from_pretrained(model_name, model_max_length=512) |
|
|
|
model = BartForConditionalGeneration.from_pretrained(model_name) |
|
|
|
text2text_generator = pipeline( |
|
"text2text-generation", |
|
model=model, |
|
tokenizer=tokenizer, |
|
truncation=True |
|
) |
|
|
|
def cn_text(text): |
|
input_text = clean_text(text) |
|
|
|
prediction = text2text_generator( |
|
input_text, |
|
max_length=128, |
|
length_penalty=50., |
|
) |
|
|
|
pred_text = [x.get("generated_text") for x in prediction] |
|
|
|
return pred_text[0] |
|
|
|
|
|
|
|
gradio_ui = gr.Interface( |
|
fn=cn_text, |
|
title="Chinese News Headlines Generator", |
|
description="Too busy or tired to write a headline for your Chinese news story? Try this instead.", |
|
inputs=gr.Textbox( |
|
lines=20, label="Paste Chinese text here" |
|
), |
|
outputs=gr.Textbox(label="Suggested Headline"), |
|
theme="huggingface", |
|
) |
|
|
|
gradio_ui.launch(enable_queue=True) |
|
|