File size: 1,953 Bytes
f4bea37
f4f7de5
 
 
f4bea37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f4f7de5
f4bea37
 
 
 
 
f4f7de5
 
 
f4bea37
f4f7de5
 
 
f4bea37
 
f4f7de5
f4bea37
f4f7de5
 
f4bea37
 
 
 
 
 
 
 
 
 
f4f7de5
 
f4bea37
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from streamlit import container, text_input, spinner
from streamlit_chat import message
from src.langchain_agent import init, init_agent

# Initialize agent and streamlit page
init()
agent_executor = init_agent()

# Initialize session state variables
st.session_state.update({
    'messages': set(),
    'clarina': agent_executor,
    'generated': [],
    'temp': ""
})

def generate_response(user_input):
    # Handle user input
    if user_input:
        # Save user input
        st.session_state.messages.add(user_input)
        # Get response from agent
        with spinner("Thinking..."):
            response = st.session_state.clarina.reverse_prompt_engineer(user_input)
        # Save response
        st.session_state.messages.add(response)
        st.session_state.generated.append(response)

def main():
    # Container for chat history
    response_container = container()
    # Container for text box
    container = container()

    with container:
        def clear_text():
            """Callback function to clear input text box"""
            st.session_state.temp = st.session_state.user_input
            st.session_state.user_input = ""

        text_input("user input", key="user_input", placeholder="Enter your code here", label_visibility="hidden", on_change=clear_text)  # Get user input
        generate_response(st.session_state.user_input)  # Generate response

    # Display message history
    if st.session_state.generated:
        with response_container:
            messages = st.session_state.messages
            i = 0
            while i < len(messages):
                # Display user input
                message(list(messages)[i], is_user=True, key=str(i) + '_user')
                i += 1
                if i < len(messages):
                    # Display response
                    message(list(messages)[i], is_user=False, key=str(i) + '_ai')
                i += 1

if __name__ == '__main__':
    main()