File size: 1,862 Bytes
7a69915
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import streamlit as st
from streamlit_chat import message
from PIL import Image


def init_chat_history():
    if 'question' not in st.session_state:
        st.session_state['question'] = []

    if 'answer' not in st.session_state:
        st.session_state['answer'] = []


def update_chat_messages():
    if st.session_state['answer']:
        for i in range(len(st.session_state['answer'])-1, -1, -1):
            message(st.session_state['answer'][i], key=str(i))
            message(st.session_state['question'][i],
                    is_user=True, key=str(i) + '_user')


def predict(image, input):
    if image is None or not input:
        return

    answer = st.session_state.predictor.predict_answer_from_text(image, input)
    st.session_state.question.append(input)
    st.session_state.answer.append(answer)


def show():
    init_chat_history()

    st.title('Visual Question Answering - Chatbot')
    st.markdown('''
            <h4 style='text-align: center; color: #B2BEB5;'>
            <i>Hi, I am a Visual Chatbot, capable of answering a sequence of questions about images.
                Please upload image and fire away!
            </i></h4>
            ''', unsafe_allow_html=True)

    image_col, text_col = st.columns(2)
    with image_col:
        upload_pic = st.file_uploader('Choose an image...', type=[
                                      'jpg', 'png', 'jpeg'], accept_multiple_files=False)
        if upload_pic is not None:
            image = Image.open(upload_pic)
            st.image(upload_pic, use_column_width='auto')
        else:
            st.session_state.question.clear()
            st.session_state.answer.clear()
            st.session_state.input = ''
    with text_col:
        input = st.text_input('', '', key='input')
        if input:
            predict(image, input)
        update_chat_messages()