| import inspect |
| import os |
| import threading |
|
|
| import gradio as gr |
| import torch |
| from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer |
|
|
|
|
| os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") |
|
|
| |
| |
| |
| |
| MODEL_ID = os.getenv("MODEL_ID", "Qwen/Qwen3-0.6B") |
| MAX_NEW_TOKENS = int(os.getenv("MAX_NEW_TOKENS", "2048")) |
| MAX_INPUT_TOKENS = int(os.getenv("MAX_INPUT_TOKENS", "1536")) |
| MAX_HISTORY_TURNS = int(os.getenv("MAX_HISTORY_TURNS", "3")) |
| N_THREADS = int(os.getenv("N_THREADS", str(max(1, os.cpu_count() or 1)))) |
| DEFAULT_SYSTEM_PROMPT = os.getenv( |
| "SYSTEM_PROMPT", |
| "You are a helpful assistant. Keep answers clear and concise. If user", |
| ) |
|
|
| PRESETS = { |
| "Math": { |
| "system": "You are a careful math tutor. Think through the problem, then give a short final answer.", |
| "prompt": "Solve: If 2x^2 - 7x + 3 = 0, what are the real solutions?", |
| "thinking": True, |
| "sample_reasoning": "The discriminant is 49 - 24 = 25, so the roots are easy to compute with the quadratic formula.", |
| "sample_answer": "The real solutions are x = 3 and x = 1/2.", |
| }, |
| "Coding": { |
| "system": "You are a Python assistant. Prefer short, readable code.", |
| "prompt": "Write a Python function that merges two sorted lists into one sorted list.", |
| "thinking": True, |
| "sample_reasoning": "Use two pointers. Compare the current elements, append the smaller one, then append the leftovers.", |
| "sample_answer": "Here is a compact merge function plus a tiny example.", |
| }, |
| "Structured output": { |
| "system": "Return compact JSON and avoid extra commentary.", |
| "prompt": "Extract JSON from: Call Mina by Friday, priority high, budget about $2400, topic is launch video edits.", |
| "thinking": False, |
| "sample_reasoning": "Reasoning is disabled here so the output stays short and machine-friendly.", |
| "sample_answer": '{"person":"Mina","deadline":"Friday","priority":"high","budget_usd":2400,"topic":"launch video edits"}', |
| }, |
| "Function calling style": { |
| "system": "You are an assistant that plans tool use when it helps. If a tool would help, say what tool you would call and with which arguments.", |
| "prompt": "Pretend you have tools. For 18.75 * 42 - 199 and converting 12 km to miles, explain which tool calls you would make, then give the result.", |
| "thinking": True, |
| "sample_reasoning": "I would use a calculator tool for the arithmetic and a unit-conversion tool for the distance conversion.", |
| "sample_answer": "Calculator(18.75 * 42 - 199) -> 588.5\nConvert(12 km -> miles) -> about 7.46 miles", |
| }, |
| "Creative writing": { |
| "system": "Write vivid, tight prose.", |
| "prompt": "Write a two-sentence opening for a sci-fi heist story set on a drifting museum ship.", |
| "thinking": False, |
| "sample_reasoning": "Reasoning is disabled for a faster clean draft.", |
| "sample_answer": "By the time the museum ship crossed into the dead zone, every priceless relic aboard had started broadcasting a heartbeat. Nia took that as her cue to cut the lights and steal the one artifact already trying to escape.", |
| }, |
| } |
|
|
|
|
| torch.set_num_threads(N_THREADS) |
| try: |
| torch.set_num_interop_threads(max(1, min(2, N_THREADS))) |
| except RuntimeError: |
| pass |
|
|
| _tokenizer = None |
| _model = None |
| _load_lock = threading.Lock() |
| _generate_lock = threading.Lock() |
|
|
|
|
| def make_chatbot(label, height=520): |
| kwargs = {"label": label, "height": height} |
| if "type" in inspect.signature(gr.Chatbot.__init__).parameters: |
| kwargs["type"] = "messages" |
| return gr.Chatbot(**kwargs) |
|
|
|
|
| def get_model(): |
| global _tokenizer, _model |
| if _model is None or _tokenizer is None: |
| with _load_lock: |
| if _model is None or _tokenizer is None: |
| _tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True) |
| _model = AutoModelForCausalLM.from_pretrained( |
| MODEL_ID, |
| torch_dtype=torch.float32, |
| ) |
| _model.eval() |
| return _tokenizer, _model |
|
|
|
|
| def clone_messages(messages): |
| return [dict(item) for item in (messages or [])] |
|
|
|
|
| def load_preset(name): |
| preset = PRESETS[name] |
| return ( |
| preset["system"], |
| preset["prompt"], |
| preset["thinking"], |
| preset["sample_reasoning"], |
| preset["sample_answer"], |
| ) |
|
|
|
|
| def clear_all(): |
| return [], [], [], "" |
|
|
|
|
| def strip_non_think_specials(text): |
| text = text or "" |
| for token in ["<|im_end|>", "<|endoftext|>", "<|end▁of▁sentence|>"]: |
| text = text.replace(token, "") |
| return text |
|
|
|
|
| def final_cleanup(text): |
| text = strip_non_think_specials(text) |
| text = text.replace("<think>", "").replace("</think>", "") |
| return text.strip() |
|
|
|
|
| def split_stream_text(raw_text, thinking): |
| raw_text = strip_non_think_specials(raw_text) |
| if not thinking: |
| return "", final_cleanup(raw_text), False |
|
|
| raw_text = raw_text.replace("<think>", "") |
| if "</think>" in raw_text: |
| reasoning, answer = raw_text.split("</think>", 1) |
| return reasoning.strip(), answer.strip(), True |
|
|
| return raw_text.strip(), "", False |
|
|
|
|
| def respond_stream( |
| message, |
| system_prompt, |
| thinking, |
| model_history, |
| reasoning_chat, |
| answer_chat, |
| ): |
| message = (message or "").strip() |
| if not message: |
| yield clone_messages(reasoning_chat), clone_messages(answer_chat), list(model_history or []), "" |
| return |
|
|
| model_history = list(model_history or []) |
| reasoning_chat = clone_messages(reasoning_chat) |
| answer_chat = clone_messages(answer_chat) |
|
|
| reasoning_chat.append({"role": "user", "content": message}) |
| reasoning_chat.append( |
| { |
| "role": "assistant", |
| "content": "(thinking...)" if thinking else "(reasoning disabled)", |
| } |
| ) |
| answer_chat.append({"role": "user", "content": message}) |
| answer_chat.append({"role": "assistant", "content": ""}) |
|
|
| yield clone_messages(reasoning_chat), clone_messages(answer_chat), list(model_history), "" |
|
|
| try: |
| tokenizer, model = get_model() |
| short_history = model_history[-2 * MAX_HISTORY_TURNS :] |
| messages = [ |
| {"role": "system", "content": (system_prompt or "").strip() or DEFAULT_SYSTEM_PROMPT}, |
| *short_history, |
| {"role": "user", "content": message}, |
| ] |
|
|
| prompt = tokenizer.apply_chat_template( |
| messages, |
| tokenize=False, |
| add_generation_prompt=True, |
| enable_thinking=thinking, |
| ) |
| inputs = tokenizer(prompt, return_tensors="pt") |
| input_ids = inputs["input_ids"][:, -MAX_INPUT_TOKENS:] |
| attention_mask = inputs["attention_mask"][:, -MAX_INPUT_TOKENS:] |
|
|
| streamer = TextIteratorStreamer( |
| tokenizer, |
| skip_prompt=True, |
| skip_special_tokens=False, |
| clean_up_tokenization_spaces=False, |
| timeout=None, |
| ) |
|
|
| generation_kwargs = { |
| "input_ids": input_ids, |
| "attention_mask": attention_mask, |
| "max_new_tokens": MAX_NEW_TOKENS, |
| "do_sample": True, |
| "temperature": 0.6 if thinking else 0.7, |
| "top_p": 0.95 if thinking else 0.8, |
| "top_k": 20, |
| "pad_token_id": tokenizer.eos_token_id, |
| "streamer": streamer, |
| } |
|
|
| generation_error = {} |
|
|
| def run_generation(): |
| try: |
| with _generate_lock: |
| model.generate(**generation_kwargs) |
| except Exception as exc: |
| generation_error["message"] = str(exc) |
| streamer.on_finalized_text("", stream_end=True) |
|
|
| thread = threading.Thread(target=run_generation, daemon=True) |
| thread.start() |
|
|
| raw_text = "" |
| saw_end_think = False |
|
|
| for chunk in streamer: |
| raw_text += chunk |
| reasoning_text, answer_text, saw_end_now = split_stream_text(raw_text, thinking) |
| saw_end_think = saw_end_think or saw_end_now |
|
|
| if thinking: |
| if saw_end_think: |
| reasoning_chat[-1]["content"] = reasoning_text or "(no reasoning text returned)" |
| else: |
| reasoning_chat[-1]["content"] = reasoning_text or "(thinking...)" |
| else: |
| reasoning_chat[-1]["content"] = "(reasoning disabled)" |
|
|
| answer_chat[-1]["content"] = answer_text |
| yield clone_messages(reasoning_chat), clone_messages(answer_chat), list(model_history), "" |
|
|
| thread.join() |
|
|
| if generation_error: |
| reasoning_chat[-1]["content"] = "" |
| answer_chat[-1]["content"] = f"Error while running the local CPU model: {generation_error['message']}" |
| yield clone_messages(reasoning_chat), clone_messages(answer_chat), list(model_history), "" |
| return |
|
|
| reasoning_text, answer_text, saw_end_think = split_stream_text(raw_text, thinking) |
| if thinking and not saw_end_think: |
| reasoning_text = "" |
| answer_text = final_cleanup(raw_text) |
|
|
| if thinking: |
| reasoning_chat[-1]["content"] = reasoning_text or "(no reasoning text returned)" |
| else: |
| reasoning_chat[-1]["content"] = "(reasoning disabled)" |
|
|
| answer_chat[-1]["content"] = answer_text or "(empty response)" |
| model_history = short_history + [ |
| {"role": "user", "content": message}, |
| {"role": "assistant", "content": answer_chat[-1]["content"]}, |
| ] |
|
|
| yield clone_messages(reasoning_chat), clone_messages(answer_chat), list(model_history), "" |
|
|
| except Exception as exc: |
| reasoning_chat[-1]["content"] = "" |
| answer_chat[-1]["content"] = f"Error while preparing the local CPU model: {exc}" |
| yield clone_messages(reasoning_chat), clone_messages(answer_chat), list(model_history), "" |
|
|
|
|
| with gr.Blocks(title="Local CPU split-reasoning chat") as demo: |
| gr.Markdown( |
| "# Local CPU split-reasoning chat\n" |
| f"Running a local safetensors model on CPU from `{MODEL_ID}`. No GGUF and no external inference provider.\n\n" |
| "The first request downloads the model, so the cold start is slower." |
| ) |
|
|
| with gr.Row(): |
| preset = gr.Dropdown( |
| choices=list(PRESETS.keys()), |
| value="Math", |
| label="Preset prompt", |
| ) |
| thinking = gr.Checkbox(label="Enable thinking", value=True) |
|
|
| system_prompt = gr.Textbox( |
| label="System prompt", |
| value=PRESETS["Math"]["system"], |
| lines=3, |
| ) |
|
|
| user_input = gr.Textbox( |
| label="Your message", |
| value=PRESETS["Math"]["prompt"], |
| lines=4, |
| ) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| with gr.Row(): |
| send_btn = gr.Button("Send", variant="primary") |
| clear_btn = gr.Button("Clear") |
|
|
| with gr.Row(): |
| reasoning_bot = make_chatbot("Reasoning", height=520) |
| answer_bot = make_chatbot("Assistant", height=520) |
|
|
| model_history_state = gr.State([]) |
|
|
| preset.change( |
| fn=load_preset, |
| inputs=preset, |
| |
| outputs=[system_prompt, user_input, thinking], |
| ) |
|
|
| send_btn.click( |
| fn=respond_stream, |
| inputs=[user_input, system_prompt, thinking, model_history_state, reasoning_bot, answer_bot], |
| outputs=[reasoning_bot, answer_bot, model_history_state, user_input], |
| ) |
| user_input.submit( |
| fn=respond_stream, |
| inputs=[user_input, system_prompt, thinking, model_history_state, reasoning_bot, answer_bot], |
| outputs=[reasoning_bot, answer_bot, model_history_state, user_input], |
| ) |
|
|
| clear_btn.click( |
| fn=clear_all, |
| inputs=None, |
| outputs=[reasoning_bot, answer_bot, model_history_state, user_input], |
| ) |
|
|
|
|
| demo.queue() |
| demo.launch() |
|
|