Spaces:
Runtime error
Runtime error
"""#Imports""" | |
#!pip install transformers gradio accelerate bitsandbytes sentencepiece | |
#import multiprocessing | |
import torch | |
from transformers import T5Tokenizer, T5ForConditionalGeneration | |
import gradio as gr | |
"""#Code""" | |
#torch.set_default_dtype(torch.float16) | |
device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
"""##FP 16""" | |
tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-xl") | |
model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-xl", device_map="auto", torch_dtype=torch.float16) | |
"""###Interface""" | |
def generate(input_text, minimum_length, maximum_length, temperature, repetition_penalty): | |
input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to("cuda") | |
outputs = model.generate(input_ids, | |
min_length=minimum_length, | |
max_new_tokens=maximum_length, | |
length_penalty=1.4, | |
num_beams=6, | |
no_repeat_ngram_size=3, | |
temperature=temperature, | |
top_k=100, | |
top_p=0.9, | |
repetition_penalty=repetition_penalty, | |
) | |
return tokenizer.decode(outputs[0], skip_special_tokens=True).capitalize() | |
title = "Flan-T5-XL GRADIO GUI" | |
def inference(input_text, minimum_length, maximum_length, temperature, repetition_penalty): | |
return generate(input_text, minimum_length, maximum_length, temperature, repetition_penalty) | |
gr.Interface( | |
fn=inference, | |
inputs=[gr.Textbox(lines=4, label="Input"), gr.Slider(0, 300, value=20, step=10, label="Minimum length"), gr.Slider(100, 2000, value=1000, step=100, label="Maximum length"), gr.Slider(0, 1, value=0.75, step=0.05, label="Temperature"), gr.Slider(1, 3, value=2.1, step=0.1, label="Repetition penalty")], | |
outputs=[ | |
gr.Textbox(lines=2, label="Output") | |
], | |
title=title, | |
css=""" | |
body {background-color: lightgreen} | |
.input_text input { | |
background-color: lightblue !important; | |
} | |
""" | |
).launch() |