Matonice's picture
Update app.py
42c306e
import pandas as pd
import numpy as np
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import nltk
nltk.download('punkt')
checkpoint = "sshleifer/distilbart-cnn-12-6"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)
def summarize_sentence(file, column):
df = pd.read_csv(file.name)
sentences = nltk.tokenize.sent_tokenize(df[column].values[0])
length = 0
chunk = ""
chunks = []
count = -1
for sentence in sentences:
count +=1
combined_length = len(tokenizer.tokenize(sentence)) + length
if combined_length <= tokenizer.max_len_single_sentence:
chunk += sentence + " "
length = combined_length
if count == len(sentences) -1:
chunks.append(chunk)
else:
chunks.append(chunk)
length = 0
chunk = ""
chunk += sentence + " "
length = len(tokenizer.tokenize(sentence))
inputs = [tokenizer(chunk, return_tensors='pt') for chunk in chunks]
summary = ''
for input in inputs:
output = model.generate(**input)
summary += (tokenizer.decode(*output, skip_special_tokens=True))
return (summary.replace('.<n>', '.\n'))
demo = gr.Interface(summarize_sentence,
inputs=[gr.inputs.File(label="CSV File", optional=False), "text"],
outputs=[gr.outputs.Textbox(label="Summary")])
if __name__ == "__main__":
demo.launch(debug=True)