cswamy's picture
Update app.py
01f1ddd
raw
history blame
No virus
1.33 kB
import torch
import gradio as gr
from model import create_marian_enfr
# Setup model and tokenizer
model, tokenizer = create_marian_enfr()
# Load state dict from model
model.load_state_dict(
torch.load(
f="marian_finetuned_kde4_enfr.pth",
map_location=torch.device("cpu")
))
# Predict function
def predict(text:str):
# Tokenize inputs and get model outputs
input = tokenizer(text,
max_length=128,
truncation=True,
return_tensors="pt")
output_tokens = model.generate(**input)
output_text = tokenizer.batch_decode(output_tokens,
skip_special_tokens=True)
return output_text
# Create examples list
examples_list = ['What a beautiful day',
'I love music']
# Create gradio app
title = "English to French translator"
description = "Marian model finetuned for english to french translation on the kde4 dataset."
demo = gr.Interface(fn=predict,
inputs=gr.inputs.Textbox(label="Input",
placeholder="Enter sentence here..."),
outputs="text",
examples=examples_list,
title=title,
description=description)
# Launch gradio
demo.launch()