File size: 4,571 Bytes
0e37bb2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f9c99cb
 
 
0e37bb2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f9c99cb
0e37bb2
 
 
 
 
f9c99cb
3e30851
0e37bb2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f9c99cb
0e37bb2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import gradio as gr
import requests
from PIL import Image
import torch
from transformers import AutoModel, AutoProcessor
import spaces

model_path = "YannQi/R-4B"

model = AutoModel.from_pretrained(
    model_path,
    torch_dtype=torch.float32,
    trust_remote_code=True,
).to("cuda")

processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)

@spaces.GPU(duration=120)
def generate_response(message, history, thinking_mode):
    if not message:
        return "", history

    messages = []
    all_images = []

    for user_msg, asst_msg in history:
        # Process user message
        if isinstance(user_msg, str):
            user_content = [{"type": "text", "text": user_msg}]
        else:
            text = user_msg.get('text', '')
            files = user_msg.get('files', [])
            file_paths = [f.get('path', str(f)) for f in files]
            user_content = []
            img_paths = file_paths if isinstance(file_paths, list) else []
            for path in img_paths:
                try:
                    img = Image.open(path)
                    all_images.append(img)
                    user_content.append({"type": "image", "image": path})
                except:
                    pass
            if text:
                user_content.append({"type": "text", "text": text})
        messages.append({"role": "user", "content": user_content})

        # Process assistant message
        asst_text = asst_msg if isinstance(asst_msg, str) else asst_msg.get('text', '')
        messages.append({"role": "assistant", "content": [{"type": "text", "text": asst_text}]})

    # Current user message
    if isinstance(message, str):
        curr_text = message
        curr_files = []
    else:
        curr_text = message.get('text', '')
        curr_files = message.get('files', [])
    curr_user_content = []
    curr_images = []
    curr_file_paths = [f.get('path', str(f)) for f in curr_files]
    for path in curr_file_paths:
        if path.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp')):
            try:
                img = Image.open(path)
                curr_images.append(img)
                curr_user_content.append({"type": "image", "image": path})
            except:
                pass
    if curr_text:
        curr_user_content.append({"type": "text", "text": curr_text})
    if not curr_user_content:
        return "", history
    messages.append({"role": "user", "content": curr_user_content})

    # Apply chat template
    text = processor.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True,
        thinking_mode=thinking_mode
    )

    # All images
    all_images += curr_images

    # Process inputs
    inputs = processor(
        images=all_images if all_images else None,
        text=text,
        return_tensors="pt"
    ).to("cuda")

    # Generate
    with torch.no_grad():
        generated_ids = model.generate(
            **inputs,
            max_new_tokens=512,
            do_sample=True,
            temperature=0.7
        )
    output_ids = generated_ids[0][len(inputs.input_ids[0]):]
    output_text = processor.decode(
        output_ids,
        skip_special_tokens=True,
        clean_up_tokenization_spaces=False
    )

    # Prepare display for current user message
    user_display = message
    new_history = history + [(user_display, output_text)]

    return "", new_history

with gr.Blocks(title="Transformers Chat") as demo:
    gr.Markdown("# Using 🤗 Transformers to Chat")
    gr.Markdown("Select thinking mode: auto (auto-thinking), long (thinking), short (non-thinking). Default: auto.")
    chatbot = gr.Chatbot(type="tuples", height=500, label="Chat")
    with gr.Row():
        msg = gr.MultimodalTextbox(
            placeholder="Type your message or upload images...",
            file_types=[".jpg", ".jpeg", ".png", ".gif", ".bmp"],
            file_count="multiple",
            label="Message"
        )
        mode = gr.Dropdown(
            choices=["auto", "long", "short"],
            value="auto",
            label="Thinking Mode",
            interactive=True
        )
    with gr.Row():
        submit_btn = gr.Button("Send", variant="primary", scale=3)
        clear_btn = gr.Button("Clear", scale=1)
    submit_btn.click(generate_response, [msg, chatbot, mode], [msg, chatbot])
    msg.submit(generate_response, [msg, chatbot, mode], [msg, chatbot])
    clear_btn.click(lambda: ([], ""), None, [chatbot, msg], queue=False)

if __name__ == "__main__":
    demo.launch()