Spaces:
Sleeping
Sleeping
from openai import OpenAI | |
import streamlit as st | |
from utils import im_2_b64 | |
import pickle | |
from upload import upload_file, get_file | |
import clipboard | |
RANDOM_SEED = 42 | |
st.title("ChatGPT with Vision") | |
client = OpenAI(api_key=st.secrets["OPENAI_KEY"]) | |
if "messages" not in st.session_state: | |
st.session_state.messages = [] | |
if "uploader_key" not in st.session_state: | |
st.session_state["uploader_key"] = 0 | |
if "id" in st.query_params: | |
id = st.query_params["id"] | |
data = get_file(id, 'chatgpt-vision-007') | |
st.session_state.messages = pickle.loads(data) | |
def clear_uploader(): | |
st.session_state["uploader_key"] += 1 | |
st.rerun() | |
def undo(): | |
if len(st.session_state.messages) > 0: | |
st.session_state.messages.pop() | |
st.session_state.messages.pop() | |
st.rerun() | |
def share(): | |
data = pickle.dumps(st.session_state.messages) | |
id = upload_file(data, 'chatgpt-vision-007') | |
return id | |
with st.sidebar: | |
if st.button("Share"): | |
id = share() | |
url = f"https://umbc-nlp-chatgpt-vision.hf.space/?id={id}" | |
# st.code(f"https://umbc-nlp-chatgpt-vision.hf.space/?id={id}") | |
clipboard.copy(url) | |
st.write(f"URL copied to clipboard: {url}") | |
if st.button("Undo"): | |
undo() | |
if st.button("Clear chat"): | |
st.session_state.messages = [] | |
clear_uploader() | |
with st.expander("Advanced Configuration"): | |
st.subheader("Temperature") | |
temperature = st.slider(label="x", min_value=0.1, max_value=1.0, value=0.5, step=0.1, label_visibility='hidden') | |
st.subheader("Max Tokens") | |
max_tokens = st.slider(label="x", min_value=32, max_value=1024, value=256, step=32, label_visibility='hidden') | |
with st.expander("Image Input", expanded=True): | |
images = st.file_uploader( | |
"Image Upload", | |
accept_multiple_files=True, | |
type=["png", "jpg", "jpeg"], | |
key=st.session_state["uploader_key"], | |
label_visibility="collapsed", | |
) | |
for message in st.session_state.messages: | |
with st.chat_message(message["role"]): | |
contents = message["content"] | |
for content in contents: | |
if content["type"] == "text": | |
st.markdown(content["text"]) | |
number_of_images = sum(1 for c in contents if c["type"] == "image_url") | |
if number_of_images > 0: | |
cols = st.columns(number_of_images) | |
i = 0 | |
for content in contents: | |
if content["type"] == "image_url": | |
with cols[i]: | |
st.image(content["image_url"]["url"]) | |
i += 1 | |
def push_message(role, content, images=None): | |
contents = [] | |
contents.append({"type": "text", "text": content}) | |
if images: | |
for image in images: | |
image_b64 = im_2_b64(image) | |
image_url = f"data:image/jpeg;base64,{image_b64.decode('utf-8')}" | |
obj = { | |
"type": "image_url", | |
"image_url": { | |
"url": image_url, | |
}, | |
} | |
contents.append(obj) | |
message = {"role": role, "content": contents} | |
st.session_state.messages.append(message) | |
return message | |
chat_input_disabled = False | |
if prompt := st.chat_input("Type a message", key="chat_input", disabled=chat_input_disabled): | |
push_message("user", prompt, images) | |
with st.chat_message("user"): | |
st.markdown(prompt) | |
if images: | |
cols = st.columns(len(images)) | |
for i, image in enumerate(images): | |
with cols[i]: | |
st.image(image) | |
with st.chat_message("assistant"): | |
messages = [ | |
{"role": m["role"], "content": m["content"]} | |
for m in st.session_state.messages | |
] | |
# print("api call", messages) | |
chat_input_disabled = True | |
stream = client.chat.completions.create( | |
model="gpt-4-vision-preview", | |
messages=messages, | |
stream=True, | |
seed=RANDOM_SEED, | |
temperature=temperature, | |
max_tokens=max_tokens, | |
) | |
response = st.write_stream(stream) | |
push_message("assistant", response) | |
chat_input_disabled = False | |
clear_uploader() | |