Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from transformers import ( | |
AutoModelForCausalLM, | |
AutoTokenizer, | |
BitsAndBytesConfig, | |
pipeline | |
) | |
model_name = "RaviNaik/Phi2-Osst" | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
trust_remote_code=True, | |
device_map=device | |
) | |
model.config.use_cache = False | |
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, device_map=device) | |
tokenizer.pad_token = tokenizer.eos_token | |
chat_template = """<|im_start|>system | |
You are a helpful assistant who always respond to user queries<|im_end|> | |
<im_start>user | |
{prompt}<|im_end|> | |
<|im_start|>assistant | |
""" | |
def generate(prompt, max_length, temperature, num_samples): | |
prompt = prompt.strip() | |
pipe = pipeline(task="text-generation", model=model, tokenizer=tokenizer, max_length=max_length, temperature=temperature, num_return_sequences=num_samples) | |
# result = pipe(chat_template.format(prompt=prompt)) | |
result = pipe(prompt) | |
return {output: result} | |
with gr.Blocks() as app: | |
gr.Markdown("## ERA Session27 - Phi2 Model Finetuning with QLoRA on OpenAssistant Conversations Dataset (OASST1)") | |
with gr.Row(): | |
with gr.Column(): | |
prompt_box = gr.Textbox(label="Initial Prompt", interactive=True) | |
max_length = gr.Slider( | |
minimum=50, | |
maximum=500, | |
value=200, | |
step=10, | |
label="Select Number of Tokens to be Generated", | |
interactive=True, | |
) | |
temperature = gr.Slider( | |
minimum=0.1, | |
maximum=1, | |
value=0.7, | |
step=0.1, | |
label="Select Temperature", | |
interactive=True, | |
) | |
num_samples = gr.Dropdown( | |
choices=[1, 2, 5, 10], | |
value=1, | |
interactive=True, | |
label="Select No. of outputs to be generated", | |
) | |
submit_btn = gr.Button(value="Generate") | |
with gr.Column(): | |
output = gr.JSON(label="Generated Text") | |
submit_btn.click( | |
generate, | |
inputs=[prompt_box, max_length, temperature, num_samples], | |
outputs=[output], | |
) | |
app.launch() |