File size: 3,245 Bytes
7a69915
6fda2a0
7a69915
4c71f4e
7a69915
4c71f4e
5560825
7a69915
 
 
 
 
 
 
 
 
 
 
460b215
 
 
7a69915
 
 
 
 
 
 
4c71f4e
 
 
 
 
 
 
6fda2a0
 
 
4c71f4e
25f8b3c
 
 
6fda2a0
4c71f4e
7a69915
 
 
7aa61b0
7a69915
 
 
 
 
 
 
6fda2a0
 
 
 
 
 
8009ea0
6fda2a0
 
 
 
 
 
 
 
 
 
7a69915
 
6fda2a0
 
 
 
 
 
4c71f4e
 
7a69915
5560825
 
7a69915
6fda2a0
7a69915
7aa61b0
7a69915
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
from streamlit_chat import message
from st_clickable_images import clickable_images
from PIL import Image
from helper import *

import streamlit as st

def init_chat_history():
    if 'question' not in st.session_state:
        st.session_state['question'] = []

    if 'answer' not in st.session_state:
        st.session_state['answer'] = []


def update_chat_messages():
    if st.session_state['answer']:
        for i in range(len(st.session_state['answer'])-1, -1, -1):
            message(st.session_state['answer'][i], key=str(
                i), avatar_style='bottts', seed=123)
            message(st.session_state['question'][i], avatar_style='micah', seed=45,
                    is_user=True, key=str(i) + '_user')


def predict(image, input):
    if image is None or not input:
        return

    with st.spinner('Preparing answer...'):
        answer = request_answer(st.session_state.uploaded_image, input)
        st.session_state.question.append(input)
        st.session_state.answer.append(answer)
        while len(st.session_state.question) >= 5:
            st.session_state.answer.pop(0)
            st.session_state.question.pop(0)


def upload_image_callback():
    st.session_state.uploaded_image = upload_image_to_server()
    st.session_state.question = []
    st.session_state.answer = []
    st.session_state.input = ''


def show():
    init_chat_history()

    st.title('Welcome to Visual Question Answering - Chatbot')
    st.markdown('''
            <h4 style='text-align: center; color: #B2BEB5;'>
            <i>Hi, I am a Visual Chatbot, capable of answering a sequence of questions about images.
                Please upload image and fire away!
            </i></h4>
            ''', unsafe_allow_html=True)

    update_gallery_images()
    if 'gallery' in st.session_state:
        clicked = clickable_images(
            st.session_state.gallery,
            titles=[f"Image #{str(i)}" for i in range(2)],
            div_style={"display": "flex",
                       "justify-content": "center", "flex-wrap": "wrap"},
            img_style={"margin": "5px", "height": "100px"},
        )

        if 'clicked' not in st.session_state or st.session_state.clicked != clicked:
            st.session_state.uploaded_image = st.session_state.gallery_images[clicked]
            st.session_state.clicked = clicked
            st.session_state.question = []
            st.session_state.answer = []
            st.session_state.input = ''

    image_col, text_col = st.columns(2)
    with image_col:
        st.file_uploader('Select an image...', type=[
            'jpg', 'jpeg'], accept_multiple_files=False,
            on_change=upload_image_callback, key='uploader')

        if st.session_state.uploaded_image is not None:
            image = Image.open(st.session_state.uploaded_image)
            st.image(st.session_state.uploaded_image,
                     use_column_width='always')
        else:
            st.session_state.question = []
            st.session_state.answer = []
            st.session_state.input = ''

    with text_col:
        input = st.text_input('Enter question: ', '', key='input')
        if input:
            predict(image, input)
        update_chat_messages()