randomshit11's picture
Update app.py
d43476c verified
raw
history blame contribute delete
No virus
2.54 kB
import gradio
import torch
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
def chunk_text(text, chunk_size):
chunks = [text[i:i + chunk_size] for i in range(0, len(text), chunk_size)]
return chunks
def shorten_text(text, min_length, max_length):
summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
chunks = chunk_text(text, 1024)
summary_chunks = []
for chunk in chunks:
summary = summarizer(chunk, max_length, min_length, do_sample=False)
summary_chunks.append(summary[0]["summary_text"])
summary = ' '.join(summary_chunks)
return summary
def paraphrase_text(text, min_length, max_length):
tokenizer = AutoTokenizer.from_pretrained("randomshit11/fin-bert-1st-shit")
model = AutoModelForSeq2SeqLM.from_pretrained("randomshit11/fin-bert-1st-shit")
device = "cuda" if torch.cuda.is_available() else "cpu"
text_instruction = "Summary: " + text + " </s>"
chunks = chunk_text(text_instruction, 1024)
output_chunks = []
for chunk in chunks:
encoding = tokenizer.encode_plus(chunk, padding="longest", return_tensors="pt")
input_ids, attention_masks = encoding["input_ids"].to(device), encoding["attention_mask"].to(device)
outputs = model.generate(
input_ids=input_ids, attention_mask=attention_masks,
max_length=max_length,
do_sample=True,
top_k=120,
top_p=0.95,
early_stopping=True,
num_return_sequences=5
)
line = tokenizer.decode(outputs[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
output_chunks.append(line)
output = ' '.join(output_chunks)
return output
def modify_text(mode, text, min_length, max_length):
if mode == "shorten":
return shorten_text(text, min_length, max_length)
else:
return paraphrase_text(text, min_length, max_length)
gradio_interface = gradio.Interface(
fn=modify_text,
inputs=[
gradio.Radio(["shorten", "Summary"], label="Mode"),
"text",
gradio.Slider(5, 200, value=30, label="Min length"),
gradio.Slider(5, 500, value=130, label="Max length")
],
outputs="text",
examples=[
["shorten",
"""Your long input text goes here...""",
30, 130]
],
title="Text shortener/paraphraser",
description="Shortening texts using `facebook/bart-large-cnn`, paraphrasing texts using `fin-bert-1st-shit`.",
)
gradio_interface.launch()