File size: 819 Bytes
7a7aaa3 9b67530 7a7aaa3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import gradio as gr
model_name = "google/flan-t5-base"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
def text_2_text_generation(input):
input_ids = tokenizer(input, return_tensors="pt", truncation= True).input_ids
output = model.generate(input_ids,
max_new_tokens = 200,
do_sample = True,
top_p = 0.9,
top_k = 50)
output_text = tokenizer.decode(output[0], skip_special_tokens=True)
return output_text
iface = gr.Interface(
fn = text_2_text_generation,
inputs = gr.Textbox(label = "Enter your queries...", lines = 5),
outputs = gr.Textbox(lines = 5)
)
iface.launch() |