File size: 8,273 Bytes
c09dbef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import streamlit as st

from session_state import set_session_state
from chat import chat_completion
from template import visual_default_prompt
from model_config import visual_model_list

def visualChat(api_key: str):
    set_session_state("visual", visual_default_prompt, 4096, 0.50)
    
    if st.session_state.visual_msg == []:
        disable = True
    elif st.session_state.visual_msg != []:
        disable = False
    
    with st.sidebar:
        clear_btn = st.button("Clear", "vi_clear", type="primary", use_container_width=True, disabled=disable)
        undo_btn = st.button("Undo", "vi_undo", use_container_width=True, disabled=disable)
        retry_btn = st.button("Retry", "vi_retry", use_container_width=True, disabled=disable)

        model_list = visual_model_list
        model = st.selectbox("Model", model_list, 0, key="vi_model", disabled=not disable)

        system_prompt = st.text_area("System Prompt", st.session_state.visual_sys, key="vi_sys", disabled=not disable)

        with st.expander("Advanced Setting"):
            tokens = st.slider("Max Tokens", 1, 4096, st.session_state.visual_tokens, 1, key="vi_tokens", disabled=not disable)
            temp = st.slider("Temperature", 0.00, 2.00, st.session_state.visual_temp, 0.01, key="vi_temp", disabled=not disable)
            topp = st.slider("Top P", 0.01, 1.00, st.session_state.visual_topp, 0.01, key="vi_topp", disabled=not disable)
            freq = st.slider("Frequency Penalty", -2.00, 2.00, st.session_state.visual_freq, 0.01, key="vi_freq", disabled=not disable)
            pres = st.slider("Presence Penalty", -2.00, 2.00, st.session_state.visual_pres, 0.01, key="vi_pres", disabled=not disable)
            if st.toggle("Set stop", key="vi_stop_toggle", disabled=not disable):
                st.session_state.general_stop = []
                stop_str = st.text_input("Stop", st.session_state.visual_stop_str, key="vi_stop_str", disabled=not disable)
                st.session_state.visual_stop_str = stop_str
                submit_stop = st.button("Submit", "vi_submit_stop", disabled=not disable)
                if submit_stop and stop_str:
                    st.session_state.visual_stop.append(st.session_state.visual_stop_str)
                    st.session_state.visual_stop_str = ""
                    st.rerun()
                if st.session_state.visual_stop:
                    for stop_str in st.session_state.visual_stop:
                        st.markdown(f"`{stop_str}`")

        st.session_state.visual_sys = system_prompt
        st.session_state.visual_tokens = tokens
        st.session_state.visual_temp = temp
        st.session_state.visual_topp = topp
        st.session_state.visual_freq = freq
        st.session_state.visual_pres = pres

    image_type = ["PNG", "JPG", "JPEG"]
    uploaded_image: list = st.file_uploader("Upload an image", type=image_type, accept_multiple_files=True, key="uploaded_image", disabled=not disable)
    base64_image_list = []
    if uploaded_image is not None:
        from process_image import image_processor
        with st.expander("Image"):
            for i in uploaded_image:
                st.image(uploaded_image, output_format="PNG")
                base64_image_list.append(image_processor(i))

    for i in st.session_state.visual_cache:
        with st.chat_message(i["role"]):
            st.markdown(i["content"])
    
    if query := st.chat_input("Say something...", key="vi_query", disabled=base64_image_list==[]):
        with st.chat_message("user"):
            st.markdown(query)
        
        st.session_state.visual_msg.append({"role": "user", "content": query})

        if len(st.session_state.visual_msg) == 1:
            messages = [
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": []}
            ]
            for base64_img in base64_image_list:
                img_url_obj = {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_img}", "detail": "high"}}
                messages[1]["content"].append(img_url_obj)
            messages[1]["content"].append({"type": "text", "text": query})
        elif len(st.session_state.visual_msg) > 1:
            messages = [
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": []}
            ]
            for base64_img in base64_image_list:
                img_url_obj = {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_img}", "detail": "high"}}
                messages[1]["content"].append(img_url_obj)
            messages[1]["content"].append({"type": "text", "text": st.session_state.visual_msg[0]["content"]})
            messages += st.session_state.visual_msg[1:]
        
        with st.chat_message("assistant"):
            try:
                response = chat_completion(api_key, model, messages, tokens, temp, freq, pres, st.session_state.visual_stop)
                result = st.write_stream(chunk.choices[0].delta.content for chunk in response if chunk.choices[0].delta.content is not None)
                st.session_state.general_msg.append({"role": "assistant", "content": result})
            except Exception as e:
                st.error(f"Error occured: {e}")
        
        st.session_state.visual_cache = st.session_state.visual_msg
        st.rerun()

    if clear_btn:
        st.session_state.visual_sys = visual_default_prompt
        st.session_state.visual_tokens = 4096
        st.session_state.visual_temp = 0.50
        st.session_state.visual_topp = 0.70
        st.session_state.visual_freq = 0.00
        st.session_state.visual_pres = 0.00
        st.session_state.visual_msg = []
        st.session_state.visual_cache = []
        st.session_state.visual_stop = None
        st.rerun()

    if undo_btn:
        del st.session_state.visual_msg[-1]
        del st.session_state.visual_cache[-1]
        st.rerun()

    if retry_btn:
        st.session_state.visual_msg.pop()
        st.session_state.visual_cache = []
        st.session_state.visual_retry = True
        st.rerun()
    if st.session_state.visual_retry:
        for i in st.session_state.visual_msg:
            with st.chat_message(i["role"]):
                st.markdown(i["content"])

        if len(st.session_state.visual_msg) == 1:
            messages = [
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": []}
            ]
            for base64_img in base64_image_list:
                img_url_obj = {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_img}", "detail": "high"}}
                messages[1]["content"].append(img_url_obj)
            messages[1]["content"].append({"type": "text", "text": st.session_state.visual_msg[0]["content"]})
        elif len(st.session_state.visual_msg) > 1:
            messages = [
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": []}
            ]
            for base64_img in base64_image_list:
                img_url_obj = {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_img}", "detail": "high"}}
                messages[1]["content"].append(img_url_obj)
            messages[1]["content"].append({"type": "text", "text": st.session_state.visual_msg[0]["content"]})
            messages += st.session_state.visual_msg[1:]
        
        with st.chat_message("assistant"):
            try:
                response = chat_completion(api_key, model, messages, tokens, temp, freq, pres, st.session_state.visual_stop)
                result = st.write_stream(chunk.choices[0].delta.content for chunk in response if chunk.choices[0].delta.content is not None)
                st.session_state.general_msg.append({"role": "assistant", "content": result})
            except Exception as e:
                st.error(f"Error occured: {e}")
        
        st.session_state.visual_cache = st.session_state.visual_msg
        st.session_state.visual_retry = False
        st.rerun()