File size: 3,735 Bytes
be1b457
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f4f8a92
be1b457
 
 
 
 
f4f8a92
d25ad55
be1b457
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import gradio as gr
import time
import base64
from openai import OpenAI

def wait_on_run(run, client, thread):
    while run.status == "queued" or run.status == "in_progress":
        run = client.beta.threads.runs.retrieve(
            thread_id=thread.id,
            run_id=run.id,
        )
        time.sleep(0.5)
    return run

def GenerateImageByCode(client, message, code_prompt):
    assistant = client.beta.assistants.create(
        name = "Chain of Image",
        instructions=code_prompt,
        model="gpt-4-1106-preview",
        tools=[{"type": "code_interpreter"}]
        )
    thread = client.beta.threads.create()
    client.beta.threads.messages.create(
        thread_id=thread.id,
        role="user",
        content=message,
    )
    run = client.beta.threads.runs.create(
        thread_id=thread.id,
        assistant_id=assistant.id,
    )
    run = wait_on_run(run, client, thread)
    run_steps = client.beta.threads.runs.steps.list(thread_id=thread.id, run_id=run.id, order="asc")
    image_id = None
    for data in run_steps.model_dump()['data']:
        if "tool_calls" in data['step_details']:
            code = data['step_details']['tool_calls'][0]['code_interpreter']['input']
            if 'image' in data['step_details']['tool_calls'][0]['code_interpreter']['outputs'][0].keys():
                image_id = data['step_details']['tool_calls'][0]['code_interpreter']['outputs'][0]['image']['file_id']
    assert image_id is not None
    image_bytes = client.files.with_raw_response.content(image_id).content
    with open(f'{image_id}.png', 'wb') as f:
        f.write(image_bytes)
    base64_image = base64.b64encode(image_bytes).decode('utf-8')
    return f"{image_id}.png", base64_image

def visual_question_answer(client, base64_image, question, vqa_prompt, max_tokens=256):
    response = client.chat.completions.create(model="gpt-4-vision-preview",
    messages=[
        {"role": "system", "content": vqa_prompt},
        {"role": "user", "content": [
            {"type": "image_url","image_url": {"url": f"data:image/jpeg;base64,{base64_image}",},},
            {"type": "text", "text": f"Question:\n{question}\nAnswer:\n"},],},
        ], max_tokens=max_tokens,)
    return response.choices[0].message.content

def chain_of_images(message, history, code_prompt, vqa_prompt, api_token, max_tokens):
    client = OpenAI(api_key=api_token)
    if len(history):
        return visual_question_answer(client, history[0][1][1], message, vqa_prompt, max_tokens=max_tokens)
    else:
        return GenerateImageByCode(client, message, code_prompt)
        

def vote(data: gr.LikeData):
    if data.liked:
        print("You upvoted this response: " + data.value)
    else:
        print("You downvoted this response: " + data.value)

demo = gr.ChatInterface(chain_of_images, 
        additional_inputs=[
            gr.Textbox("You are a research drawing assistant. Your primary role is to help visualize questions posed by users. Instead of directly answering questions, you will use code to invoke the most suitable toolkit, transforming these questions into images. This helps users quickly understand the question and find answers through visualization. You should prioritize clarity and effectiveness in your visual representations, ensuring that complex scientific or technical concepts are made accessible and comprehensible through your drawings.", label="Code Interpreter Prompt"), 
            gr.Textbox("You are a visual thinking expert. Your primary role is to answer questions about an image posed by users.", label="VQA Prompt"), 
            gr.Textbox(label="API Key"),
            gr.Slider(32, 128),
        ],
    ).queue()
if __name__ == "__main__":
    demo.launch()