File size: 4,136 Bytes
9f54a3b
 
877a721
 
9f54a3b
fadd816
9f54a3b
fadd816
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c0d899e
142827c
fadd816
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
877a721
fadd816
 
 
 
 
 
9f54a3b
fadd816
 
 
 
 
 
 
 
9f54a3b
 
 
fadd816
 
 
 
 
 
877a721
fadd816
 
877a721
fadd816
 
 
 
877a721
9f54a3b
fadd816
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import streamlit as st
import os
import requests
import json

entire_assistant_response = ""

def get_streamed_response(message, history, model):
    all_message = []
    
    for human, assistant in history:
        all_message.append({"role": "user", "content": human})
        all_message.append({"role": "assistant", "content": assistant})
    
    global entire_assistant_response
    entire_assistant_response = ""  # Reset the entire assistant response

    all_message.append({"role": "user", "content": message})

    url = "https://api.together.xyz/v1/chat/completions"
    payload = {
        "model": model,
        "temperature": 1.05,
        "top_p": 0.9,
        "top_k": 50,
        "repetition_penalty": 1,
        "n": 1,
        "messages": all_message,
        "stream_tokens": True,
    }

    TOGETHER_API_KEY = os.getenv('TOGETHER_API_KEY')
    headers = {
        "accept": "application/json",
        "content-type": "application/json",
        "Authorization": f"Bearer {TOGETHER_API_KEY}",
    }

    response = requests.post(url, json=payload, headers=headers, stream=True)
    response.raise_for_status()  # Ensure HTTP request was successful

    for line in response.iter_lines():
        if line:
            decoded_line = line.decode('utf-8')

            # Check for the completion signal
            if decoded_line == "data: [DONE]":
                yield entire_assistant_response  # Yield the entire response at the end
                break

            try:
                # Decode and strip any SSE format specific prefix ("data: ")
                if decoded_line.startswith("data: "):
                    decoded_line = decoded_line.replace("data: ", "")
                    chunk_data = json.loads(decoded_line)
                    content = chunk_data['choices'][0]['delta']['content']
                    entire_assistant_response += content  # Aggregate content
                    yield entire_assistant_response

            except json.JSONDecodeError:
                print(f"Invalid JSON received: {decoded_line}")
                continue
            except KeyError as e:
                print(f"KeyError encountered: {e}")
                continue

    print(entire_assistant_response)
    all_message.append({"role": "assistant", "content": entire_assistant_response})


# Initialize Streamlit app
st.title("AI Chatbot")

# Initialize session state if not present
if "messages" not in st.session_state:
    st.session_state.messages = []

# Define available models
models = {
    "Addiction Recovery": "model_addiction_recovery",
    "Mental Health": "model_mental_health",
    "Wellness": "model_wellness"
}

# Allow user to select a model
selected_model = st.selectbox("Select Model", list(models.keys()))

# Define models
model_addiction_recovery = "model_addiction_recovery"
model_mental_health = "model_mental_health"
model_wellness = "model_wellness"

# Accept user input
if prompt := st.text_input("You:", key="user_input"):
    # Display user message
    with st.spinner("AI is typing..."):
        st.session_state.messages.append({"role": "user", "content": prompt})

        # Call selected model to get response
        if selected_model == "Addiction Recovery":
            response_stream = get_streamed_response(prompt, [(m["content"] for m in st.session_state.messages[:-1])], model_addiction_recovery)
        elif selected_model == "Mental Health":
            response_stream = get_streamed_response(prompt, [(m["content"] for m in st.session_state.messages[:-1])], model_mental_health)
        elif selected_model == "Wellness":
            response_stream = get_streamed_response(prompt, [(m["content"] for m in st.session_state.messages[:-1])], model_wellness)
            
        for response in response_stream:
            st.session_state.messages.append({"role": "assistant", "content": response})

# Display chat history
for message in st.session_state.messages:
    if message["role"] == "user":
        st.text_input("You:", value=message["content"], disabled=True)
    else:
        st.text_input("AI:", value=message["content"], disabled=True)