dzmltzack's picture
Update app.py
a6b5c53
raw
history blame
735 Bytes
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-large")
model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-large", device_map="auto", cache_dir="cache", offload_folder="offload" )
def generate(input_text):
input_ids = tokenizer(input_text, return_tensors="pt")
output = model.generate(input_ids, max_length=70)
return tokenizer.decode(output[0], skip_special_tokens=True)
#@title GUI
import gradio as gr
title = "Flan T5 :)"
def inference(text):
return generate(text)
io = gr.Interface(
inference,
gr.Textbox(lines=3),
outputs=[
gr.Textbox(lines=3, label="Flan T5")
],
title=title,
)
io.launch(share=False,debug=False)