hanzla's picture
first
db2e21a
raw
history blame
1.76 kB
import gradio as gr
import spaces
import torch
import transformers
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
model_name = "ModularityAI/gemma-2b-datascience-it-raft"
tokenizer_name = "google/gemma-2b-it"
model = AutoModelForCausalLM.from_pretrained(model_name,torch_dtype=torch.bfloat16,device='cuda')
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name,device='cuda')
pipeline = transformers.pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
device="cuda",
)
def format_test_question(q):
return f"<bos><start_of_turn>user {q} <end_of_turn>model "
@spaces.GPU
def chat_function(message, history,max_new_tokens,temperature):
prompt = format_test_question(message)
print(prompt)
temp = temperature + 0.1
outputs = pipeline(
prompt,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=temp,
)
print(outputs)
return outputs[0]["generated_text"][len(prompt):]
gr.ChatInterface(
chat_function,
chatbot=gr.Chatbot(height=400),
textbox=gr.Textbox(placeholder="Enter message here", container=False, scale=7),
title="Gemma 2B Data Science QA RAFT Demo",
description="""
This space is dedicated for chatting with Gemma 2B Finetuned for Data Science QA using RAFT. Find this model here: https://huggingface.co/ModularityAI/gemma-2b-datascience-it-raft
Feel free to play with customization in the "Additional Inputs".
Fine tune Notebook: https://www.kaggle.com/code/hanzlajavaid/gemma-finetuning-raft-technique
""",
theme="Soft",
additional_inputs=[
gr.Slider(512, 4096, value=1024,label="Max New Tokens"),
gr.Slider(0, 1,value=0.5 ,label="Temperature")
]
).launch()