File size: 15,929 Bytes
2c2f3fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
import gradio as gr
import random
import re
import threading
import time

import spaces
import torch
import numpy as np

# Assuming the transformers library is installed
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer

# --- Global Settings ---
# These variables are placed in the global scope and will be loaded once when the Gradio app starts
system_prompt = []
device = "cuda" if torch.cuda.is_available() else "cpu"

MODEL_PATHS = {
    "Embformer-MiniMind-Base (0.1B)": ["HighCWu/Embformer-MiniMind-Base-0.1B", "Embformer-MiniMind-Base-0.1B"],
    "Embformer-MiniMind-Seqlen512 (0.1B)": ["HighCWu/Embformer-MiniMind-Seqlen512-0.1B", "Embformer-MiniMind-Seqlen512-0.1B"],
    "Embformer-MiniMind (0.1B)": ["HighCWu/Embformer-MiniMind-0.1B", "Embformer-MiniMind-0.1B"],
    "Embformer-MiniMind-RLHF (0.1B)": ["HighCWu/Embformer-MiniMind-RLHF-0.1B", "Embformer-MiniMind-RLHF-0.1B"],
    "Embformer-MiniMind-R1 (0.1B)": ["HighCWu/Embformer-MiniMind-R1-0.1B", "Embformer-MiniMind-R1-0.1B"],
}

# --- Helper Functions (Mostly unchanged) ---

def process_assistant_content(content, model_source, selected_model_name):
    """
    Processes the model output, converting <think> tags to HTML details elements,
    and handling content after </think>, filtering out <answer> tags.
    """
    is_r1_model = False
    if model_source == "API":
        if 'R1' in selected_model_name:
            is_r1_model = True
    else:
        model_identifier = MODEL_PATHS.get(selected_model_name, ["", ""])[1]
        if 'R1' in model_identifier:
            is_r1_model = True
    
    if not is_r1_model:
        return content

    # Fully closed <think>...</think> block
    if '<think>' in content and '</think>' in content:
        # Using re.split is more robust than finding indices
        parts = re.split(r'(</think>)', content, 1)
        think_part = parts[0] + parts[1] # All content from <think> to </think>
        after_think_part = parts[2] if len(parts) > 2 else ""

        # 1. Process the think part
        processed_think = re.sub(
            r'(<think>)(.*?)(</think>)',
            r'<details style="font-style: italic; background: rgba(222, 222, 222, 0.5); padding: 10px; border-radius: 10px;"><summary style="font-weight:bold;">Reasoning (Click to expand)</summary>\2</details>',
            think_part,
            flags=re.DOTALL
        )
        
        # 2. Process the part after </think>, filtering <answer> tags
        # Using re.sub to replace <answer> and </answer> with an empty string
        processed_after_think = re.sub(r'</?answer>', '', after_think_part)
        
        # 3. Concatenate the results
        return processed_think + processed_after_think

    # Only an opening <think>, indicating reasoning is in progress
    if '<think>' in content and '</think>' not in content:
        return re.sub(
            r'<think>(.*?)$',
            r'<details open style="font-style: italic; background: rgba(222, 222, 222, 0.5); padding: 10px; border-radius: 10px;"><summary style="font-weight:bold;">Reasoning...</summary>\1</details>',
            content,
            flags=re.DOTALL
        )

    # This case should be rare in streaming output, but kept for completeness
    if '<think>' not in content and '</think>' in content:
        # Also need to process content after </think>
        parts = re.split(r'(</think>)', content, 1)
        think_part = parts[0] + parts[1]
        after_think_part = parts[2] if len(parts) > 2 else ""

        processed_think = re.sub(
            r'(.*?)</think>',
            r'<details style="font-style: italic; background: rgba(222, 222, 222, 0.5); padding: 10px; border-radius: 10px;"><summary style="font-weight:bold;">Reasoning (Click to expand)</summary>\1</details>',
            think_part,
            flags=re.DOTALL
        )
        processed_after_think = re.sub(r'</?answer>', '', after_think_part)
        
        return processed_think + processed_after_think

    # If there are no <think> tags, return the content directly
    return content


def setup_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if device != "cpu":
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

# --- Gradio App Logic ---

# Gradio uses global variables or functions to load models, similar to st.cache_resource
# We cache models and tokenizers in a dictionary to avoid reloading
loaded_models = {}

def load_model_tokenizer_gradio(model_name):
    """
    Gradio version of the model loading function with caching.
    """
    if model_name in loaded_models:
        # print(f"Using cached model: {model_name}")
        return loaded_models[model_name]
    
    # print(f"Loading model: {model_name}...")
    model_path = MODEL_PATHS[model_name][0]
    model = AutoModelForCausalLM.from_pretrained(
        model_path,
        trust_remote_code=True,
        cache_dir=".cache",
    ).to(device).eval()
    tokenizer = AutoTokenizer.from_pretrained(
        model_path,
        trust_remote_code=True,
        cache_dir=".cache",
    )
    loaded_models[model_name] = (model, tokenizer)
    print("Model loaded.")
    return model, tokenizer

@spaces.GPU
def chat_fn(
    user_message, 
    history, 
    model_source,
    # Local model settings
    selected_model,
    # API settings
    api_url,
    api_model_id,
    api_model_name,
    api_key,
    # Generation parameters
    history_chat_num,
    max_new_tokens,
    temperature
):
    """
    Gradio's core chat processing function.
    It receives the current values of all UI components as input.
    """
    history = history or []
    
    # Build context for the model based on the passed, unmodified history
    chat_messages_for_model = []
    # Limit the number of history turns
    if history_chat_num > 0 and len(history) > history_chat_num:
        relevant_history_turns = history[-history_chat_num:]
    else:
        relevant_history_turns = history
        
    for user_msg, assistant_msg in relevant_history_turns:
        chat_messages_for_model.append({"role": "user", "content": user_msg})
        if assistant_msg:
            chat_messages_for_model.append({"role": "assistant", "content": assistant_msg})
    
    # Add the current user message to the model's context
    chat_messages_for_model.append({"role": "user", "content": user_message})
    
    final_chat_messages = system_prompt + chat_messages_for_model
    
    # Now, update the history for UI display
    history.extend([*chat_messages_for_model, {"role": "assistant", "content": user_message}])

    # --- Model Invocation ---
    if model_source == "API":
        try:
            from openai import OpenAI
            client = OpenAI(api_key=api_key, base_url=api_url)
            
            response = client.chat.completions.create(
                model=api_model_id,
                messages=final_chat_messages,
                stream=True,
                temperature=temperature
            )
            
            answer = ""
            for chunk in response:
                content = chunk.choices[0].delta.content or ""
                answer += content
                processed_answer = process_assistant_content(answer, model_source, api_model_name)
                history[-1]["content"] = processed_answer
                yield history, history
        
        except Exception as e:
            history[-1]["content"] = f"API call error: {str(e)}"
            yield history, history

    else: # Local Model
        try:
            model, tokenizer = load_model_tokenizer_gradio(selected_model)
            
            random_seed = random.randint(0, 2**32 - 1)
            setup_seed(random_seed)

            new_prompt = tokenizer.apply_chat_template(
                final_chat_messages,
                tokenize=False,
                add_generation_prompt=True
            )

            inputs = tokenizer(new_prompt, return_tensors="pt", truncation=True).to(device)
            streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
            
            generation_kwargs = {
                "input_ids": inputs.input_ids,
                "attention_mask": inputs.attention_mask,
                "max_new_tokens": max_new_tokens,
                "num_return_sequences": 1,
                "do_sample": True,
                "pad_token_id": tokenizer.pad_token_id,
                "eos_token_id": tokenizer.eos_token_id,
                "temperature": temperature,
                "top_p": 0.85,
                "streamer": streamer,
            }

            thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
            thread.start()
            
            answer = ""
            for new_text in streamer:
                answer += new_text
                processed_answer = process_assistant_content(answer, model_source, selected_model)
                history[-1]["content"] = processed_answer
                yield history, history
        except Exception as e:
            history[-1]["content"] = f"Local model call error: {str(e)}"
            yield history, history

# --- Gradio UI Layout ---
css = """
.gradio-container { font-family: 'sans-serif'; }
footer { display: none !important; }
"""
image_url = "https://chunte-hfba.static.hf.space/images/modern%20Huggies/Huggy%20Sunny%20hello.png"

# Define example data
prompt_datas = [
    '请介绍一下自己。',
    '你更擅长哪一个学科?',
    '鲁迅的《狂人日记》是如何批判封建礼教的?',
    '我咳嗽已经持续了两周,需要去医院检查吗?',
    '详细的介绍光速的物理概念。',
    '推荐一些杭州的特色美食吧。',
    '请为我讲解“大语言模型”这个概念。',
    '如何理解ChatGPT?',
    'Introduce the history of the United States, please.'
]

with gr.Blocks(theme='soft', css=css) as demo:
    # History state, this is the Gradio equivalent of st.session_state
    chat_history = gr.State([])
    chat_input_cache = gr.State("")

    # Top Title and Badge
    title_html = """
<div style="text-align: center;">
    <h1>Embformer: An Embedding-Weight-Only Transformer Architecture</h1>
    <div style="display: flex; justify-content: center; align-items: center; gap: 8px; margin-top: 10px;">
        <a href="https://doi.org/10.5281/zenodo.15736957">
            <img src="https://img.shields.io/badge/DOI-10.5281%2Fzenodo.15736957-blue.svg" alt="DOI">
        </a>
        <a href="https://github.com/HighCWu/embformer">
            <img src="https://img.shields.io/badge/github-%23121011.svg?style=for-the-badge&logo=github&logoColor=white" alt="code">
        </a>
        <a href="https://huggingface.co/collections/HighCWu/embformer-minimind-685be74dc761610439241bd5">
            <img src="https://img.shields.io/badge/Model-🤗-yellow" alt="model">
        </a>
    </div>
</div>
"""
    gr.HTML(title_html)
    gr.Markdown("""
This is the official demo of [Embformer: An Embedding-Weight-Only Transformer Architecture](https://doi.org/10.5281/zenodo.15736957).

**Note**: Since the model dataset used in this demo is derived from the MiniMind dataset, which contains a large proportion of Chinese content, please try to use Chinese as much as possible in the conversation.
""")

    with gr.Row():
        with gr.Column(scale=1, min_width=200):
            gr.Markdown("### Model Settings")
            
            # Model source switcher
            model_source_radio = gr.Radio(["Local Model", "API"], value="Local Model", label="Select Model Source", visible=False)
            
            # Local model settings
            with gr.Group(visible=True) as local_model_group:
                selected_model_dd = gr.Dropdown(
                    list(MODEL_PATHS.keys()), 
                    value="Embformer-MiniMind (0.1B)", 
                    label="Select Local Model"
                )

            # API settings
            with gr.Group(visible=False) as api_model_group:
                api_url_tb = gr.Textbox("http://127.0.0.1:8000/v1", label="API URL")
                api_model_id_tb = gr.Textbox("embformer-minimind", label="Model ID")
                api_model_name_tb = gr.Textbox("Embformer-MiniMind (0.1B)", label="Model Name (for feature detection)")
                api_key_tb = gr.Textbox("none", label="API Key", type="password")

            # Common generation parameters
            history_chat_num_slider = gr.Slider(0, 6, value=0, step=2, label="History Turns")
            max_new_tokens_slider = gr.Slider(256, 8192, value=1024, step=1, label="Max New Tokens")
            temperature_slider = gr.Slider(0.6, 1.2, value=0.85, step=0.01, label="Temperature")

            # Clear history button
            clear_btn = gr.Button("🗑️ Clear History")

        with gr.Column(scale=4):
            gr.Markdown("### Chat")
            
            chatbot = gr.Chatbot(
                [],
                elem_id="chatbot",
                avatar_images=(None, image_url),
                type="messages",
                height=350
            )
            chat_input = gr.Textbox(
                show_label=False,
                placeholder="Send a message to MiniMind...  (Enter to send)",
                container=False,
                scale=7,
                elem_id="chat-textbox",
            )
            examples = gr.Examples(
                examples=prompt_datas,
                inputs=chat_input, # After clicking, the example content will fill chat_input
                label="Click an example to ask (will automatically clear chat and continue)"
            )

    # --- Event Listeners and Bindings ---
    
    # Show/hide corresponding setting groups when switching model source
    def toggle_model_source_ui(source):
        return {
            local_model_group: gr.update(visible=source == "Local Model"),
            api_model_group: gr.update(visible=source == "API")
        }
    model_source_radio.change(
        fn=toggle_model_source_ui,
        inputs=model_source_radio,
        outputs=[local_model_group, api_model_group]
    )

    # Define the list of input components for the submit event
    submit_inputs = [
        chat_input_cache, chat_history, model_source_radio, selected_model_dd,
        api_url_tb, api_model_id_tb, api_model_name_tb, api_key_tb,
        history_chat_num_slider, max_new_tokens_slider, temperature_slider
    ]

    # When chat_input is submitted (user presses enter or an example is clicked), run chat_fn
    submit_event = chat_input.submit(
        fn=lambda text: ("", text),
        inputs=chat_input,
        outputs=[chat_input, chat_input_cache],
    ).then(
        fn=chat_fn,
        inputs=submit_inputs,
        outputs=[chatbot, chat_history],
    )
    
    # Event chain for clicking an example
    examples.load_input_event.then(
        fn=lambda text: ("", text, [], []), # A function to clear the history
        inputs=chat_input,
        outputs=[chat_input, chat_input_cache, chatbot, chat_history], # This affects the chatbot and chat_history
    ).then(
        fn=chat_fn, # Use the dedicated run_example function
        inputs=submit_inputs, # Pass example text and other settings
        outputs=[chatbot, chat_history],
    )

    # Clear history button logic
    def clear_history():
        return [], []
    clear_btn.click(fn=clear_history, outputs=[chatbot, chat_history])
    chatbot.clear(fn=clear_history, outputs=[chatbot, chat_history])


if __name__ == "__main__":
    # Pre-load the default model on startup
    print("Pre-loading default model...")
    load_model_tokenizer_gradio("Embformer-MiniMind (0.1B)")
    
    # Launch the Gradio app
    demo.queue().launch(share=False)