File size: 725 Bytes
50f6c1a
 
4629047
50f6c1a
 
4629047
50f6c1a
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

tokenizer = AutoTokenizer.from_pretrained("mmcquade11/autonlp-reuters-summarization-34018133")
model = AutoModelForSeq2SeqLM.from_pretrained("mmcquade11/autonlp-reuters-summarization-34018133")

def summarize(text):
    input_ids = torch.tensor(tokenizer.encode(text, add_special_tokens=True)).unsqueeze(0)
    summary_ids = model.generate(input_ids, num_beams=4, max_length=100, early_stopping=True)
    return tokenizer.decode(summary_ids[0], skip_special_tokens=True)

def summarize_text(text):
    return summarize(text)

iface = gr.Interface(summarize_text, "textbox", "label")
if __name__ == "__main__":
    iface.launch()