priyaradhakrishnan's picture
initial commit
3194bf0
raw
history blame
1.07 kB
import torch
import gradio
import gradio as gr
from transformers import T5Tokenizer, T5ForConditionalGeneration, T5Config
#retrain
#initialize
device = torch.device("cuda")
model = T5ForConditionalGeneration.from_pretrained('t5-small',device_map="auto")
tokenizer = T5Tokenizer.from_pretrained('t5-small',device_map="auto")
def summ(text_content):
preprocess_text = text_content.strip().replace("\n","")
t5_inputText = "summarize: "+preprocess_text
tokenized_text = tokenizer.encode(t5_inputText, return_tensors="pt").to(device)
summary_ids = model.generate(tokenized_text,num_beams=4, no_repeat_ngram_size=2,
min_length=30, max_length=300,early_stopping=True).to(device)
summarized_output = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
return summarized_output
def greet(text_content):
bm25 = summ(text_content)
return bm25
demo = gr.Interface(fn=greet, inputs="text", outputs="text")
print("Throwing up")
demo.launch(server_name="0.0.0.0", server_port=8000, share=True)