Vincent Claes commited on
Commit
2458584
1 Parent(s): 611aebd

first working end to end version

Browse files
Files changed (3) hide show
  1. app.py +62 -10
  2. import_data.py +18 -11
  3. requirements.txt +117 -0
app.py CHANGED
@@ -2,32 +2,84 @@ import os
2
 
3
  import gradio as gr
4
  import weaviate
 
 
 
 
5
 
6
  collection_name = "Chunk"
7
 
 
 
 
8
 
9
- def predict(input_text):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  client = weaviate.Client(
11
  url=os.environ["WEAVIATE_URL"],
12
  auth_client_secret=weaviate.AuthApiKey(api_key=os.environ["WEAVIATE_API_KEY"]),
13
- additional_headers={
14
- "X-OpenAI-Api-Key": os.environ["OPENAI_API_KEY"]
15
- }
16
  )
17
 
18
- return (
19
- client.query
20
- .get(class_name=collection_name, properties=["text"])
21
- .with_near_text({"concepts": input_text})
22
- .with_limit(1)
23
- .with_generate(single_prompt="{text}")
24
  .do()
25
  )
 
 
 
 
 
 
 
26
 
27
  iface = gr.Interface(
28
  fn=predict, # the function to wrap
29
  inputs="text", # the input type
30
  outputs="text", # the output type
 
 
 
 
 
 
 
 
31
  )
32
 
33
  if __name__ == "__main__":
 
2
 
3
  import gradio as gr
4
  import weaviate
5
+ from langchain import LLMChain
6
+ from langchain.chains import SequentialChain
7
+ from langchain.chat_models import ChatOpenAI
8
+ from langchain.prompts import ChatPromptTemplate
9
 
10
  collection_name = "Chunk"
11
 
12
+ MODEL = "gpt-3.5-turbo"
13
+ LANGUAGE = "en" # nl / en
14
+ llm = ChatOpenAI(temperature=0.0, openai_api_key=os.environ["OPENAI_API_KEY"])
15
 
16
+
17
+ def get_answer_given_the_context(llm, prompt, context) -> SequentialChain:
18
+ template = f"""
19
+ Provide an answer to the prompt given the context.
20
+
21
+ <PROMPT>
22
+
23
+ {prompt}
24
+
25
+ <CONTEXT>
26
+
27
+ {context}
28
+
29
+ """
30
+
31
+ prompt_get_skills_intersection = ChatPromptTemplate.from_template(template=template)
32
+ skills_match_chain = LLMChain(
33
+ llm=llm,
34
+ prompt=prompt_get_skills_intersection,
35
+ output_key="answer",
36
+ )
37
+
38
+ chain = SequentialChain(
39
+ chains=[skills_match_chain],
40
+ input_variables=["prompt", "context"],
41
+ output_variables=[
42
+ skills_match_chain.output_key,
43
+ ],
44
+ verbose=False,
45
+ )
46
+ return chain({"prompt": prompt, "context": context})["answer"]
47
+
48
+
49
+ def predict(prompt):
50
  client = weaviate.Client(
51
  url=os.environ["WEAVIATE_URL"],
52
  auth_client_secret=weaviate.AuthApiKey(api_key=os.environ["WEAVIATE_API_KEY"]),
53
+ additional_headers={"X-OpenAI-Api-Key": os.environ["OPENAI_API_KEY"]},
 
 
54
  )
55
 
56
+ search_result = (
57
+ client.query.get(class_name=collection_name, properties=["text"])
58
+ .with_near_text({"concepts": prompt})
59
+ # .with_generate(single_prompt="{text}")
60
+ .with_limit(5)
 
61
  .do()
62
  )
63
+ context_list = [
64
+ element["text"] for element in search_result["data"]["Get"]["Chunk"]
65
+ ]
66
+ context = "\n".join(context_list)
67
+
68
+ return get_answer_given_the_context(llm=llm, prompt=prompt, context=context)
69
+
70
 
71
  iface = gr.Interface(
72
  fn=predict, # the function to wrap
73
  inputs="text", # the input type
74
  outputs="text", # the output type
75
+ examples=[
76
+ [f"what is the process of raising an incident?"],
77
+ [f"What is Cx0 program management?"],
78
+ [
79
+ f"What is process for identifying risksthat can impact the desired outcomes of a project?"
80
+ ],
81
+ [f"What is the release management process?"],
82
+ ],
83
  )
84
 
85
  if __name__ == "__main__":
import_data.py CHANGED
@@ -6,6 +6,7 @@ from llama_index import VectorStoreIndex, StorageContext
6
  from pathlib import Path
7
  import argparse
8
 
 
9
  def get_pdf_files(base_path, loader):
10
  """
11
  Get paths to all PDF files in a directory and its subdirectories.
@@ -22,13 +23,15 @@ def get_pdf_files(base_path, loader):
22
  if not os.path.exists(base_path):
23
  raise FileNotFoundError(f"The specified base path does not exist: {base_path}")
24
  if not os.path.isdir(base_path):
25
- raise NotADirectoryError(f"The specified base_path is not a directory: {base_path}")
 
 
26
 
27
  # Loop through all directories and files starting from the base path
28
  for root, dirs, files in os.walk(base_path):
29
  for filename in files:
30
  # If a file has a .pdf extension, add its path to the list
31
- if filename.endswith('.pdf'):
32
  pdf_file = loader.load_data(file=Path(root, filename))
33
  pdf_paths.extend(pdf_file)
34
 
@@ -44,13 +47,13 @@ def main(args):
44
  client = weaviate.Client(
45
  url=os.environ["WEAVIATE_URL"],
46
  auth_client_secret=weaviate.AuthApiKey(api_key=os.environ["WEAVIATE_API_KEY"]),
47
- additional_headers={
48
- "X-OpenAI-Api-Key": os.environ["OPENAI_API_KEY"]
49
- }
50
  )
51
 
52
  # construct vector store
53
- vector_store = WeaviateVectorStore(weaviate_client=client, index_name=args.customer, text_key="content")
 
 
54
 
55
  # setting up the storage for the embeddings
56
  storage_context = StorageContext.from_defaults(vector_store=vector_store)
@@ -63,11 +66,15 @@ def main(args):
63
 
64
 
65
  if __name__ == "__main__":
66
- parser = argparse.ArgumentParser(description='Process and query PDF files.')
67
-
68
- parser.add_argument('--customer', default='Ausy', help='Customer name')
69
- parser.add_argument('--pdf_dir', default='./data', help='Directory containing PDFs')
70
- parser.add_argument('--query', default='What is CX0 customer exprience office?', help='Query to execute')
 
 
 
 
71
 
72
  args = parser.parse_args()
73
 
 
6
  from pathlib import Path
7
  import argparse
8
 
9
+
10
  def get_pdf_files(base_path, loader):
11
  """
12
  Get paths to all PDF files in a directory and its subdirectories.
 
23
  if not os.path.exists(base_path):
24
  raise FileNotFoundError(f"The specified base path does not exist: {base_path}")
25
  if not os.path.isdir(base_path):
26
+ raise NotADirectoryError(
27
+ f"The specified base_path is not a directory: {base_path}"
28
+ )
29
 
30
  # Loop through all directories and files starting from the base path
31
  for root, dirs, files in os.walk(base_path):
32
  for filename in files:
33
  # If a file has a .pdf extension, add its path to the list
34
+ if filename.endswith(".pdf"):
35
  pdf_file = loader.load_data(file=Path(root, filename))
36
  pdf_paths.extend(pdf_file)
37
 
 
47
  client = weaviate.Client(
48
  url=os.environ["WEAVIATE_URL"],
49
  auth_client_secret=weaviate.AuthApiKey(api_key=os.environ["WEAVIATE_API_KEY"]),
50
+ additional_headers={"X-OpenAI-Api-Key": os.environ["OPENAI_API_KEY"]},
 
 
51
  )
52
 
53
  # construct vector store
54
+ vector_store = WeaviateVectorStore(
55
+ weaviate_client=client, index_name=args.customer, text_key="content"
56
+ )
57
 
58
  # setting up the storage for the embeddings
59
  storage_context = StorageContext.from_defaults(vector_store=vector_store)
 
66
 
67
 
68
  if __name__ == "__main__":
69
+ parser = argparse.ArgumentParser(description="Process and query PDF files.")
70
+
71
+ parser.add_argument("--customer", default="Ausy", help="Customer name")
72
+ parser.add_argument("--pdf_dir", default="./data", help="Directory containing PDFs")
73
+ parser.add_argument(
74
+ "--query",
75
+ default="What is CX0 customer exprience office?",
76
+ help="Query to execute",
77
+ )
78
 
79
  args = parser.parse_args()
80
 
requirements.txt ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==23.2.1 ; python_version >= "3.9" and python_version < "4.0"
2
+ aiohttp==3.8.5 ; python_version >= "3.9" and python_version < "4.0"
3
+ aiosignal==1.3.1 ; python_version >= "3.9" and python_version < "4.0"
4
+ altair==5.1.1 ; python_version >= "3.9" and python_version < "4.0"
5
+ annotated-types==0.5.0 ; python_version >= "3.9" and python_version < "4.0"
6
+ anyio==3.7.1 ; python_version >= "3.9" and python_version < "4.0"
7
+ async-timeout==4.0.3 ; python_version >= "3.9" and python_version < "4.0"
8
+ attrs==23.1.0 ; python_version >= "3.9" and python_version < "4.0"
9
+ authlib==1.2.1 ; python_version >= "3.9" and python_version < "4.0"
10
+ beautifulsoup4==4.12.2 ; python_version >= "3.9" and python_version < "4.0"
11
+ blis==0.7.10 ; python_version >= "3.9" and python_version < "4.0"
12
+ catalogue==2.0.9 ; python_version >= "3.9" and python_version < "4.0"
13
+ certifi==2023.7.22 ; python_version >= "3.9" and python_version < "4.0"
14
+ cffi==1.15.1 ; python_version >= "3.9" and python_version < "4.0"
15
+ charset-normalizer==3.2.0 ; python_version >= "3.9" and python_version < "4.0"
16
+ click==8.1.7 ; python_version >= "3.9" and python_version < "4.0"
17
+ colorama==0.4.6 ; python_version >= "3.9" and python_version < "4.0" and (platform_system == "Windows" or sys_platform == "win32")
18
+ confection==0.1.3 ; python_version >= "3.9" and python_version < "4.0"
19
+ contourpy==1.1.1 ; python_version >= "3.9" and python_version < "4.0"
20
+ cryptography==41.0.4 ; python_version >= "3.9" and python_version < "4.0"
21
+ cycler==0.11.0 ; python_version >= "3.9" and python_version < "4.0"
22
+ cymem==2.0.8 ; python_version >= "3.9" and python_version < "4.0"
23
+ dataclasses-json==0.6.0 ; python_version >= "3.9" and python_version < "4.0"
24
+ exceptiongroup==1.1.3 ; python_version >= "3.9" and python_version < "3.11"
25
+ fastapi==0.103.1 ; python_version >= "3.9" and python_version < "4.0"
26
+ ffmpy==0.3.1 ; python_version >= "3.9" and python_version < "4.0"
27
+ filelock==3.12.4 ; python_version >= "3.9" and python_version < "4.0"
28
+ fonttools==4.42.1 ; python_version >= "3.9" and python_version < "4.0"
29
+ frozenlist==1.4.0 ; python_version >= "3.9" and python_version < "4.0"
30
+ fsspec==2023.9.1 ; python_version >= "3.9" and python_version < "4.0"
31
+ goldenverba==0.2.3 ; python_version >= "3.9" and python_version < "4.0"
32
+ gradio-client==0.5.1 ; python_version >= "3.9" and python_version < "4.0"
33
+ gradio==3.44.4 ; python_version >= "3.9" and python_version < "4.0"
34
+ greenlet==2.0.2 ; python_version >= "3.9" and python_version < "4.0" and (platform_machine == "win32" or platform_machine == "WIN32" or platform_machine == "AMD64" or platform_machine == "amd64" or platform_machine == "x86_64" or platform_machine == "ppc64le" or platform_machine == "aarch64")
35
+ h11==0.14.0 ; python_version >= "3.9" and python_version < "4.0"
36
+ httpcore==0.18.0 ; python_version >= "3.9" and python_version < "4.0"
37
+ httptools==0.6.0 ; python_version >= "3.9" and python_version < "4.0"
38
+ httpx==0.25.0 ; python_version >= "3.9" and python_version < "4.0"
39
+ huggingface-hub==0.17.2 ; python_version >= "3.9" and python_version < "4.0"
40
+ idna==3.4 ; python_version >= "3.9" and python_version < "4.0"
41
+ importlib-metadata==6.8.0 ; python_version >= "3.9" and python_version < "3.10"
42
+ importlib-resources==6.1.0 ; python_version >= "3.9" and python_version < "4.0"
43
+ jinja2==3.1.2 ; python_version >= "3.9" and python_version < "4.0"
44
+ joblib==1.3.2 ; python_version >= "3.9" and python_version < "4.0"
45
+ jsonschema-specifications==2023.7.1 ; python_version >= "3.9" and python_version < "4.0"
46
+ jsonschema==4.19.1 ; python_version >= "3.9" and python_version < "4.0"
47
+ kiwisolver==1.4.5 ; python_version >= "3.9" and python_version < "4.0"
48
+ langchain==0.0.296 ; python_version >= "3.9" and python_version < "4.0"
49
+ langcodes==3.3.0 ; python_version >= "3.9" and python_version < "4.0"
50
+ langsmith==0.0.38 ; python_version >= "3.9" and python_version < "4.0"
51
+ llama-index==0.8.29.post1 ; python_version >= "3.9" and python_version < "4.0"
52
+ markupsafe==2.1.3 ; python_version >= "3.9" and python_version < "4.0"
53
+ marshmallow==3.20.1 ; python_version >= "3.9" and python_version < "4.0"
54
+ matplotlib==3.8.0 ; python_version >= "3.9" and python_version < "4.0"
55
+ multidict==6.0.4 ; python_version >= "3.9" and python_version < "4.0"
56
+ murmurhash==1.0.10 ; python_version >= "3.9" and python_version < "4.0"
57
+ mypy-extensions==1.0.0 ; python_version >= "3.9" and python_version < "4.0"
58
+ nest-asyncio==1.5.8 ; python_version >= "3.9" and python_version < "4.0"
59
+ nltk==3.8.1 ; python_version >= "3.9" and python_version < "4.0"
60
+ numexpr==2.8.6 ; python_version >= "3.9" and python_version < "4.0"
61
+ numpy==1.25.2 ; python_version >= "3.9" and python_version < "4.0"
62
+ openai==0.28.0 ; python_version >= "3.9" and python_version < "4.0"
63
+ orjson==3.9.7 ; python_version >= "3.9" and python_version < "4.0"
64
+ packaging==23.1 ; python_version >= "3.9" and python_version < "4.0"
65
+ pandas==2.1.0 ; python_version >= "3.9" and python_version < "4.0"
66
+ pathy==0.10.2 ; python_version >= "3.9" and python_version < "4.0"
67
+ pillow==10.0.1 ; python_version >= "3.9" and python_version < "4.0"
68
+ preshed==3.0.9 ; python_version >= "3.9" and python_version < "4.0"
69
+ pycparser==2.21 ; python_version >= "3.9" and python_version < "4.0"
70
+ pydantic-core==2.6.3 ; python_version >= "3.9" and python_version < "4.0"
71
+ pydantic==2.3.0 ; python_version >= "3.9" and python_version < "4.0"
72
+ pydub==0.25.1 ; python_version >= "3.9" and python_version < "4.0"
73
+ pyparsing==3.1.1 ; python_version >= "3.9" and python_version < "4.0"
74
+ pypdf==3.16.1 ; python_version >= "3.9" and python_version < "4.0"
75
+ python-dateutil==2.8.2 ; python_version >= "3.9" and python_version < "4.0"
76
+ python-dotenv==1.0.0 ; python_version >= "3.9" and python_version < "4.0"
77
+ python-multipart==0.0.6 ; python_version >= "3.9" and python_version < "4.0"
78
+ pytz==2023.3.post1 ; python_version >= "3.9" and python_version < "4.0"
79
+ pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "4.0"
80
+ referencing==0.30.2 ; python_version >= "3.9" and python_version < "4.0"
81
+ regex==2023.8.8 ; python_version >= "3.9" and python_version < "4.0"
82
+ requests==2.31.0 ; python_version >= "3.9" and python_version < "4.0"
83
+ rpds-py==0.10.3 ; python_version >= "3.9" and python_version < "4.0"
84
+ semantic-version==2.10.0 ; python_version >= "3.9" and python_version < "4.0"
85
+ setuptools-scm==8.0.1 ; python_version >= "3.9" and python_version < "4.0"
86
+ setuptools==68.2.2 ; python_version >= "3.9" and python_version < "4.0"
87
+ six==1.16.0 ; python_version >= "3.9" and python_version < "4.0"
88
+ smart-open==6.4.0 ; python_version >= "3.9" and python_version < "4.0"
89
+ sniffio==1.3.0 ; python_version >= "3.9" and python_version < "4.0"
90
+ soupsieve==2.5 ; python_version >= "3.9" and python_version < "4.0"
91
+ spacy-legacy==3.0.12 ; python_version >= "3.9" and python_version < "4.0"
92
+ spacy-loggers==1.0.5 ; python_version >= "3.9" and python_version < "4.0"
93
+ spacy==3.6.1 ; python_version >= "3.9" and python_version < "4.0"
94
+ sqlalchemy==2.0.21 ; python_version >= "3.9" and python_version < "4.0"
95
+ srsly==2.4.7 ; python_version >= "3.9" and python_version < "4.0"
96
+ starlette==0.27.0 ; python_version >= "3.9" and python_version < "4.0"
97
+ tenacity==8.2.3 ; python_version >= "3.9" and python_version < "4.0"
98
+ thinc==8.1.12 ; python_version >= "3.9" and python_version < "4.0"
99
+ tiktoken==0.5.1 ; python_version >= "3.9" and python_version < "4.0"
100
+ tomli==2.0.1 ; python_version >= "3.9" and python_version < "3.11"
101
+ toolz==0.12.0 ; python_version >= "3.9" and python_version < "4.0"
102
+ tqdm==4.66.1 ; python_version >= "3.9" and python_version < "4.0"
103
+ typer==0.9.0 ; python_version >= "3.9" and python_version < "4.0"
104
+ typing-extensions==4.8.0 ; python_version >= "3.9" and python_version < "4.0"
105
+ typing-inspect==0.9.0 ; python_version >= "3.9" and python_version < "4.0"
106
+ tzdata==2023.3 ; python_version >= "3.9" and python_version < "4.0"
107
+ urllib3==1.26.16 ; python_version >= "3.9" and python_version < "4.0"
108
+ uvicorn==0.23.2 ; python_version >= "3.9" and python_version < "4.0"
109
+ uvicorn[standard]==0.23.2 ; python_version >= "3.9" and python_version < "4.0"
110
+ uvloop==0.17.0 ; (sys_platform != "win32" and sys_platform != "cygwin") and platform_python_implementation != "PyPy" and python_version >= "3.9" and python_version < "4.0"
111
+ validators==0.22.0 ; python_version >= "3.9" and python_version < "4.0"
112
+ wasabi==1.1.2 ; python_version >= "3.9" and python_version < "4.0"
113
+ watchfiles==0.20.0 ; python_version >= "3.9" and python_version < "4.0"
114
+ weaviate-client==3.24.1 ; python_version >= "3.9" and python_version < "4.0"
115
+ websockets==11.0.3 ; python_version >= "3.9" and python_version < "4.0"
116
+ yarl==1.9.2 ; python_version >= "3.9" and python_version < "4.0"
117
+ zipp==3.17.0 ; python_version >= "3.9" and python_version < "3.10"