File size: 5,311 Bytes
c8a0f34
df1d046
 
7b6cdd4
c755297
 
 
 
61dbd5e
2587c2e
c755297
 
 
 
 
 
e378588
34ce748
c755297
e16b761
7fb3621
c755297
d4c77e7
 
0a9e139
 
c755297
 
 
 
 
 
 
4871886
c755297
 
 
783ad43
c755297
54c11e5
 
307f1d8
 
858190c
3ed1495
307f1d8
 
 
e16b761
c020cdf
 
6e37d5d
5dc1bc3
 
 
9f17ce8
c020cdf
f573c2a
57db3f3
009017d
e16b761
 
4e0f9dd
 
 
 
 
 
5dc1bc3
16c7a2e
 
 
 
5dc1bc3
 
 
c9d3a09
 
655400f
3ed1495
 
655400f
5dc1bc3
 
 
 
 
 
 
 
 
 
 
 
 
4e0f9dd
 
783ad43
5dc1bc3
4e0f9dd
 
e16b761
 
5490950
 
 
 
c1e65f1
4e0f9dd
e16b761
4e0f9dd
5dc1bc3
 
1e5c398
 
6164e6b
571d9c3
6164e6b
51a3672
 
 
 
4c8d045
57db3f3
93a7ed5
6164e6b
3c15656
6164e6b
 
 
 
4e0f9dd
 
845068b
4e0f9dd
c1e65f1
71adbd5
571d9c3
845068b
571d9c3
50b1dd2
845068b
88ba472
9f17ce8
845068b
9f17ce8
 
 
 
 
 
 
4e0f9dd
 
c1226e0
6164e6b
2587c2e
6164e6b
 
009017d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import streamlit as st
from datetime import time as t
import time

from operator import itemgetter  
import os
import json
import getpass
import openai
import re
  
from langchain.vectorstores import Pinecone
from langchain.embeddings import OpenAIEmbeddings  
import pinecone


from results import results_agent
from filter import filter_agent
from reranker import reranker
from utils import build_filter, clean_pinecone
from keywords import keyword_agent

OPENAI_API = st.secrets["OPENAI_API"]
PINECONE_API = st.secrets["PINECONE_API"]
openai.api_key = OPENAI_API


pinecone.init(
    api_key= PINECONE_API,
    environment="gcp-starter" 
)
index_name = "use-class-db"

embeddings = OpenAIEmbeddings(openai_api_key = OPENAI_API)

index = pinecone.Index(index_name)

k = 35

st.title("USC GPT - Find the perfect class")

class_time = st.slider(
    "Filter Class Times:",
    value=(t(8, 30), t(18, 45))
)

units = st.slider(
    "Number of units",
    1, 4, 4
)

days = st.multiselect("What days are you free?",
               options = ["Mon", "Tue", "Wed", "Thu", "Fri", "Sat"],
               default = None,
               placeholder = "Any day"
        )

assistant = st.chat_message("assistant")
initial_message = "Hello, I am your GPT-powered USC Class Helper! \n How can I assist you today? \n*Note that I work best in semantics, as I primarily search class descriptions:) "



def get_rag_results(prompt):
    '''
1. Remove filters from the prompt to optimize success of the RAG-based step.
2. Query the Pinecone DB and return the top 25 results based on cosine similarity
3. Rerank the results from vector DB using a BERT-based cross encoder
    '''
    query = filter_agent(prompt, OPENAI_API)
    print("Here is the response from the filter_agent:", query)

    query += keyword_agent(query)
    print("Here is the new query with keywords added:", query)

  ##Get metadata filters  
    days_filter = list()
    start = float(class_time[0].hour) + float(class_time[0].minute) / 100.0
    end = float(class_time[1].hour) + float(class_time[1].minute) / 100.0
    query_filter = {
        "start": {"$gte": start},
        "end": {"$lte": end}
    }

    if units != "any":
        query_filter["units"] = str(int(units)) + ".0 units"

    if len(days) > 0:
        for i in range(len(days)):
            days_filter.append(days[i])
            for j in range(i+1, len(days)):
                two_day = days[i] + ", " +  days[j]
                days_filter.append(two_day)
        query_filter["days"] = {"$in": days_filter}

  ## Query the pinecone database
    response = index.query(
        vector = embeddings.embed_query(query),
        top_k = k,
        filter = query_filter,
        include_metadata = True
    )

    response, additional_metadata = clean_pinecone(response)
    if len(response) < 1:
        response = "No classes were found that matched your criteria"
        additional_metadata = "None"
    else:
        response = reranker(query, response) # BERT cross encoder for ranking 

    return response, additional_metadata

    

if "messages" not in st.session_state:
    st.session_state.messages = []
    st.session_state.messages.append({"role": "assistant", "content": initial_message})
    st.session_state.rag_responses = []
    
for message in st.session_state.messages:
    with st.chat_message(message["role"]):
        st.markdown(message["content"])
    

if prompt := st.chat_input("Find me a class to learn about the effect of pharmaceuticals on the brain!"):
    st.session_state.messages.append({"role": "user", "content": prompt})
    with st.chat_message("user"):
            st.markdown(prompt)

    with st.chat_message("assistant"):
        message_placeholder = st.empty()
        full_response = ""

        messages = [{"role": m["role"], "content": m["content"]}
                for m in st.session_state.messages[-6:]]
        message_history = " ".join([message["content"] for message in messages])
        print("Prompt is", prompt)
        
        rag_response, additional_metadata = get_rag_results(prompt)
        rag_response = " ".join([message for message in rag_response])
        st.session_state.rag_responses.append(rag_response)
        print("Here is the session state responses", st.session_state.rag_responses)
        all_rag_responses = " ".join([response for response in st.session_state.rag_responses])
        result_query = 'Original Query:' + prompt
        # '\n Additional Class Times:' + str(additional_metadata)
        assistant_response = results_agent(result_query, "Class Options from RAG:" + all_rag_responses + "\nMessage_history" + message_history)
            # assistant_response = openai.ChatCompletion.create(
            #     model = "gpt-4",
            #     messages = [
            #         {"role": m["role"], "content": m["content"]}
            #         for m in st.session_state.messages
            #     ]
            # )["choices"][0]["message"]["content"]

        ## Display response regardless of route
        for chunk in re.split(r'(\s+)', assistant_response):
            full_response += chunk + " "
            time.sleep(0.02)
            message_placeholder.markdown(full_response + "▌")
        st.session_state.messages.append({"role": "assistant", "content": full_response})