liamvbetts's picture
working interface
9b95b6c
raw
history blame
No virus
1.43 kB
import gradio as gr
import random
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from datasets import load_dataset
tokenizer = AutoTokenizer.from_pretrained("liamvbetts/bart-large-cnn-v4")
model = AutoModelForSeq2SeqLM.from_pretrained("liamvbetts/bart-large-cnn-v4")
dataset = load_dataset("cnn_dailymail", "3.0.0")
def summarize(article):
inputs = tokenizer(article, return_tensors="pt").input_ids
outputs = model.generate(inputs, max_new_tokens=128, do_sample=False)
summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
return summary
def get_random_article():
random.seed()
val_example = dataset["validation"].shuffle().select(range(1))
val_article = val_example['article'][0][:1024]
return val_article
def load_article():
return get_random_article()
# Using Gradio Blocks
with gr.Blocks() as demo:
gr.Markdown("## News Summary App")
gr.Markdown("Enter a news text and get its summary, or load a random article from the validation set.")
with gr.Row():
input_text = gr.Textbox(lines=10, label="Input Text")
output_text = gr.Textbox(label="Summary")
load_article_button = gr.Button("Load Random Article")
load_article_button.click(fn=load_article, inputs=[], outputs=input_text)
summarize_button = gr.Button("Summarize")
summarize_button.click(fn=summarize, inputs=input_text, outputs=output_text)
demo.launch()