Spaces:
Paused
Paused
File size: 5,112 Bytes
01041bb d8ef477 01041bb 4e38b02 d8ef477 a54d338 d8ef477 01041bb d8ef477 a54d338 d8ef477 01041bb d8ef477 01041bb d8ef477 4e38b02 d8ef477 01041bb d8ef477 01041bb d8ef477 01041bb d8ef477 13f4ac6 d8ef477 01041bb d8ef477 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
import threading
import torch
# Load base model directly and then add the adapter
model = AutoModelForCausalLM.from_pretrained("unsloth/gemma-3-1b-it")
# Apply adapter from the fine-tuned version
model.load_adapter("Oysiyl/gemma-3-1B-GRPO")
tokenizer = AutoTokenizer.from_pretrained("unsloth/gemma-3-1b-it")
def process_history(history):
"""Process chat history into the format expected by the model."""
processed_history = []
for user_msg, assistant_msg in history:
# Always add user message first, even if empty
processed_history.append({"role": "user", "content": [{"type": "text", "text": user_msg or ""}]})
# Always add assistant message, even if empty
processed_history.append({"role": "assistant", "content": [{"type": "text", "text": assistant_msg or ""}]})
return processed_history
def process_new_user_message(message):
"""Process a new user message into the format expected by the model."""
return [{"type": "text", "text": message}]
def respond(
message,
history: list[tuple[str, str]],
system_message,
max_tokens,
temperature,
top_p,
):
# Format messages according to Gemma's expected chat format
messages = []
if system_message:
messages.append({"role": "system", "content": [{"type": "text", "text": system_message}]})
# Process the conversation history
if history:
messages.extend(process_history(history))
# Add the new user message
messages.append({"role": "user", "content": process_new_user_message(message)})
# Apply chat template
inputs = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt"
)
if torch.cuda.is_available():
inputs = inputs.to(device=model.device, dtype=torch.bfloat16)
model.to("cuda")
# Set up the streamer
streamer = TextIteratorStreamer(tokenizer, timeout=30.0, skip_prompt=True, skip_special_tokens=False)
# Run generation in a separate thread
generate_kwargs = dict(
**inputs,
streamer=streamer,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
top_k=64, # Recommended Gemma-3 setting
do_sample=True,
)
thread = threading.Thread(target=model.generate, kwargs=generate_kwargs)
thread.start()
# Stream the output, add "Thinking" at the beginning
output = "Thinking: \n"
for token in streamer:
output += token
# Check if "<SOLUTION>" token is in the output and format everything after it as bold
if "<SOLUTION>" in output:
solution_start = output.find("<SOLUTION>") + len("<SOLUTION>")
solution_end = output.find("</SOLUTION>")
if solution_end > solution_start:
formatted_output = (
output[:solution_start] +
"Final answer: **" + output[solution_start:solution_end] + "**" +
output[solution_end:]
)
yield formatted_output
else:
yield output
else:
yield output
"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""
demo = gr.ChatInterface(
respond,
additional_inputs=[
gr.Textbox(
value="You are given a problem.\nThink about the problem and provide your working out.\nPlace it between <start_working_out> and <end_working_out>.\nThen, provide your solution between <SOLUTION></SOLUTION>",
label="System message"
),
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=1.0, step=0.1, label="Temperature"),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
label="Top-p (nucleus sampling)",
),
],
examples=[
["Apple sold 100 iPhones at their New York store today for an average cost of $1000. They also sold 20 iPads for an average cost of $900 and 80 Apple TVs for an average cost of $200. What was the average cost across all products sold today?"],
["Julie is reading a 120-page book. Yesterday, she was able to read 12 pages and today, she read twice as many pages as yesterday. If she wants to read half of the remaining pages tomorrow, how many pages should she read?"],
["Mel is three years younger than Katherine. When Katherine is two dozen years old, how old will Mel be in years?"],
["What is the sqrt of 101?"],
],
cache_examples=False,
chatbot=gr.Chatbot(
latex_delimiters=[
{"left": "$$", "right": "$$", "display": True},
{"left": "$", "right": "$", "display": False}
],
),
)
if __name__ == "__main__":
demo.launch() |