Asaad Almutareb commited on
Commit
8a93d38
1 Parent(s): 7d8f356

added stuff

Browse files
Files changed (2) hide show
  1. app.py +171 -13
  2. lc-embeddings.py +0 -0
app.py CHANGED
@@ -3,9 +3,162 @@ from langchain.chains import RetrievalQAWithSourcesChain
3
 
4
  # gradio
5
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  global qa
8
- from qa import qa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
 
11
  #####
@@ -21,24 +174,29 @@ def create_gradio_interface(qa:RetrievalQAWithSourcesChain):
21
  history = history + [(text, None)]
22
  return history, ""
23
 
24
- def bot(history):
25
- response = infer(history[-1][0], history)
26
- sources = [doc.metadata.get("source") for doc in response['source_documents']]
27
- src_list = '\n'.join(sources)
28
- print_this = response['answer'] + "\n\n\n Sources: \n\n\n" + src_list
29
-
30
-
31
- history[-1][1] = print_this #response['answer']
32
- return history
 
 
 
 
 
33
 
34
  def infer(question, history):
35
  query = question
36
  result = qa({"query": query, "history": history, "question": question})
37
  return result
38
 
39
- css="""
40
- #col-container {max-width: 700px; margin-left: auto; margin-right: auto;}
41
- """
42
 
43
  title = """
44
  <div style="text-align: center;max-width: 1920px;">
 
3
 
4
  # gradio
5
  import gradio as gr
6
+ #import random
7
+ import time
8
+ #boto3 for S3 access
9
+ import boto3
10
+ from botocore import UNSIGNED
11
+ from botocore.client import Config
12
+ # access .env file
13
+ import os
14
+ from dotenv import load_dotenv
15
+ #from bs4 import BeautifulSoup
16
+ # HF libraries
17
+ from langchain.llms import HuggingFaceHub
18
+ from langchain.embeddings import HuggingFaceHubEmbeddings
19
+ # vectorestore
20
+ from langchain.vectorstores import Chroma
21
+ #from langchain.vectorstores import FAISS
22
+ # retrieval chain
23
+ from langchain.chains import RetrievalQA
24
+ # from langchain.chains import RetrievalQAWithSourcesChain
25
+ # prompt template
26
+ from langchain.prompts import PromptTemplate
27
+ from langchain.memory import ConversationBufferMemory
28
+ # logging
29
+ import logging
30
+ #import zipfile
31
+ # improve results with retriever
32
+ # from langchain.retrievers import ContextualCompressionRetriever
33
+ # from langchain.retrievers.document_compressors import LLMChainExtractor
34
+ # from langchain.retrievers.document_compressors import EmbeddingsFilter
35
+ # from langchain.retrievers.multi_query import MultiQueryRetriever
36
+ # from langchain.retrievers import BM25Retriever, EnsembleRetriever
37
+ # reorder retrived documents
38
+ # from langchain.document_transformers import LongContextReorder
39
+ # github issues
40
+ #from langchain.document_loaders import GitHubIssuesLoader
41
+ # debugging
42
+ from langchain.globals import set_verbose
43
+ # caching
44
+ from langchain.globals import set_llm_cache
45
+ #from langchain.cache import InMemoryCache
46
+ # We can do the same thing with a SQLite cache
47
+ from langchain.cache import SQLiteCache
48
+ #set_llm_cache(InMemoryCache())
49
+
50
+ set_verbose(True)
51
+
52
+ # load .env variables
53
+ config = load_dotenv(".env")
54
+ HUGGINGFACEHUB_API_TOKEN=os.getenv('HUGGINGFACEHUB_API_TOKEN')
55
+ AWS_S3_LOCATION=os.getenv('AWS_S3_LOCATION')
56
+ AWS_S3_FILE=os.getenv('AWS_S3_FILE')
57
+ VS_DESTINATION=os.getenv('VS_DESTINATION')
58
+
59
+ # initialize Model config
60
+ # HuggingFaceH4/zephyr-7b-beta
61
+ # mistralai/Mistral-7B-Instruct-v0.1
62
+ model_id = HuggingFaceHub(repo_id="HuggingFaceH4/zephyr-7b-beta", model_kwargs={
63
+ "temperature":0.1,
64
+ "max_new_tokens":1024,
65
+ "repetition_penalty":1.2,
66
+ # "streaming": True,
67
+ # "return_full_text":True
68
+ })
69
+
70
+ #model_name = "sentence-transformers/multi-qa-mpnet-base-dot-v1"
71
+ model_name = "sentence-transformers/all-mpnet-base-v2"
72
+ embeddings = HuggingFaceHubEmbeddings(repo_id=model_name)
73
+
74
+ # remove old vectorstore
75
+ if os.path.exists(VS_DESTINATION):
76
+ os.remove(VS_DESTINATION)
77
+
78
+ # remove old sqlite cache
79
+ if os.path.exists('.langchain.sqlite'):
80
+ os.remove('.langchain.sqlite')
81
+
82
+ set_llm_cache(SQLiteCache(database_path=".langchain.sqlite"))
83
+
84
+ # retrieve vectorsrore
85
+ s3 = boto3.client('s3', config=Config(signature_version=UNSIGNED))
86
+
87
+ ## Chroma DB
88
+ s3.download_file(AWS_S3_LOCATION, AWS_S3_FILE, VS_DESTINATION)
89
+ # use the cached embeddings instead of embeddings to speed up re-retrival
90
+ db = Chroma(persist_directory="./vectorstore", embedding_function=embeddings)
91
+ db.get()
92
+
93
+ ## FAISS DB
94
+ # s3.download_file('rad-rag-demos', 'vectorstores/faiss_db_ray.zip', './chroma_db/faiss_db_ray.zip')
95
+ # with zipfile.ZipFile('./chroma_db/faiss_db_ray.zip', 'r') as zip_ref:
96
+ # zip_ref.extractall('./chroma_db/')
97
+
98
+ # FAISS_INDEX_PATH='./chroma_db/faiss_db_ray'
99
+ # db = FAISS.load_local(FAISS_INDEX_PATH, embeddings)
100
+
101
+ # initialize the bm25 retriever and chroma/faiss retriever
102
+ # bm25_retriever = BM25Retriever.
103
+ # bm25_retriever.k = 2
104
+
105
+ # Retrieve more documents with higher diversity useful if your dataset has many similar documents
106
+ retriever = db.as_retriever(search_type="mmr")#, search_kwargs={'k': 3, 'lambda_mult': 0.25})
107
+ # Above a certain threshold
108
+ # retriever = db.as_retriever(
109
+ # search_type="similarity_score_threshold",
110
+ # search_kwargs={'score_threshold': 0.6}
111
+ # )
112
+
113
+ # # asks LLM to create 3 alternatives baed on user query
114
+ # multi_retriever = MultiQueryRetriever.from_llm(retriever=retriever, llm=model_id)
115
+
116
+ # # asks LLM to extract relevant parts from retrieved documents
117
+ # compressor = LLMChainExtractor.from_llm(model_id)
118
+ # compression_retriever = ContextualCompressionRetriever(base_compressor=compressor, base_retriever=multi_retriever)
119
 
120
  global qa
121
+ template = """
122
+ You are the friendly documentation AI buddy Arti, who helps the Human in using RAY, the open-source unified framework for scaling AI and Python applications.
123
+ Use the following context (delimited by <ctx></ctx>) and the chat history (delimited by <hs></hs>) to answer the question :
124
+ ------
125
+ <ctx>
126
+ {context}
127
+ </ctx>
128
+ ------
129
+ <hs>
130
+ {history}
131
+ </hs>
132
+ ------
133
+ {question}
134
+ Answer:
135
+ """
136
+ prompt = PromptTemplate(
137
+ input_variables=["history", "context", "question"],
138
+ template=template,
139
+ )
140
+ memory = ConversationBufferMemory(memory_key="history", input_key="question")
141
+
142
+ # logging for the chain
143
+ logging.basicConfig()
144
+ logging.getLogger("langchain.retrievers").setLevel(logging.INFO)
145
+ logging.getLogger("langchain.chains").setLevel(logging.INFO)
146
+
147
+
148
+
149
+ qa = RetrievalQA.from_chain_type(llm=model_id, retriever=retriever, return_source_documents=True, verbose=True, chain_type_kwargs={
150
+ "verbose": True,
151
+ "memory": memory,
152
+ "prompt": prompt
153
+ }
154
+ )
155
+ # qa = RetrievalQAWithSourcesChain.from_chain_type(llm=model_id, retriever=retriever, return_source_documents=True, verbose=True, chain_type_kwargs={
156
+ # "verbose": True,
157
+ # "memory": memory,
158
+ # "prompt": prompt,
159
+ # "document_variable_name": "context"
160
+ # }
161
+ # )
162
 
163
 
164
  #####
 
174
  history = history + [(text, None)]
175
  return history, ""
176
 
177
+ def bot(history):
178
+ response = infer(history[-1][0], history)
179
+ sources = [doc.metadata.get("source") for doc in response['source_documents']]
180
+ src_list = '\n'.join(sources)
181
+ print_this = response['result'] + "\n\n\n Sources: \n\n\n" + src_list
182
+
183
+ history[-1][1] = ""
184
+ for character in response['result']: #print_this: #
185
+ #print_this:
186
+ history[-1][1] += character
187
+ time.sleep(0.01)
188
+ yield history
189
+ # history[-1][1] = print_this #response['answer']
190
+ # return history
191
 
192
  def infer(question, history):
193
  query = question
194
  result = qa({"query": query, "history": history, "question": question})
195
  return result
196
 
197
+ css="""
198
+ #col-container {min-width: 800px; max-width: 1920px; margin-left: auto; margin-right: auto;}
199
+ """
200
 
201
  title = """
202
  <div style="text-align: center;max-width: 1920px;">
lc-embeddings.py ADDED
File without changes