import gradio as gr from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer import threading import torch # Load base model directly and then add the adapter model = AutoModelForCausalLM.from_pretrained("unsloth/gemma-3-1b-it") # Apply adapter from the fine-tuned version model.load_adapter("Oysiyl/gemma-3-1B-GRPO") tokenizer = AutoTokenizer.from_pretrained("unsloth/gemma-3-1b-it") def process_history(history): """Process chat history into the format expected by the model.""" processed_history = [] for user_msg, assistant_msg in history: # Always add user message first, even if empty processed_history.append({"role": "user", "content": [{"type": "text", "text": user_msg or ""}]}) # Always add assistant message, even if empty processed_history.append({"role": "assistant", "content": [{"type": "text", "text": assistant_msg or ""}]}) return processed_history def process_new_user_message(message): """Process a new user message into the format expected by the model.""" return [{"type": "text", "text": message}] def respond( message, history: list[tuple[str, str]], system_message, max_tokens, temperature, top_p, ): # Format messages according to Gemma's expected chat format messages = [] if system_message: messages.append({"role": "system", "content": [{"type": "text", "text": system_message}]}) # Process the conversation history if history: messages.extend(process_history(history)) # Add the new user message messages.append({"role": "user", "content": process_new_user_message(message)}) # Apply chat template inputs = tokenizer.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt" ) if torch.cuda.is_available(): inputs = inputs.to(device=model.device, dtype=torch.bfloat16) model.to("cuda") # Set up the streamer streamer = TextIteratorStreamer(tokenizer, timeout=30.0, skip_prompt=True, skip_special_tokens=False) # Run generation in a separate thread generate_kwargs = dict( **inputs, streamer=streamer, max_new_tokens=max_tokens, temperature=temperature, top_p=top_p, top_k=64, # Recommended Gemma-3 setting do_sample=True, ) thread = threading.Thread(target=model.generate, kwargs=generate_kwargs) thread.start() # Stream the output, add "Thinking" at the beginning output = "Thinking: \n" for token in streamer: output += token # Check if "" token is in the output and format everything after it as bold if "" in output: solution_start = output.find("") + len("") solution_end = output.find("") if solution_end > solution_start: formatted_output = ( output[:solution_start] + "Final answer: **" + output[solution_start:solution_end] + "**" + output[solution_end:] ) yield formatted_output else: yield output else: yield output """ For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface """ demo = gr.ChatInterface( respond, additional_inputs=[ gr.Textbox( value="You are given a problem.\nThink about the problem and provide your working out.\nPlace it between and .\nThen, provide your solution between ", label="System message" ), gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"), gr.Slider(minimum=0.1, maximum=4.0, value=1.0, step=0.1, label="Temperature"), gr.Slider( minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)", ), ], examples=[ ["Apple sold 100 iPhones at their New York store today for an average cost of $1000. They also sold 20 iPads for an average cost of $900 and 80 Apple TVs for an average cost of $200. What was the average cost across all products sold today?"], ["Julie is reading a 120-page book. Yesterday, she was able to read 12 pages and today, she read twice as many pages as yesterday. If she wants to read half of the remaining pages tomorrow, how many pages should she read?"], ["Mel is three years younger than Katherine. When Katherine is two dozen years old, how old will Mel be in years?"], ["What is the sqrt of 101?"], ], cache_examples=False, chatbot=gr.Chatbot( latex_delimiters=[ {"left": "$$", "right": "$$", "display": True}, {"left": "$", "right": "$", "display": False} ], ), ) if __name__ == "__main__": demo.launch()