isayahc commited on
Commit
0190e25
1 Parent(s): 5c184a9

seperated QA from app.py

Browse files
Files changed (2) hide show
  1. app.py +44 -155
  2. qa.py +0 -0
app.py CHANGED
@@ -1,122 +1,11 @@
1
-
2
- # logging
3
- import logging
4
-
5
- # access .env file
6
- import os
7
- from dotenv import load_dotenv
8
-
9
- import time
10
-
11
- #boto3 for S3 access
12
- import boto3
13
- from botocore import UNSIGNED
14
- from botocore.client import Config
15
-
16
- # HF libraries
17
- from langchain.llms import HuggingFaceHub
18
- from langchain.embeddings import HuggingFaceHubEmbeddings
19
- # vectorestore
20
- from langchain.vectorstores import Chroma
21
-
22
- # retrieval chain
23
  from langchain.chains import RetrievalQAWithSourcesChain
24
- # prompt template
25
- from langchain.prompts import PromptTemplate
26
- from langchain.memory import ConversationBufferMemory
27
- from langchain.retrievers import BM25Retriever, EnsembleRetriever
28
- # reorder retrived documents
29
- # github issues
30
- from langchain.document_loaders import GitHubIssuesLoader
31
- # debugging
32
- from langchain.globals import set_verbose
33
- # caching
34
- from langchain.globals import set_llm_cache
35
- # We can do the same thing with a SQLite cache
36
- from langchain.cache import SQLiteCache
37
 
38
  # gradio
39
  import gradio as gr
40
 
41
- # template for prompt
42
- from prompt import template
43
-
44
-
45
-
46
- set_verbose(True)
47
-
48
-
49
- # set up logging for the chain
50
- logging.basicConfig()
51
- logging.getLogger("langchain.retrievers").setLevel(logging.INFO)
52
- logging.getLogger("langchain.chains.qa_with_sources").setLevel(logging.INFO)
53
-
54
- # load .env variables
55
- config = load_dotenv(".env")
56
- HUGGINGFACEHUB_API_TOKEN=os.getenv('HUGGINGFACEHUB_API_TOKEN')
57
- AWS_S3_LOCATION=os.getenv('AWS_S3_LOCATION')
58
- AWS_S3_FILE=os.getenv('AWS_S3_FILE')
59
- VS_DESTINATION=os.getenv('VS_DESTINATION')
60
-
61
- # remove old vectorstore
62
- if os.path.exists(VS_DESTINATION):
63
- os.remove(VS_DESTINATION)
64
-
65
- # remove old sqlite cache
66
- if os.path.exists('.langchain.sqlite'):
67
- os.remove('.langchain.sqlite')
68
-
69
- # initialize Model config
70
- llm_model_name = "mistralai/Mistral-7B-Instruct-v0.1"
71
-
72
- # changed named to model_id to llm as is common
73
- llm = HuggingFaceHub(repo_id=llm_model_name, model_kwargs={
74
- # "temperature":0.1,
75
- "max_new_tokens":1024,
76
- "repetition_penalty":1.2,
77
- # "streaming": True,
78
- # "return_full_text":True
79
- })
80
-
81
- # initialize Embedding config
82
- embedding_model_name = "sentence-transformers/all-mpnet-base-v2"
83
- embeddings = HuggingFaceHubEmbeddings(repo_id=embedding_model_name)
84
-
85
- set_llm_cache(SQLiteCache(database_path=".langchain.sqlite"))
86
-
87
- # retrieve vectorsrore
88
- s3 = boto3.client('s3', config=Config(signature_version=UNSIGNED))
89
-
90
- ## Chroma DB
91
- s3.download_file(AWS_S3_LOCATION, AWS_S3_FILE, VS_DESTINATION)
92
- # use the cached embeddings instead of embeddings to speed up re-retrival
93
- db = Chroma(persist_directory="./vectorstore", embedding_function=embeddings)
94
- db.get()
95
-
96
-
97
- retriever = db.as_retriever(search_type="mmr")#, search_kwargs={'k': 3, 'lambda_mult': 0.25})
98
-
99
- # asks LLM to create 3 alternatives baed on user query
100
- # asks LLM to extract relevant parts from retrieved documents
101
-
102
-
103
  global qa
104
-
105
- prompt = PromptTemplate(
106
- input_variables=["history", "context", "question"],
107
- template=template,
108
- )
109
- memory = ConversationBufferMemory(memory_key="history", input_key="question")
110
-
111
-
112
-
113
- qa = RetrievalQAWithSourcesChain.from_chain_type(llm=llm, retriever=retriever, return_source_documents=True, verbose=True, chain_type_kwargs={
114
- "verbose": True,
115
- "memory": memory,
116
- "prompt": prompt,
117
- "document_variable_name": "context"
118
- }
119
- )
120
 
121
 
122
  #####
@@ -124,57 +13,57 @@ qa = RetrievalQAWithSourcesChain.from_chain_type(llm=llm, retriever=retriever, r
124
  # Gradio fns
125
  ####
126
 
127
- def add_text(history, text):
128
- history = history + [(text, None)]
129
- return history, ""
130
-
131
- def bot(history):
132
- response = infer(history[-1][0], history)
133
- sources = [doc.metadata.get("source") for doc in response['source_documents']]
134
- src_list = '\n'.join(sources)
135
- print_this = response['answer'] + "\n\n\n Sources: \n\n\n" + src_list
136
 
 
 
 
 
 
137
 
138
- history[-1][1] = print_this #response['answer']
139
- return history
140
 
141
- def infer(question, history):
142
- query = question
143
- result = qa({"query": query, "history": history, "question": question})
144
- return result
145
 
146
- css="""
147
- #col-container {max-width: 700px; margin-left: auto; margin-right: auto;}
148
- """
 
149
 
150
- title = """
151
- <div style="text-align: center;max-width: 1920px;">
152
- <h1>Chat with your Documentation</h1>
153
- <p style="text-align: center;">This is a privately hosten Docs AI Buddy, <br />
154
- It will help you with any question regarding the documentation of Ray ;)</p>
155
- </div>
156
- """
157
 
 
 
 
 
 
 
 
158
 
159
 
160
- with gr.Blocks(css=css) as demo:
161
- with gr.Column(min_width=900, elem_id="col-container"):
162
- gr.HTML(title)
163
- chatbot = gr.Chatbot([], elem_id="chatbot")
164
- #with gr.Row():
165
- # clear = gr.Button("Clear")
166
 
167
- with gr.Row():
168
- question = gr.Textbox(label="Question", placeholder="Type your question and hit Enter ")
169
- with gr.Row():
170
- clear = gr.ClearButton([chatbot, question])
 
 
171
 
172
- question.submit(add_text, [chatbot, question], [chatbot, question], queue=False).then(
173
- bot, chatbot, chatbot
174
- )
175
- #clear.click(lambda: None, None, chatbot, queue=False)
176
 
177
- demo.queue().launch()
 
 
 
178
 
179
- def create_gradio_interface(qa:RetrievalQAWithSourcesChain, ):
180
- pass
 
 
1
+ # import for typing
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  from langchain.chains import RetrievalQAWithSourcesChain
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
  # gradio
5
  import gradio as gr
6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  global qa
8
+ from qa import qa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
 
11
  #####
 
13
  # Gradio fns
14
  ####
15
 
16
+ def create_gradio_interface(qa:RetrievalQAWithSourcesChain):
17
+ def add_text(history, text):
18
+ history = history + [(text, None)]
19
+ return history, ""
 
 
 
 
 
20
 
21
+ def bot(history):
22
+ response = infer(history[-1][0], history)
23
+ sources = [doc.metadata.get("source") for doc in response['source_documents']]
24
+ src_list = '\n'.join(sources)
25
+ print_this = response['answer'] + "\n\n\n Sources: \n\n\n" + src_list
26
 
 
 
27
 
28
+ history[-1][1] = print_this #response['answer']
29
+ return history
 
 
30
 
31
+ def infer(question, history):
32
+ query = question
33
+ result = qa({"query": query, "history": history, "question": question})
34
+ return result
35
 
36
+ css="""
37
+ #col-container {max-width: 700px; margin-left: auto; margin-right: auto;}
38
+ """
 
 
 
 
39
 
40
+ title = """
41
+ <div style="text-align: center;max-width: 1920px;">
42
+ <h1>Chat with your Documentation</h1>
43
+ <p style="text-align: center;">This is a privately hosten Docs AI Buddy, <br />
44
+ It will help you with any question regarding the documentation of Ray ;)</p>
45
+ </div>
46
+ """
47
 
48
 
 
 
 
 
 
 
49
 
50
+ with gr.Blocks(css=css) as demo:
51
+ with gr.Column(min_width=900, elem_id="col-container"):
52
+ gr.HTML(title)
53
+ chatbot = gr.Chatbot([], elem_id="chatbot")
54
+ #with gr.Row():
55
+ # clear = gr.Button("Clear")
56
 
57
+ with gr.Row():
58
+ question = gr.Textbox(label="Question", placeholder="Type your question and hit Enter ")
59
+ with gr.Row():
60
+ clear = gr.ClearButton([chatbot, question])
61
 
62
+ question.submit(add_text, [chatbot, question], [chatbot, question], queue=False).then(
63
+ bot, chatbot, chatbot
64
+ )
65
+ #clear.click(lambda: None, None, chatbot, queue=False)
66
 
67
+ if __name__ == "__main__":
68
+ demo = create_gradio_interface(qa)
69
+ demo.queue().launch()
qa.py ADDED
File without changes