cszhzleo's picture
Create app.py
25a9d18 verified
raw
history blame
2.96 kB
import streamlit as st
#from streamlit_datalist import stDatalist
from utils import convert_to_base64, convert_to_html
import requests
IP = '127.0.0.1'
PORT= 8080
url = f'http://{IP}:{PORT}/predictions/model'
headers = {'Content-Type': 'application/json'}
st.set_page_config(page_title="AWS Inferentia2 Demo", layout="wide")
#st.set_page_config(layout="wide")
st.title("Multimodal Model on AWS Inf2")
st.subheader("LLaVA-1.6-Mistral-7B")
def upload_image():
image_list=["./images/view.jpg",
"./images/cat.jpg",
"./images/olympic.jpg",
"./images/usa.jpg",
"./images/box.jpg"]
name_list=["view(https://llava-vl.github.io/static/images/view.jpg)",
"cat",
"paris 2024",
"statue of liberty",
"box(from my camera)"]
images_all = dict(zip(name_list, image_list))
user_option = st.selectbox("Select a preset image", ["–Select–"] + name_list)
print(user_option)
if user_option!="–Select–":
image_names=[images_all[user_option]]
else:
image_names=[]
st.text("OR")
images = st.file_uploader("Upload an image to chat about", type=["png", "jpg", "jpeg"], accept_multiple_files=True)
#print(images)
# assert max number of images, e.g. 1
assert len(images) <= 1, (st.error("Please upload at most 1 image"), st.stop())
if images or image_names:
if images:
image_names=[]
# convert images to base64
images_b64 = []
for image in images+image_names:
image_b64 = convert_to_base64(image)
images_b64.append(image_b64)
# display images in multiple columns
cols = st.columns(len(images_b64)) ##only process first image
for i, col in enumerate(cols):
col.markdown(f"**Image {i+1}**")
col.markdown(convert_to_html(images_b64[i]), unsafe_allow_html=True)
break #only process first image
st.markdown("---")
return images_b64[0] #only process first image
st.stop()
@st.cache_data(show_spinner=False)
def ask_llm(prompt, byte_image):
payload = {
"prompt":prompt,
"image": byte_image,
"parameters": {
"top_k": 100,
"top_p": 0.1,
"temperature": 0.2,
}
}
response = requests.post(url, json=payload, headers=headers)
return response.text
def app():
st.markdown("---")
c1, c2 = st.columns(2)
with c2:
image_b64 = upload_image()
with c1:
question = st.chat_input("Ask a question about this image")
if not question: st.stop()
with c1:
with st.chat_message("question"):
st.markdown(question, unsafe_allow_html=True)
with st.spinner("Thinking..."):
res = ask_llm(question, image_b64)
with st.chat_message("response"):
st.write(res)
if __name__ == "__main__":
app()