|
import os |
|
import sys |
|
|
|
import gradio as gr |
|
|
|
sys.path.append("./ctm") |
|
from ctm.ctms.ctm_base import BaseConsciousnessTuringMachine |
|
from PIL import Image |
|
import io |
|
import base64 |
|
|
|
ctm = BaseConsciousnessTuringMachine() |
|
ctm.add_supervisor("gpt4_supervisor") |
|
|
|
DEPLOYED = os.getenv("DEPLOYED", "true").lower() == "true" |
|
|
|
|
|
|
|
def convert_base64(image_array): |
|
if image_array is None: |
|
return None |
|
image = Image.fromarray(image_array) |
|
buffer = io.BytesIO() |
|
image.save(buffer, format="PNG") |
|
byte_data = buffer.getvalue() |
|
base64_string = base64.b64encode(byte_data).decode("utf-8") |
|
return base64_string |
|
|
|
|
|
|
|
def add_processor(processor_name, display_name, state): |
|
print("add processor ", processor_name) |
|
ctm.add_processor(processor_name) |
|
print(ctm.processor_group_map) |
|
print(len(ctm.processor_list)) |
|
return gr.Button( |
|
value=display_name, |
|
elem_id="selected" |
|
) |
|
|
|
|
|
def processor_tab(): |
|
|
|
text_processors = [ |
|
"gpt4_text_emotion_processor", |
|
"gpt4_text_summary_processor", |
|
"gpt4_speaker_intent_processor", |
|
"roberta_text_sentiment_processor", |
|
] |
|
vision_processors = [ |
|
"gpt4v_cloth_fashion_processor", |
|
"gpt4v_face_emotion_processor", |
|
"gpt4v_ocr_processor", |
|
"gpt4v_posture_processor", |
|
"gpt4v_scene_location_processor", |
|
] |
|
|
|
with gr.Accordion('Select your processors here.'): |
|
with gr.Row(): |
|
with gr.Blocks(): |
|
for model_name in text_processors: |
|
display_name = ( |
|
model_name.replace("processor", "") |
|
.replace("_", " ") |
|
.title() |
|
) |
|
|
|
button = gr.Button( |
|
value=display_name, |
|
elem_id="unselected" |
|
) |
|
processor_name = gr.Textbox( |
|
value=model_name, visible=False |
|
) |
|
display_name = gr.Textbox( |
|
value=display_name, visible=False |
|
) |
|
button.click( |
|
fn=add_processor, |
|
inputs=[processor_name, display_name, gr.State()], |
|
outputs=[button], |
|
) |
|
|
|
for model_name in vision_processors: |
|
display_name = ( |
|
model_name.replace("processor", "") |
|
.replace("_", " ") |
|
.title() |
|
) |
|
|
|
button = gr.Button( |
|
value=display_name, |
|
elem_id="unselected" |
|
) |
|
processor_name = gr.Textbox( |
|
value=model_name, visible=False |
|
) |
|
display_name = gr.Textbox( |
|
value=display_name, visible=False |
|
) |
|
button.click( |
|
fn=add_processor, |
|
inputs=[processor_name, display_name, gr.State()], |
|
outputs=[button], |
|
) |
|
|
|
|
|
def forward(query, content, image, state): |
|
image = convert_base64(image) |
|
state["question"] = query |
|
ask_processors_output_info, state = ask_processors( |
|
query, content, image, state |
|
) |
|
uptree_competition_output_info, state = uptree_competition(state) |
|
ask_supervisor_output_info, state = ask_supervisor(state) |
|
|
|
ctm.downtree_broadcast(state["winning_output"]) |
|
ctm.link_form(state["processor_output"]) |
|
return ( |
|
ask_processors_output_info, |
|
uptree_competition_output_info, |
|
ask_supervisor_output_info, |
|
state, |
|
gr.Button( |
|
value="Update CTM", |
|
elem_id="selected-ctm", |
|
) |
|
) |
|
|
|
|
|
def ask_processors(query, text, image, state): |
|
|
|
processor_output = ctm.ask_processors( |
|
query=query, |
|
text=text, |
|
image=image, |
|
) |
|
output_info = "" |
|
for name, info in processor_output.items(): |
|
gist = info["gist"].replace("\n", "").strip() |
|
output_info += f"<{name}>\n{gist}\n\n" |
|
state["processor_output"] = processor_output |
|
return output_info, state |
|
|
|
|
|
def uptree_competition(state): |
|
winning_output = ctm.uptree_competition(state["processor_output"]) |
|
state["winning_output"] = winning_output |
|
output_info = ( |
|
"<{}>\n{}".format( |
|
winning_output["name"], winning_output["gist"].replace("\n", "").strip() |
|
) |
|
) |
|
return output_info, state |
|
|
|
|
|
def ask_supervisor(state): |
|
question = state["question"] |
|
winning_output = state["winning_output"] |
|
answer, score = ctm.ask_supervisor(question, winning_output) |
|
output_info = answer |
|
state["answer"] = answer |
|
state["score"] = score |
|
return output_info, state |
|
|
|
|
|
def input_tab(): |
|
state = gr.State({}) |
|
|
|
with gr.Accordion("Enter your input here."): |
|
with gr.Row(): |
|
query = gr.Textbox(label="Query", placeholder="Type your query here", lines=3) |
|
|
|
with gr.Row(): |
|
text = gr.Textbox(label="Text Input", placeholder="Input text data", lines=11) |
|
image = gr.Image(label="Image Input") |
|
return query, text, image, state |
|
|
|
def output_tab(query, text, image, state): |
|
|
|
with gr.Accordion("Check your outputs here."): |
|
processors_output = gr.Textbox(label="STM Chunks", visible=True, lines=5) |
|
competition_output = gr.Textbox(label="Winning Chunk", visible=True, lines=3) |
|
supervisor_output = gr.Textbox(label="Answer", visible=True, lines=2) |
|
forward_button = gr.Button("Launch CTM") |
|
|
|
forward_button.click( |
|
fn=forward, |
|
inputs=[query, text, image, state], |
|
outputs=[ |
|
processors_output, |
|
competition_output, |
|
supervisor_output, |
|
state, |
|
forward_button, |
|
], |
|
) |
|
|
|
|
|
|
|
def main(): |
|
with gr.Blocks( |
|
theme="gradio/monochrome", |
|
css="""#chat_container {height: 820px; width: 1000px; margin-left: auto; margin-right: auto;} |
|
#chatbot {height: 600px; overflow: auto;} |
|
#create_container {height: 750px; margin-left: 0px; margin-right: 0px;} |
|
#tokenizer_renderer span {white-space: pre-wrap} |
|
#selected {background-color: orange; width: 180px} |
|
#unselected {width: 180px} |
|
#selected-ctm {background-color: orange;} |
|
""", |
|
) as demo: |
|
gr.Image("images/banner.jpg", elem_id="banner-image", show_label=False) |
|
with gr.Row(): |
|
with gr.Column(): |
|
processor_tab() |
|
query, text, image, state = input_tab() |
|
with gr.Column(): |
|
output_tab(query, text, image, state) |
|
return demo |
|
|
|
|
|
def start_demo(): |
|
demo = main() |
|
if DEPLOYED: |
|
demo.queue(api_open=False).launch(show_api=False) |
|
else: |
|
demo.queue() |
|
demo.launch(share=False, server_name="0.0.0.0") |
|
|
|
|
|
if __name__ == "__main__": |
|
start_demo() |
|
|