import streamlit as st st.set_page_config(page_title="GPT-4V Demo", page_icon="🧠", layout="wide") from PIL import Image import base64 from io import BytesIO from utils import get_str_to_json from pass1 import get_gpt4V_response_1 from pass2 import get_gpt4V_response_2 from examples import example_1, example_2 def clear_data(): st.session_state["story"] = "" st.session_state["goal"] = "" st.session_state["entity"] = "" st.session_state["images"] = [] for key in st.session_state.keys(): st.session_state.pop(key) # st.rerun() print(st.session_state) with st.sidebar: if st.button("Clear Inputs"): clear_data() st.title("Parameters") st.write("This is a demo of GPT-4V model. It takes a story, goal, entity and an image as input and generates a response.") st.subheader("Sampling Temperature") temperature = st.slider(label="x", min_value=0.1, max_value=1.0, value=0.5, step=0.1, label_visibility='hidden') st.write("The sampling temperature, between 0 and 1. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.") st.subheader("Entity?") entity_opt = st.radio(label="With or Without", options=[1, 0], format_func=lambda x: ["Without", "With"][x], on_change=clear_data) st.subheader("Examples") cols = st.columns(2) for i, example in enumerate([example_1, example_2]): with cols[i % len(cols)]: if st.button(f"Example {i+1}", key=f"example{i+1}"): clear_data() st.session_state["data"] = example def main(): global temperature, entity st.title('What can go wrong?') data = st.session_state.get("data", None) col1, col2 = st.columns(2) with col1: story = st.text_area("Story", placeholder="Enter the story here", value=(data.story if data else ""), key="story") entity = None if entity_opt: entity = st.text_input("Entity", placeholder="Enter the entity here", value=(data.entity if data else ""), key="entity") goal = st.text_area("Goal", placeholder="Enter the goal here", value=(data.goal if data else ""), key="goal") images = st.file_uploader("Upload Image", type=['jpg', 'png'], accept_multiple_files=True) if images: cols = st.columns(len(images)) for i, image in enumerate(images): with cols[i]: image = Image.open(image) st.image(image, caption="Uploaded Image", use_column_width=True) elif not images and data: cols = st.columns(len(data.images)) for i, imb64 in enumerate(data.images_base64): with cols[i]: image = Image.open(BytesIO(base64.b64decode(imb64))) st.image(image, caption="Example Image", use_column_width=True) if st.button("Pass 1"): st.session_state["button_2"] = False image_to_send = None if images: image_to_send = images elif data: image_to_send = data.images_base64 if not story or not goal or (entity_opt and not entity) or not image_to_send: st.error("Please fill all the fields") return with col2: with st.status("Generating response...", expanded=True): response = get_gpt4V_response_1(story, goal, entity, image_to_send, temperature=temperature) response_json = {} try: response_json = get_str_to_json(response) if "condition" not in response_json or "alternate_condition" not in response_json: raise ValueError("Invalid JSON - 1") if not entity_opt and "entity" not in response_json: raise ValueError("Invalid JSON - 2") except Exception as e: print("Exception 1", e) response_json = { "entity": "", "condition": "", "alternate_condition": "", "response": response } finally: out1 = { "entity": response_json.get("entity", None), "condition": response_json.get("condition", None), "alternate_condition": response_json.get("alternate_condition", None), "response": response_json.get("response", "") } st.session_state["output_1"] = out1 st.session_state["button_1"] = True with col2: if st.session_state.get("button_1", False): # If pass 1 is done output_1 = st.session_state.get("output_1", {}) if "response" in output_1 and output_1["response"]: st.warning(f"Failed to parse JSON. Going for full output") st.write(output_1["response"]) entity = output_1.get("entity", "") condition = output_1.get("condition", "") alternate_condition = output_1.get("alternate_condition", "") if not entity_opt: st.text_input("Entity", value=entity) st.text_area("Condition", value=condition) st.text_area("Alternate Condition", value=alternate_condition) if st.button("Pass 2"): st.session_state["button_2"] = True with st.status("Generating response...", expanded=True): response = get_gpt4V_response_2(story, goal, alternate_condition, images, temperature=temperature) try: response_json = get_str_to_json(response) if "event" not in response_json: raise ValueError("Invalid JSON - 3") except Exception as e: print("Exception 2", e) st.warning(f"Failed to parse JSON. Going for full output") response_json = { "event": response } finally: out2 = { "event": response_json.get("event", response) } st.session_state["output_2"] = out2 if st.session_state.get("button_2", False): # If pass 2 is done output_2 = st.session_state.get("output_2", {}) st.subheader("Event Leads to Alternate Condition") st.write(output_2.get("event", "")) main()