File size: 4,213 Bytes
3aba7d8
 
2de941d
3aba7d8
 
 
 
 
 
 
2de941d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3aba7d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2de941d
3aba7d8
 
 
 
 
 
 
 
3a3b167
3aba7d8
 
2de941d
 
 
 
 
 
3aba7d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2de941d
3aba7d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2de941d
3aba7d8
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import streamlit as st
from PIL import Image
from function import bounding_box
from tempfile import NamedTemporaryFile
import os
from function import ImageCaptionTools, ObjectDetectionTool
from langchain.agents import initialize_agent, AgentType
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain.memory import ConversationBufferWindowMemory
from htmlTemplate import css, bot_template, user_template
import random

DIR = './temp'
if not os.path.exists(DIR):
    os.mkdir(DIR)

if "image_processed" not in st.session_state:
    DIR_PATH = os.path.join(DIR, str(random.randint(1,999999999)))
    st.session_state.dirpath = DIR_PATH
    if not os.path.exists(DIR_PATH):
        os.mkdir(DIR_PATH)

def delete_temp_files():
    for filename in os.listdir(st.session_state.dirpath):
        file_path = os.path.join(st.session_state.dirpath, filename)
        if os.path.isfile(file_path):
            os.unlink(file_path)



# initialize Agent
def agent_init():
    tools = [ImageCaptionTools(), ObjectDetectionTool()]
    llm = ChatGoogleGenerativeAI(model="gemini-pro")
    memory = ConversationBufferWindowMemory(memory_key='chat_history',
                                            k=5,
                                            return_messages=True)
    agents = initialize_agent(
        agent=AgentType.CONVERSATIONAL_REACT_DESCRIPTION,
        llm=llm,
        tools=tools,
        max_iterations=5,
        verbose=True,
        memory=memory
    )
    return agents



def main():
    st.set_page_config(
        page_title="Chat with an Image",
        page_icon="🖼️",
        layout="wide"
    )
    st.write(css, unsafe_allow_html=True)
    st.title("Chat with an Image 🖼️")
    agent = agent_init()

    # Check if the page has been reloaded
    if 'reloaded' not in st.session_state:
        st.session_state.reloaded = False
    else:
        st.session_state.reloaded = True

    if "image_processed" not in st.session_state:
        st.session_state.image_processed = None

    if "result_bounding" not in st.session_state:
        st.session_state.result_bounding = None

    # image_path = 'documentation\photo_1.jpg'

    col1, col2 = st.columns([1, 1])
    with col1:
        image_upload = st.file_uploader(label="Please Upload Your Image", type=['jpg', 'png', 'jpeg'])
        if not image_upload:
            st.warning("Please upload your image")
        else:
            st.image(
                image_upload,
                use_column_width=True
            )
        click_process = st.button("Process Image", disabled=not image_upload)
        if click_process:
            delete_temp_files()
            with NamedTemporaryFile(dir=st.session_state.dirpath, delete=False) as f:
                f.write(image_upload.getbuffer())
                st.session_state.image_path = f.name
                st.session_state.image_processed = True

        if (st.session_state.image_processed and st.session_state.result_bounding is None) or click_process:
            with st.spinner("Please Wait"):
                result_bounding = bounding_box(st.session_state.image_path)
                st.session_state.result_bounding = result_bounding

        # Expander to show/hide image
        if st.session_state.result_bounding is not None:
            with st.expander("Show Image (Bounding Box)"):
                st.image(st.session_state.result_bounding)

    with col2:
        user_question = st.text_area("Ask About your image",
                                     disabled=not st.session_state.image_processed,
                                     max_chars=150)
        click_ask = st.button("Ask Question", disabled=not st.session_state.image_processed)
        if click_ask:
            st.write(user_template.replace("{{MSG}}", user_question), unsafe_allow_html=True)
            with st.spinner("Doraemon Searching for Answer🔎"):
                chat_history = agent.invoke({"input": f"{user_question}, this is the image path: {st.session_state.image_path}"})
                response = chat_history['output']
                st.write(bot_template.replace("{{MSG}}", response), unsafe_allow_html=True)

if __name__ == "__main__":
    main()