| import threading |
|
|
| import torch |
| import gradio as gr |
|
|
| from transformers import AutoTokenizer |
| from transformers import GenerationConfig |
| from transformers import AutoModelForCausalLM |
| from transformers import TextIteratorStreamer |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| MODEL_ID = "Qwen/Qwen3-0.6B" |
|
|
| |
| SYSTEM = "You are a helpful, concise assistant." |
|
|
| device = ( |
| "cuda" |
| if torch.cuda.is_available() |
| |
| |
| |
| else "cpu" |
| ) |
|
|
| |
| |
| |
| |
|
|
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) |
| model = AutoModelForCausalLM.from_pretrained( |
| MODEL_ID, |
| |
| ).to(device) |
|
|
| |
| context_window = getattr(model.config, "max_position_embeddings", None) |
| if context_window is None: |
| context_window = getattr(tokenizer, "model_max_length", 2048) |
|
|
| print(f"model: {MODEL_ID}, context window: {context_window}.") |
|
|
|
|
| def predict(message, history): |
| """ |
| Gradio ChatInterface callback. |
| |
| - `history` is a list of dicts with `role` and `content` (type="messages"). |
| - We append the latest user message, then build a chat template for Qwen. |
| """ |
|
|
| |
|
|
| |
| conversation = history + [{"role": "user", "content": message}] |
|
|
| |
| if SYSTEM: |
| conversation = [ |
| { |
| "role": "system", |
| "content": SYSTEM, |
| }, |
| *conversation, |
| ] |
|
|
| |
| |
| input_text = tokenizer.apply_chat_template( |
| conversation, |
| tokenize=False, |
| add_generation_prompt=True, |
| ) |
|
|
| inputs = tokenizer( |
| input_text, |
| return_tensors="pt", |
| add_special_tokens=False, |
| ).to(device) |
|
|
| |
| input_len = inputs["input_ids"].shape[1] |
| max_new_tokens = max(1, context_window - input_len) |
|
|
| |
| |
| |
| streamer = TextIteratorStreamer( |
| tokenizer, |
| skip_prompt=True, |
| skip_special_tokens=True, |
| ) |
|
|
| generation_config = GenerationConfig.from_pretrained(MODEL_ID) |
| generation_config.max_new_tokens = max_new_tokens |
| |
| model.generation_config.pad_token_id = tokenizer.eos_token_id |
|
|
| |
| |
| def _run_generation(): |
| model.generate( |
| **inputs, |
| generation_config=generation_config, |
| streamer=streamer, |
| ) |
|
|
| thread = threading.Thread(target=_run_generation) |
| thread.start() |
|
|
| |
| |
| |
| generated = "" |
| in_think = False |
|
|
| for new_text in streamer: |
| if not new_text: |
| continue |
|
|
| |
| next_text_stripped = new_text.strip() |
| if next_text_stripped == "<think>": |
| generated += "<p style='color:#777; font-size: 12px; font-style:italic;'>" |
| in_think = True |
| continue |
| if next_text_stripped == "</think>": |
| generated += "</p>" |
| in_think = False |
| continue |
|
|
| generated += new_text |
|
|
| if in_think: |
| |
| yield generated + "</p>" |
| else: |
| |
| yield generated |
|
|
| |
| thread.join() |
|
|
|
|
| demo = gr.ChatInterface( |
| predict, |
| api_name="chat", |
| ) |
|
|
| demo.launch() |
|
|