Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer | |
import threading | |
import torch | |
# Load base model directly and then add the adapter | |
model = AutoModelForCausalLM.from_pretrained("unsloth/gemma-3-1b-it") | |
# Apply adapter from the fine-tuned version | |
model.load_adapter("Oysiyl/gemma-3-1B-GRPO") | |
tokenizer = AutoTokenizer.from_pretrained("unsloth/gemma-3-1b-it") | |
def process_history(history): | |
"""Process chat history into the format expected by the model.""" | |
processed_history = [] | |
for user_msg, assistant_msg in history: | |
if user_msg: | |
processed_history.append({"role": "user", "content": [{"type": "text", "text": user_msg}]}) | |
if assistant_msg: | |
processed_history.append({"role": "assistant", "content": [{"type": "text", "text": assistant_msg}]}) | |
return processed_history | |
def process_new_user_message(message): | |
"""Process a new user message into the format expected by the model.""" | |
return [{"type": "text", "text": message}] | |
def respond( | |
message, | |
history: list[tuple[str, str]], | |
system_message, | |
max_tokens, | |
temperature, | |
top_p, | |
): | |
# Format messages according to Gemma's expected chat format | |
messages = [] | |
if system_message: | |
messages.append({"role": "system", "content": [{"type": "text", "text": system_message}]}) | |
messages.extend(process_history(history)) | |
messages.append({"role": "user", "content": process_new_user_message(message)}) | |
# Apply chat template | |
inputs = tokenizer.apply_chat_template( | |
messages, | |
add_generation_prompt=True, | |
tokenize=True, | |
return_dict=True, | |
return_tensors="pt" | |
) | |
if torch.cuda.is_available(): | |
inputs = inputs.to(device=model.device, dtype=torch.bfloat16) | |
model.to("cuda") | |
# Set up the streamer | |
streamer = TextIteratorStreamer(tokenizer, timeout=30.0, skip_prompt=True, skip_special_tokens=False) | |
# Run generation in a separate thread | |
generate_kwargs = dict( | |
**inputs, | |
streamer=streamer, | |
max_new_tokens=max_tokens, | |
temperature=temperature, | |
top_p=top_p, | |
top_k=64, # Recommended Gemma-3 setting | |
do_sample=True, | |
) | |
thread = threading.Thread(target=model.generate, kwargs=generate_kwargs) | |
thread.start() | |
# Stream the output, add "Thinking" at the beginning | |
output = "Thinking: \n" | |
for token in streamer: | |
output += token | |
# Check if "<SOLUTION>" token is in the output and format everything after it as bold | |
if "<SOLUTION>" in output: | |
solution_start = output.find("<SOLUTION>") + len("<SOLUTION>") | |
solution_end = output.find("</SOLUTION>") | |
if solution_end > solution_start: | |
formatted_output = ( | |
output[:solution_start] + | |
"Final answer: **" + output[solution_start:solution_end] + "**" + | |
output[solution_end:] | |
) | |
yield formatted_output | |
else: | |
yield output | |
else: | |
yield output | |
""" | |
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface | |
""" | |
demo = gr.ChatInterface( | |
respond, | |
additional_inputs=[ | |
gr.Textbox( | |
value="You are given a problem.\nThink about the problem and provide your working out.\nPlace it between <start_working_out> and <end_working_out>.\nThen, provide your solution between <SOLUTION></SOLUTION>", | |
label="System message" | |
), | |
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"), | |
gr.Slider(minimum=0.1, maximum=4.0, value=1.0, step=0.1, label="Temperature"), | |
gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=0.95, | |
step=0.05, | |
label="Top-p (nucleus sampling)", | |
), | |
], | |
examples=[ | |
["Apple sold 100 iPhones at their New York store today for an average cost of $1000. They also sold 20 iPads for an average cost of $900 and 80 Apple TVs for an average cost of $200. What was the average cost across all products sold today?"], | |
["Julie is reading a 120-page book. Yesterday, she was able to read 12 pages and today, she read twice as many pages as yesterday. If she wants to read half of the remaining pages tomorrow, how many pages should she read?"], | |
["Mel is three years younger than Katherine. When Katherine is two dozen years old, how old will Mel be in years?"], | |
["What is the sqrt of 101?"], | |
], | |
cache_examples=False, | |
chatbot=gr.Chatbot( | |
latex_delimiters=[ | |
{"left": "$$", "right": "$$", "display": True}, | |
{"left": "$", "right": "$", "display": False} | |
], | |
), | |
) | |
if __name__ == "__main__": | |
demo.launch() |