cswamy's picture
Update app.py
d78f16f
import torch
import gradio as gr
from model import create_mt5_small
# Setup model and tokenizer
model, tokenizer = create_mt5_small()
# Load state dict from model
model.load_state_dict(
torch.load(
f="mt5_amzn_enes_reviews_summarization.pth",
map_location=torch.device("cpu")
))
# Predict function
def predict(text:str):
# Tokenize inputs and get model outputs
input = tokenizer(text,
max_length=512,
truncation=True,
return_tensors='pt')
output_tokens = model.generate(input['input_ids'],
attention_mask=input['attention_mask'],
max_length=30)
output_text = tokenizer.batch_decode(output_tokens,
skip_special_tokens=True)
return output_text[0]
# Create examples list
examples_list = ["The ball hit the splice a lot and sent a fizzing sensation up the handle and into the bottom hand, so I adapted at each session by playing softer and softer, later and later. I found it very difficult to get down the pitch and meet the ball as it landed and so persuaded myself to play back more. It occurred to me that a better player would manage the shimmy down the pitch with more skill and faster footwork, and that the good sweepers would have to take him on in the way that Kevin Pietersen managed so successfully on occasions.",
"Todo muy bien, cumple con lo esperado. Lo único malo es que: se calienta un poco y la batería no dura 8h. A una persona le ha parecido esto útil"]
# Create gradio app
title = "Summarizer for English and Spanish inputs"
description = "MT5-small model finetuned for summarization on English or Spanish text trained on the Amazon reviews dataset."
demo = gr.Interface(fn=predict,
inputs=gr.inputs.Textbox(label="Input",
placeholder="Enter sentences here in English or Spanish..."),
outputs="text",
examples=examples_list,
title=title,
description=description)
# Launch gradio
demo.launch()