import streamlit as st from utils import convert_to_base64, convert_to_html import requests import boto3 import sagemaker region = os.getenv("region") sm_endpoint_name = os.getenv("sm_endpoint_name") access_key = os.getenv("access_key") secret_key = os.getenv("secret_key") hf_token = os.getenv("hf_read_access") session = boto3.Session( aws_access_key_id=access_key, aws_secret_access_key=secret_key, region_name=region ) sess = sagemaker.Session(boto_session=session) smr = session.client("sagemaker-runtime") headers = {'Content-Type': 'application/json'} st.set_page_config(page_title="AWS Inferentia2 Demo", layout="wide") #st.set_page_config(layout="wide") st.title("Multimodal Model on AWS Inf2") st.subheader("LLaVA-1.6-Mistral-7B") def upload_image(): image_list=["./images/view.jpg", "./images/cat.jpg", "./images/olympic.jpg", "./images/usa.jpg", "./images/box.jpg"] name_list=["view(https://llava-vl.github.io/static/images/view.jpg)", "cat", "paris 2024", "statue of liberty", "box(from my camera)"] images_all = dict(zip(name_list, image_list)) user_option = st.selectbox("Select a preset image", ["–Select–"] + name_list) print(user_option) if user_option!="–Select–": image_names=[images_all[user_option]] else: image_names=[] st.text("OR") images = st.file_uploader("Upload an image to chat about", type=["png", "jpg", "jpeg"], accept_multiple_files=True) #print(images) # assert max number of images, e.g. 1 assert len(images) <= 1, (st.error("Please upload at most 1 image"), st.stop()) if images or image_names: if images: image_names=[] # convert images to base64 images_b64 = [] for image in images+image_names: image_b64 = convert_to_base64(image) images_b64.append(image_b64) # display images in multiple columns cols = st.columns(len(images_b64)) ##only process first image for i, col in enumerate(cols): col.markdown(f"**Image {i+1}**") col.markdown(convert_to_html(images_b64[i]), unsafe_allow_html=True) break #only process first image st.markdown("---") return images_b64[0] #only process first image st.stop() @st.cache_data(show_spinner=False) def ask_llm(prompt, byte_image): payload = { "prompt":prompt, "image": byte_image, "parameters": { "top_k": 100, "top_p": 0.1, "temperature": 0.2, } } #response = requests.post(url, json=payload, headers=headers) response = smr.invoke_endpoint( EndpointName=sm_endpoint_name, Body=json.dumps(payload), ContentType="application/json", ) #return response.text return response def app(): st.markdown("---") c1, c2 = st.columns(2) with c2: image_b64 = upload_image() with c1: question = st.chat_input("Ask a question about this image") if not question: st.stop() with c1: with st.chat_message("question"): st.markdown(question, unsafe_allow_html=True) with st.spinner("Thinking..."): res = ask_llm(question, image_b64) with st.chat_message("response"): st.write(res) if __name__ == "__main__": app()