amanda103 commited on
Commit
0043c9e
1 Parent(s): b429508

Upload 5 files

Browse files
Files changed (5) hide show
  1. app.py +88 -0
  2. cli_app.py +63 -0
  3. download_data.py +44 -0
  4. ingest_data.py +44 -0
  5. requirements.txt +160 -0
app.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Optional, Tuple
3
+ import gradio as gr
4
+ from cli_app import get_chain
5
+ from threading import Lock
6
+ from langchain.vectorstores import Pinecone
7
+ from langchain.embeddings.openai import OpenAIEmbeddings
8
+ import pinecone
9
+
10
+ PINECONE_API_KEY = os.environ.get("PINECONE_API_KEY")
11
+ PINECONE_API_ENV = os.environ.get("PINECONE_API_ENV")
12
+ OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
13
+ PINECONE_INDEX_NAME = os.environ.get("PINECONE_INDEX_NAME")
14
+
15
+
16
+ def grab_vector_connection():
17
+ embeddings = OpenAIEmbeddings(openai_api_key=OPENAI_API_KEY)
18
+ pinecone.init(api_key=PINECONE_API_KEY, environment=PINECONE_API_ENV)
19
+ vectorstore = Pinecone.from_existing_index(PINECONE_INDEX_NAME, embeddings)
20
+ qa_chain = get_chain(vectorstore)
21
+ return qa_chain
22
+
23
+
24
+ class ChatWrapper:
25
+ def __init__(self):
26
+ self.lock = Lock()
27
+
28
+ def __call__(self, inp: str, history: Optional[Tuple[str, str]], chain):
29
+ """Execute the chat functionality."""
30
+ self.lock.acquire()
31
+ if not chain:
32
+ chain = grab_vector_connection()
33
+ try:
34
+ history = history or []
35
+ # Run chain and append input.
36
+ output = chain({"question": inp, "chat_history": history})["answer"]
37
+ history.append((inp, output))
38
+ except Exception as e:
39
+ raise e
40
+ finally:
41
+ self.lock.release()
42
+ return history, history
43
+
44
+
45
+ chat = ChatWrapper()
46
+
47
+
48
+ block = gr.Blocks(css=".gradio-container {background-color: lightgray}")
49
+
50
+ with block:
51
+ with gr.Row():
52
+ gr.Markdown("<h3><center>Chat-IRS-Manuals</center></h3>")
53
+
54
+ chatbot = gr.Chatbot()
55
+
56
+ with gr.Row():
57
+ message = gr.Textbox(
58
+ label="What's your question?",
59
+ placeholder="Ask questions about the IRS Manuals",
60
+ lines=1,
61
+ )
62
+ submit = gr.Button(value="Send", variant="secondary").style(full_width=False)
63
+
64
+ gr.Examples(
65
+ examples=[
66
+ "What is the definition of a taxpayer?",
67
+ "What kinds of factors affect how much I owe in taxes?",
68
+ "What if I don't pay my taxes?",
69
+ ],
70
+ inputs=message,
71
+ )
72
+
73
+ gr.HTML("Demo application of a LangChain chain.")
74
+
75
+ gr.HTML(
76
+ """<center>
77
+ Powered by <a href='https://github.com/hwchase17/langchain'>LangChain 🦜️🔗</a>
78
+ and <a href='https://github.com/unstructured-io/unstructured'>Unstructured.IO</a>
79
+ </center>"""
80
+ )
81
+
82
+ state = gr.State()
83
+ agent_state = gr.State()
84
+
85
+ submit.click(chat, inputs=[message, state, agent_state], outputs=[chatbot, state])
86
+ message.submit(chat, inputs=[message, state, agent_state], outputs=[chatbot, state])
87
+
88
+ block.launch(debug=True)
cli_app.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.prompts.prompt import PromptTemplate
2
+ from langchain.llms import OpenAI
3
+ from langchain.chains import ConversationalRetrievalChain, ChatVectorDBChain
4
+ from langchain.vectorstores import Pinecone
5
+ from langchain.embeddings.openai import OpenAIEmbeddings
6
+ import pinecone
7
+ import os
8
+
9
+
10
+ PINECONE_API_KEY = os.environ.get("PINECONE_API_KEY")
11
+ PINECONE_API_ENV = os.environ.get("PINECONE_API_ENV")
12
+ OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
13
+ PINECONE_INDEX_NAME = os.environ.get("PINECONE_INDEX_NAME")
14
+
15
+
16
+ _template = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question.
17
+ You can assume the question about the Internal Revenue Manuals.
18
+
19
+ Chat History:
20
+ {chat_history}
21
+ Follow Up Input: {question}
22
+ Standalone question:"""
23
+
24
+ CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template)
25
+
26
+ template = """You are an AI assistant for answering questions about the Internal Revenue Manuals. You are given the following extracted parts of a long document and a question. Provide a conversational answer.
27
+ If you don't know the answer, just say "Hmm, I'm not sure." Don't try to make up an answer.
28
+ If the question is not about the war in Internal Revenue Manuals, politely inform them that you are tuned to only answer questions about the Internal Revenue Manuals
29
+ Question: {question}
30
+ =========
31
+ {context}
32
+ =========
33
+ Answer in Markdown:"""
34
+
35
+
36
+ QA_PROMPT = PromptTemplate(template=template, input_variables=["question", "context"])
37
+
38
+
39
+ def get_chain(vector):
40
+ llm = OpenAI(temperature=0, openai_api_key=OPENAI_API_KEY)
41
+ qa_chain = ChatVectorDBChain.from_llm(
42
+ llm,
43
+ vector,
44
+ qa_prompt=QA_PROMPT,
45
+ condense_question_prompt=CONDENSE_QUESTION_PROMPT,
46
+ )
47
+ return qa_chain
48
+
49
+
50
+ if __name__ == "__main__":
51
+ embeddings = OpenAIEmbeddings(openai_api_key=OPENAI_API_KEY)
52
+ pinecone.init(api_key=PINECONE_API_KEY, environment=PINECONE_API_ENV)
53
+ vectorstore = Pinecone.from_existing_index(PINECONE_INDEX_NAME, embeddings)
54
+ qa_chain = get_chain(vectorstore)
55
+ chat_history = []
56
+ print("Chat with your docs!")
57
+ while True:
58
+ print("Human:")
59
+ question = input()
60
+ result = qa_chain({"question": question, "chat_history": chat_history})
61
+ chat_history.append((question, result["answer"]))
62
+ print("AI:")
63
+ print(result["answer"])
download_data.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import urllib
3
+ import requests
4
+ from bs4 import BeautifulSoup
5
+ import re
6
+ import zipfile
7
+
8
+
9
+ def get_zip_urls(base="https://www.irs.gov/downloads/irm", start_page=1, max_page=74):
10
+ urls = []
11
+ for page_num in range(start_page, max_page + 1):
12
+ url = f"{base}?page={page_num}"
13
+ response = requests.get(url)
14
+ html_content = response.text
15
+ soup = BeautifulSoup(html_content, "html.parser")
16
+ for link in soup.find_all("a", href=re.compile(r"\.zip$")):
17
+ urls.append(link.get("href"))
18
+ return urls
19
+
20
+
21
+ def download_and_unzip(urls, unzip_dir):
22
+ for zip_url in urls[:10]:
23
+ filename = zip_url.split("/")[-1]
24
+ urllib.request.urlretrieve(zip_url, filename)
25
+ with zipfile.ZipFile(filename, "r") as zip_ref:
26
+ for file_info in zip_ref.infolist():
27
+ # check if the file has a PDF extension
28
+ if file_info.filename.lower().endswith(".pdf"):
29
+ # extract the file to the PDF directory
30
+ zip_ref.extract(file_info, unzip_dir)
31
+
32
+
33
+ if __name__ == "__main__":
34
+ base_url = sys.argv[1]
35
+ page_start = int(sys.argv[2])
36
+ page_max = int(sys.argv[3])
37
+ pdf_dir = sys.argv[4]
38
+ print(f"Grabbing zip urls from {base_url}")
39
+ zip_urls = get_zip_urls(base_url, page_start, page_max)
40
+ print(
41
+ f"Found {len(zip_urls)} zip urls, downloading and unzipping pdfs into {pdf_dir}"
42
+ )
43
+ download_and_unzip(zip_urls, pdf_dir)
44
+ print(f"Finished unzipping")
ingest_data.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ import pinecone
4
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
5
+ from langchain.document_loaders import DirectoryLoader
6
+ from langchain.embeddings import OpenAIEmbeddings
7
+ from langchain.vectorstores import Pinecone
8
+
9
+
10
+ PINECONE_API_KEY = os.environ.get("PINECONE_API_KEY")
11
+ PINECONE_API_ENV = os.environ.get("PINECONE_API_ENV")
12
+ OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
13
+ PINECONE_INDEX_NAME = os.environ.get("PINECONE_INDEX_NAME")
14
+
15
+
16
+ def load_documents(path_to_files):
17
+ # Uses UnstructuredLoader under the hood
18
+ loader = DirectoryLoader(path=path_to_files, glob="*.json")
19
+ raw_documents = loader.load()
20
+ text_splitter = RecursiveCharacterTextSplitter()
21
+ documents = text_splitter.split_documents(raw_documents)
22
+ return documents
23
+
24
+
25
+ def send_docs_to_pinecone(documents):
26
+ embeddings = OpenAIEmbeddings(openai_api_key=OPENAI_API_KEY)
27
+ pinecone.init(api_key=PINECONE_API_KEY, environment=PINECONE_API_ENV)
28
+
29
+ if PINECONE_INDEX_NAME in pinecone.list_indexes():
30
+ print(
31
+ f"Index {PINECONE_INDEX_NAME} already exists, deleting and recreating to avoid duplicates"
32
+ )
33
+ pinecone.delete_index(name=PINECONE_INDEX_NAME)
34
+
35
+ pinecone.create_index(name=PINECONE_INDEX_NAME, dimension=1536)
36
+ Pinecone.from_documents(documents, embeddings, index_name=PINECONE_INDEX_NAME)
37
+
38
+
39
+ if __name__ == "__main__":
40
+ path_to_files = sys.argv[1]
41
+ print(f"Grabbing json files from {path_to_files}")
42
+ docs = load_documents(path_to_files)
43
+ print(f"Found {len(docs)}, sending to pinecone")
44
+ send_docs_to_pinecone(docs)
requirements.txt ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==23.1.0
2
+ aiohttp==3.8.4
3
+ aiosignal==1.3.1
4
+ altair==4.2.2
5
+ antlr4-python3-runtime==4.9.3
6
+ anyio==3.6.2
7
+ appnope==0.1.3
8
+ argilla==1.6.0
9
+ asttokens==2.2.1
10
+ async-timeout==4.0.2
11
+ attrs==22.2.0
12
+ backcall==0.2.0
13
+ backoff==2.2.1
14
+ beautifulsoup4==4.12.2
15
+ bs4==0.0.1
16
+ certifi==2022.12.7
17
+ cffi==1.15.1
18
+ charset-normalizer==3.1.0
19
+ click==8.1.3
20
+ coloredlogs==15.0.1
21
+ commonmark==0.9.1
22
+ contourpy==1.0.7
23
+ cryptography==40.0.1
24
+ cycler==0.11.0
25
+ dataclasses-json==0.5.7
26
+ decorator==5.1.1
27
+ Deprecated==1.2.13
28
+ dnspython==2.3.0
29
+ effdet==0.3.0
30
+ entrypoints==0.4
31
+ et-xmlfile==1.1.0
32
+ executing==1.2.0
33
+ faiss-cpu==1.7.3
34
+ fastapi==0.95.0
35
+ ffmpy==0.3.0
36
+ filelock==3.11.0
37
+ flatbuffers==23.3.3
38
+ fonttools==4.39.3
39
+ frozenlist==1.3.3
40
+ fsspec==2023.4.0
41
+ gradio==3.25.0
42
+ gradio_client==0.0.10
43
+ h11==0.14.0
44
+ httpcore==0.16.3
45
+ httpx==0.23.3
46
+ huggingface-hub==0.13.4
47
+ humanfriendly==10.0
48
+ idna==3.4
49
+ importlib-metadata==6.3.0
50
+ importlib-resources==5.12.0
51
+ iopath==0.1.10
52
+ ipython==8.12.0
53
+ jedi==0.18.2
54
+ Jinja2==3.1.2
55
+ joblib==1.2.0
56
+ jsonschema==4.17.3
57
+ kiwisolver==1.4.4
58
+ langchain==0.0.136
59
+ layoutparser==0.3.4
60
+ linkify-it-py==2.0.0
61
+ loguru==0.7.0
62
+ lxml==4.9.2
63
+ Markdown==3.4.3
64
+ markdown-it-py==2.2.0
65
+ MarkupSafe==2.1.2
66
+ marshmallow==3.19.0
67
+ marshmallow-enum==1.5.1
68
+ matplotlib==3.7.1
69
+ matplotlib-inline==0.1.6
70
+ mdit-py-plugins==0.3.3
71
+ mdurl==0.1.2
72
+ monotonic==1.6
73
+ mpmath==1.3.0
74
+ msg-parser==1.2.0
75
+ multidict==6.0.4
76
+ mypy-extensions==1.0.0
77
+ networkx==3.1
78
+ nltk==3.8.1
79
+ numpy==1.23.5
80
+ olefile==0.46
81
+ omegaconf==2.3.0
82
+ onnxruntime==1.14.1
83
+ openai==0.27.4
84
+ openapi-schema-pydantic==1.2.4
85
+ opencv-python==4.6.0.66
86
+ openpyxl==3.1.2
87
+ orjson==3.8.10
88
+ packaging==23.0
89
+ pandas==1.5.3
90
+ parso==0.8.3
91
+ pdf2image==1.16.3
92
+ pdfminer.six==20221105
93
+ pdfplumber==0.8.1
94
+ pexpect==4.8.0
95
+ pickleshare==0.7.5
96
+ Pillow==9.5.0
97
+ pinecone-client==2.2.1
98
+ pkgutil_resolve_name==1.3.10
99
+ portalocker==2.7.0
100
+ prompt-toolkit==3.0.38
101
+ protobuf==4.22.1
102
+ ptyprocess==0.7.0
103
+ pure-eval==0.2.2
104
+ pycocotools==2.0.6
105
+ pycparser==2.21
106
+ pydantic==1.10.7
107
+ pydub==0.25.1
108
+ Pygments==2.15.0
109
+ pypandoc==1.11
110
+ pyparsing==3.0.9
111
+ pyrsistent==0.19.3
112
+ pytesseract==0.3.10
113
+ python-dateutil==2.8.2
114
+ python-docx==0.8.11
115
+ python-magic==0.4.27
116
+ python-multipart==0.0.6
117
+ python-pptx==0.6.21
118
+ pytz==2023.3
119
+ PyYAML==6.0
120
+ regex==2023.3.23
121
+ requests==2.28.2
122
+ rfc3986==1.5.0
123
+ rich==13.0.1
124
+ scikit-learn==1.2.2
125
+ scipy==1.10.1
126
+ semantic-version==2.10.0
127
+ sentence-transformers==2.2.2
128
+ sentencepiece==0.1.97
129
+ six==1.16.0
130
+ sniffio==1.3.0
131
+ soupsieve==2.4
132
+ SQLAlchemy==1.4.47
133
+ stack-data==0.6.2
134
+ starlette==0.26.1
135
+ sympy==1.11.1
136
+ tenacity==8.2.2
137
+ threadpoolctl==3.1.0
138
+ tiktoken==0.3.3
139
+ timm==0.6.13
140
+ tokenizers==0.13.3
141
+ toolz==0.12.0
142
+ torch==2.0.0
143
+ torchvision==0.15.1
144
+ tqdm==4.65.0
145
+ traitlets==5.9.0
146
+ transformers==4.27.4
147
+ typing-inspect==0.8.0
148
+ typing_extensions==4.5.0
149
+ uc-micro-py==1.0.1
150
+ unstructured==0.5.11
151
+ unstructured-inference==0.3.2
152
+ urllib3==1.26.15
153
+ uvicorn==0.21.1
154
+ Wand==0.6.11
155
+ wcwidth==0.2.6
156
+ websockets==11.0.1
157
+ wrapt==1.14.1
158
+ XlsxWriter==3.0.9
159
+ yarl==1.8.2
160
+ zipp==3.15.0