po5302006 commited on
Commit
c91731d
1 Parent(s): e2861d0

added from main on GitHub, implemented QAchain on chatbot

Browse files
Files changed (2) hide show
  1. .streamlit/config.toml +6 -0
  2. Home.py +119 -35
.streamlit/config.toml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ [theme]
2
+ primaryColor = "#FF0000" # Bright red for interactive elements
3
+ backgroundColor = "#000000" # Black background for the main content area
4
+ secondaryBackgroundColor = "#121212" # A slightly lighter shade of black for the sidebar
5
+ textColor = "#FFFFFF" # White for text to contrast the dark background
6
+ font = "monospace" # A techy font style.
Home.py CHANGED
@@ -5,6 +5,7 @@ import time
5
  from module.__custom__ import *
6
  from streamlit_extras.switch_page_button import switch_page
7
 
 
8
  # Openai API Key
9
  import openai
10
  import json
@@ -25,8 +26,6 @@ def read_api_key_from_secrets(file_path='secrets.json'):
25
 
26
  # Example usage
27
  try:
28
- # key = read_api_key_from_secrets()
29
- # key = os.environ['key']
30
  openai.api_key = os.environ['key']
31
  os.environ['OPENAI_API_KEY'] = os.environ['key']
32
  print(f"OpenAI API Key Found")
@@ -56,31 +55,111 @@ db_plot = Chroma(
56
  embedding_function=embedding
57
  )
58
 
59
- metadata_field_info = [
60
- AttributeInfo(
61
- name="name",
62
- description="The name of the video game on steam",
63
- type="string",
64
- )
65
- ]
66
- document_content_description = "Brief summary of a video game on Steam"
67
-
68
 
69
  with st.sidebar: is_plot = st.toggle('Enable Plot')
70
  db_selected = db_cos
71
  if is_plot: db_selected = db_plot
72
 
73
 
74
- retriever = SelfQueryRetriever.from_llm(
75
- llm,
76
- db_selected,
77
- document_content_description,
78
- metadata_field_info,
79
- enable_limit=True,
 
 
 
80
  )
 
 
 
 
81
 
82
- emoji = '🕹️ GameInsightify'
83
- st.header(emoji)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
  # Initialize chat history
86
  if "messages" not in st.session_state:
@@ -118,21 +197,26 @@ if prompt := st.chat_input("Need a game recommendation?"):
118
  message_placeholder = st.empty()
119
 
120
  # docs = db.max_marginal_relevance_search(prompt,k=query_num, fetch_k=10) # Sending query to db
121
- docs = retriever.invoke(prompt) # retrieve response from chatgpt
 
 
 
122
  full_response = random.choice( # 1st sentence of response
123
- ["I recommend the following games:\n",
124
- f"Hi, human! These are the {len(docs)} best games:\n",
125
- f"I bet you will love these {len(docs)} games:\n",]
126
  )
127
 
128
  # formatting response from db
129
  top_games = []
130
  assistant_response = ""
131
- for idx, doc in enumerate(docs):
132
- gamename = doc.metadata['name']
133
- top_games.append(gamename)
134
- assistant_response += f"{idx+1}. {gamename}\n"
135
-
 
 
 
 
136
  # separating response into chunk of words
137
  chunks = []
138
  for line in assistant_response.splitlines():
@@ -159,12 +243,12 @@ with col2:
159
 
160
 
161
  # Styling on Tabs
162
- css=f'''
163
- div.stTabs {{
164
- height: 40vh;
165
- overflow-y: scroll;
 
166
  overflow-x: hidden;
167
- }}
168
- </style>
169
  '''
170
- st.markdown(f'<style>{css}</style>', unsafe_allow_html=True)
 
5
  from module.__custom__ import *
6
  from streamlit_extras.switch_page_button import switch_page
7
 
8
+
9
  # Openai API Key
10
  import openai
11
  import json
 
26
 
27
  # Example usage
28
  try:
 
 
29
  openai.api_key = os.environ['key']
30
  os.environ['OPENAI_API_KEY'] = os.environ['key']
31
  print(f"OpenAI API Key Found")
 
55
  embedding_function=embedding
56
  )
57
 
 
 
 
 
 
 
 
 
 
58
 
59
  with st.sidebar: is_plot = st.toggle('Enable Plot')
60
  db_selected = db_cos
61
  if is_plot: db_selected = db_plot
62
 
63
 
64
+
65
+ from langchain.agents.agent_toolkits.conversational_retrieval.tool import (
66
+ create_retriever_tool,
67
+ )
68
+ retriever = db_selected.as_retriever()
69
+ retriever_tool = create_retriever_tool(
70
+ retriever,
71
+ "document-retriever",
72
+ "Query a retriever to get information about the video game dataset.",
73
  )
74
+ from typing import List
75
+
76
+ from langchain.utils.openai_functions import convert_pydantic_to_openai_function
77
+ from pydantic import BaseModel, Field
78
 
79
+
80
+ class Response(BaseModel):
81
+ """Final response to the question being asked.
82
+ If you do not have an answer, say you do not have an answer, and ask the user to ask another recommendation.
83
+ If you do have an answer, be verbose and explain why you think the game answers the user's query.
84
+ Don't give information not mentioned in the documents CONTEXT.
85
+ You should always refuse to answer questions that are not related to this specific domain, of video game recommendation.
86
+ If no document passes the minimum threshold of similarity .75, default to apologizing for no answer.
87
+ """
88
+
89
+ answer: str = Field(description="The final answer to the user, including the names in the answer.")
90
+ name: List[str] = Field(
91
+ description="A list of the names of the games found for the user. Only include the game name if it was given as a result to the user's query."
92
+ )
93
+
94
+ import json
95
+
96
+ from langchain.schema.agent import AgentActionMessageLog, AgentFinish
97
+ def parse(output):
98
+ # If no function was invoked, return to user
99
+ if "function_call" not in output.additional_kwargs:
100
+ return AgentFinish(return_values={"output": output.content}, log=output.content)
101
+
102
+ # Parse out the function call
103
+ function_call = output.additional_kwargs["function_call"]
104
+ name = function_call["name"]
105
+ inputs = json.loads(function_call["arguments"])
106
+
107
+ # If the Response function was invoked, return to the user with the function inputs
108
+ if name == "Response":
109
+ return AgentFinish(return_values=inputs, log=str(function_call))
110
+ # Otherwise, return an agent action
111
+ else:
112
+ return AgentActionMessageLog(
113
+ tool=name, tool_input=inputs, log="", message_log=[output]
114
+ )
115
+ from langchain.agents import AgentExecutor
116
+ from langchain.agents.format_scratchpad import format_to_openai_function_messages
117
+ from langchain.chat_models import ChatOpenAI
118
+ from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
119
+ from langchain.tools.render import format_tool_to_openai_function
120
+ prompt = ChatPromptTemplate.from_messages(
121
+ [
122
+ ("system", "You are a recommendation assistant, based off documents."),
123
+ ("user", "{input}"),
124
+ MessagesPlaceholder(variable_name="agent_scratchpad"),
125
+ ]
126
+ )
127
+
128
+ llm_with_tools = llm.bind(
129
+ functions=[
130
+ # The retriever tool
131
+ format_tool_to_openai_function(retriever_tool),
132
+ # Response schema
133
+ convert_pydantic_to_openai_function(Response),
134
+ ]
135
+ )
136
+
137
+ agent = (
138
+ {
139
+ "input": lambda x: x["input"],
140
+ # Format agent scratchpad from intermediate steps
141
+ "agent_scratchpad": lambda x: format_to_openai_function_messages(
142
+ x["intermediate_steps"]
143
+ ),
144
+ }
145
+ | prompt
146
+ | llm_with_tools
147
+ | parse
148
+ )
149
+ agent_executor = AgentExecutor(tools=[retriever_tool], agent=agent, verbose=True)
150
+
151
+ post_prompt = """Do not give me any information that is not included in the document.
152
+ If you do not have an answer, say 'I do not have an answer for that, please ask another question. If you need more context from the user, ask them to
153
+ provide more context in the next query. Do not include games that contain the queried game in the title.
154
+ """
155
+
156
+ st.header("🕹️ GameInsightify - Your Personal Game Recommender")
157
+
158
+ # Description for users
159
+ st.markdown("""
160
+ Welcome to GameInsightify! This chatbot will help you find the perfect game based on your preferences.
161
+ Just type in what you're looking for in a game, and let our AI assistant provide recommendations.
162
+ """)
163
 
164
  # Initialize chat history
165
  if "messages" not in st.session_state:
 
197
  message_placeholder = st.empty()
198
 
199
  # docs = db.max_marginal_relevance_search(prompt,k=query_num, fetch_k=10) # Sending query to db
200
+ docs = agent_executor.invoke(
201
+ {"input": f"{prompt} {post_prompt}"},
202
+ return_only_outputs=True,
203
+ ) # retrieve response from chatgpt
204
  full_response = random.choice( # 1st sentence of response
205
+ [""]
 
 
206
  )
207
 
208
  # formatting response from db
209
  top_games = []
210
  assistant_response = ""
211
+ # for idx, doc in enumerate(docs['name']):
212
+ # gamename = doc
213
+ # top_games.append(gamename)
214
+ # assistant_response += f"{idx+1}. {gamename}\n"
215
+ print(docs)
216
+ try:
217
+ assistant_response += docs["answer"]
218
+ except:
219
+ assistant_response += docs["output"]
220
  # separating response into chunk of words
221
  chunks = []
222
  for line in assistant_response.splitlines():
 
243
 
244
 
245
  # Styling on Tabs
246
+ css = '''
247
+ div.stTabs {
248
+ min-height: 20vh; # Minimum height set for the chat area
249
+ max-height: 60vh; # Maximum height, after which scrolling starts
250
+ overflow-y: auto; # Allows scrolling when content exceeds max height
251
  overflow-x: hidden;
252
+ }
 
253
  '''
254
+ st.markdown(f'<style>{css}</style>', unsafe_allow_html=True)