madlad400-3b-mt / app.py
jbochi's picture
Use app.py from https://github.com/synkathairo/flan-t5-large-gradio
0fb075f
raw history blame
No virus
1.47 kB
from transformers import T5Tokenizer, T5ForConditionalGeneration
import gradio as gr
MODEL_NAME = "jbochi/madlad400-3b-mt"
default_max_length = 200
print("Using `{}`.".format(MODEL_NAME))
tokenizer = T5Tokenizer.from_pretrained(MODEL_NAME)
print("T5Tokenizer loaded from pretrained.")
model = T5ForConditionalGeneration.from_pretrained(MODEL_NAME, device_map="auto")
print("T5ForConditionalGeneration loaded from pretrained.")
def inference(max_length, input_text, history=[]):
input_ids = tokenizer(input_text, return_tensors="pt").input_ids
outputs = model.generate(input_ids, max_length=max_length, bos_token_id=0)
result = tokenizer.decode(outputs[0], skip_special_tokens=True)
history.append((input_text, result))
return history, history
with gr.Blocks() as demo:
with gr.Row():
gr.Markdown(
"<h1>Demo of {}</h1><p>See more at Hugging Face: <a href='https://huggingface.co/{}'>{}</a>.</p>".format(
MODEL_NAME, MODEL_NAME, MODEL_NAME
)
)
max_length = gr.Number(
value=default_max_length, label="maximum length of response"
)
chatbot = gr.Chatbot(label=MODEL_NAME)
state = gr.State([])
with gr.Row():
txt = gr.Textbox(
show_label=False, placeholder="<2es> text to translate"
).style(container=False)
txt.submit(fn=inference, inputs=[max_length, txt, state], outputs=[chatbot, state])
demo.launch()