File size: 5,112 Bytes
01041bb
d8ef477
 
 
01041bb
4e38b02
 
 
 
 
d8ef477
 
 
 
 
 
a54d338
 
 
 
d8ef477
 
 
 
 
 
01041bb
 
 
 
 
 
 
 
 
 
d8ef477
 
 
 
 
a54d338
 
 
 
 
d8ef477
 
 
 
01041bb
d8ef477
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
01041bb
 
d8ef477
 
 
 
 
 
 
 
 
 
 
4e38b02
d8ef477
 
 
 
 
 
 
 
 
 
 
 
 
 
01041bb
 
 
 
 
 
 
 
d8ef477
 
 
 
01041bb
d8ef477
01041bb
 
 
 
 
 
 
 
d8ef477
 
 
 
 
 
13f4ac6
d8ef477
 
 
 
 
 
01041bb
 
 
 
d8ef477
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
import threading
import torch

# Load base model directly and then add the adapter
model = AutoModelForCausalLM.from_pretrained("unsloth/gemma-3-1b-it")
# Apply adapter from the fine-tuned version
model.load_adapter("Oysiyl/gemma-3-1B-GRPO")
tokenizer = AutoTokenizer.from_pretrained("unsloth/gemma-3-1b-it")


def process_history(history):
    """Process chat history into the format expected by the model."""
    processed_history = []
    for user_msg, assistant_msg in history:
        # Always add user message first, even if empty
        processed_history.append({"role": "user", "content": [{"type": "text", "text": user_msg or ""}]})
        # Always add assistant message, even if empty
        processed_history.append({"role": "assistant", "content": [{"type": "text", "text": assistant_msg or ""}]})
    return processed_history


def process_new_user_message(message):
    """Process a new user message into the format expected by the model."""
    return [{"type": "text", "text": message}]


def respond(
    message,
    history: list[tuple[str, str]],
    system_message,
    max_tokens,
    temperature,
    top_p,
):
    # Format messages according to Gemma's expected chat format
    messages = []
    if system_message:
        messages.append({"role": "system", "content": [{"type": "text", "text": system_message}]})
    
    # Process the conversation history
    if history:
        messages.extend(process_history(history))
    
    # Add the new user message
    messages.append({"role": "user", "content": process_new_user_message(message)})
    
    # Apply chat template
    inputs = tokenizer.apply_chat_template(
        messages,
        add_generation_prompt=True,
        tokenize=True,
        return_dict=True,
        return_tensors="pt"
    )
    
    if torch.cuda.is_available():
        inputs = inputs.to(device=model.device, dtype=torch.bfloat16)
        model.to("cuda")
    
    # Set up the streamer
    streamer = TextIteratorStreamer(tokenizer, timeout=30.0, skip_prompt=True, skip_special_tokens=False)
    
    # Run generation in a separate thread
    generate_kwargs = dict(
        **inputs,
        streamer=streamer,
        max_new_tokens=max_tokens,
        temperature=temperature,
        top_p=top_p,
        top_k=64,  # Recommended Gemma-3 setting
        do_sample=True,
    )
    
    thread = threading.Thread(target=model.generate, kwargs=generate_kwargs)
    thread.start()
    
    # Stream the output, add "Thinking" at the beginning
    output = "Thinking: \n"
    for token in streamer:
        output += token
        # Check if "<SOLUTION>" token is in the output and format everything after it as bold
        if "<SOLUTION>" in output:
            solution_start = output.find("<SOLUTION>") + len("<SOLUTION>")
            solution_end = output.find("</SOLUTION>")
            if solution_end > solution_start:
                formatted_output = (
                    output[:solution_start] + 
                    "Final answer: **" + output[solution_start:solution_end] + "**" + 
                    output[solution_end:]
                )
                yield formatted_output
            else:
                yield output
        else:
            yield output


"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""
demo = gr.ChatInterface(
    respond,
    additional_inputs=[
        gr.Textbox(
            value="You are given a problem.\nThink about the problem and provide your working out.\nPlace it between <start_working_out> and <end_working_out>.\nThen, provide your solution between <SOLUTION></SOLUTION>",
            label="System message"
        ),
        gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=1.0, step=0.1, label="Temperature"),
        gr.Slider(
            minimum=0.1,
            maximum=1.0,
            value=0.95,
            step=0.05,
            label="Top-p (nucleus sampling)",
        ),
    ],
    examples=[
        ["Apple sold 100 iPhones at their New York store today for an average cost of $1000. They also sold 20 iPads for an average cost of $900 and 80 Apple TVs for an average cost of $200. What was the average cost across all products sold today?"],
        ["Julie is reading a 120-page book. Yesterday, she was able to read 12 pages and today, she read twice as many pages as yesterday. If she wants to read half of the remaining pages tomorrow, how many pages should she read?"],
        ["Mel is three years younger than Katherine. When Katherine is two dozen years old, how old will Mel be in years?"],
        ["What is the sqrt of 101?"],
    ],
    cache_examples=False,
    chatbot=gr.Chatbot(
        latex_delimiters=[
            {"left": "$$", "right": "$$", "display": True},
            {"left": "$", "right": "$", "display": False}
        ],
    ),
)


if __name__ == "__main__":
    demo.launch()