Oysiyl's picture
Update app.py
13f4ac6 verified
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
import threading
import torch
# Load base model directly and then add the adapter
model = AutoModelForCausalLM.from_pretrained("unsloth/gemma-3-1b-it")
# Apply adapter from the fine-tuned version
model.load_adapter("Oysiyl/gemma-3-1B-GRPO")
tokenizer = AutoTokenizer.from_pretrained("unsloth/gemma-3-1b-it")
def process_history(history):
"""Process chat history into the format expected by the model."""
processed_history = []
for user_msg, assistant_msg in history:
if user_msg:
processed_history.append({"role": "user", "content": [{"type": "text", "text": user_msg}]})
if assistant_msg:
processed_history.append({"role": "assistant", "content": [{"type": "text", "text": assistant_msg}]})
return processed_history
def process_new_user_message(message):
"""Process a new user message into the format expected by the model."""
return [{"type": "text", "text": message}]
def respond(
message,
history: list[tuple[str, str]],
system_message,
max_tokens,
temperature,
top_p,
):
# Format messages according to Gemma's expected chat format
messages = []
if system_message:
messages.append({"role": "system", "content": [{"type": "text", "text": system_message}]})
messages.extend(process_history(history))
messages.append({"role": "user", "content": process_new_user_message(message)})
# Apply chat template
inputs = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt"
)
if torch.cuda.is_available():
inputs = inputs.to(device=model.device, dtype=torch.bfloat16)
model.to("cuda")
# Set up the streamer
streamer = TextIteratorStreamer(tokenizer, timeout=30.0, skip_prompt=True, skip_special_tokens=False)
# Run generation in a separate thread
generate_kwargs = dict(
**inputs,
streamer=streamer,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
top_k=64, # Recommended Gemma-3 setting
do_sample=True,
)
thread = threading.Thread(target=model.generate, kwargs=generate_kwargs)
thread.start()
# Stream the output, add "Thinking" at the beginning
output = "Thinking: \n"
for token in streamer:
output += token
# Check if "<SOLUTION>" token is in the output and format everything after it as bold
if "<SOLUTION>" in output:
solution_start = output.find("<SOLUTION>") + len("<SOLUTION>")
solution_end = output.find("</SOLUTION>")
if solution_end > solution_start:
formatted_output = (
output[:solution_start] +
"Final answer: **" + output[solution_start:solution_end] + "**" +
output[solution_end:]
)
yield formatted_output
else:
yield output
else:
yield output
"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""
demo = gr.ChatInterface(
respond,
additional_inputs=[
gr.Textbox(
value="You are given a problem.\nThink about the problem and provide your working out.\nPlace it between <start_working_out> and <end_working_out>.\nThen, provide your solution between <SOLUTION></SOLUTION>",
label="System message"
),
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=1.0, step=0.1, label="Temperature"),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
label="Top-p (nucleus sampling)",
),
],
examples=[
["Apple sold 100 iPhones at their New York store today for an average cost of $1000. They also sold 20 iPads for an average cost of $900 and 80 Apple TVs for an average cost of $200. What was the average cost across all products sold today?"],
["Julie is reading a 120-page book. Yesterday, she was able to read 12 pages and today, she read twice as many pages as yesterday. If she wants to read half of the remaining pages tomorrow, how many pages should she read?"],
["Mel is three years younger than Katherine. When Katherine is two dozen years old, how old will Mel be in years?"],
["What is the sqrt of 101?"],
],
cache_examples=False,
chatbot=gr.Chatbot(
latex_delimiters=[
{"left": "$$", "right": "$$", "display": True},
{"left": "$", "right": "$", "display": False}
],
),
)
if __name__ == "__main__":
demo.launch()