File size: 3,153 Bytes
89fc497
1876937
3a307c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89fc497
3a307c7
 
 
 
 
 
 
89fc497
3a307c7
 
 
 
 
 
 
16f57b8
 
3a307c7
 
16f57b8
 
 
 
 
3a307c7
 
 
 
16f57b8
3a307c7
 
 
89fc497
3a307c7
 
 
 
 
 
 
 
89fc497
3a307c7
 
 
89fc497
3a307c7
 
 
89fc497
3a307c7
 
 
 
 
 
 
89fc497
 
3a307c7
 
 
89fc497
 
3a307c7
16f57b8
3a307c7
 
 
 
 
bcee032
1876937
3a307c7
89fc497
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98

import gradio as gr
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration, TextIteratorStreamer
from transformers.image_utils import load_image
from threading import Thread
import time
import torch
import spaces

MODEL_ID = "TheEighthDay/SeekWorld_RL_PLUS"
processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
    MODEL_ID,
    trust_remote_code=True,
    torch_dtype=torch.bfloat16
).to("cpu").eval()

@spaces.GPU
def model_inference(input_dict, history):
    text = input_dict["text"]
    files = input_dict["files"]

    # Load images if provided
    if len(files) > 1:
        images = [load_image(image) for image in files]
    elif len(files) == 1:
        images = [load_image(files[0])]
    else:
        images = []

    # Validate input
    if text == "" and not images:
        gr.Error("Please input a query and optionally image(s).")
        return
    if text == "" and images:
        gr.Error("Please input a text query along with the image(s).")
        return
    system_message = "You are a helpful assistant good at solving problems with step-by-step reasoning. You should first think about the reasoning process in the mind and then provide the user with the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags."
    question_text = "In which country and within which first-level administrative region of that country was this picture taken? Please answer in the format of <answer>$country,administrative_area_level_1$</answer>?"
    # Prepare messages for the model
    messages = [
        {
            "role": "system",
            "content": system_message
            
        },
        {
            "role": "user",
            "content": [
                *[{"type": "image", "image": image} for image in images],
                {"type": "text", "text": question_text},
            ],
        }
    ]

    # Apply chat template and process inputs
    prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = processor(
        text=[prompt],
        images=images if images else None,
        return_tensors="pt",
        padding=True,
    ).to("cpu")

    # Set up streamer for real-time output
    streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
    generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)

    # Start generation in a separate thread
    thread = Thread(target=model.generate, kwargs=generation_kwargs)
    thread.start()

    # Stream the output
    buffer = ""
    yield "Thinking..."
    for new_text in streamer:
        buffer += new_text
        time.sleep(0.01)
        yield buffer


# Example inputs
examples = [
]

demo = gr.ChatInterface(
    fn=model_inference,
    description="# **SeekWorld**",
    examples=examples,
    textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image"], file_count="multiple"),
    stop_btn="Stop Generation",
    multimodal=True,
    cache_examples=False,
)

demo.launch(debug=True)