elyxlz commited on
Commit
0c7add2
·
1 Parent(s): be8dc83

initial commit

Browse files
Files changed (9) hide show
  1. .gitattributes +1 -33
  2. .gitignore +2 -0
  3. app.py +104 -0
  4. config/conf_0.1.yaml +52 -0
  5. data/store.pkl +3 -0
  6. ingest.py +78 -0
  7. modules.py +75 -0
  8. requirements.txt +7 -0
  9. tools.py +76 -0
.gitattributes CHANGED
@@ -1,34 +1,2 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
  *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tflite filter=lfs diff=lfs merge=lfs -text
29
- *.tgz filter=lfs diff=lfs merge=lfs -text
30
- *.wasm filter=lfs diff=lfs merge=lfs -text
31
- *.xz filter=lfs diff=lfs merge=lfs -text
32
- *.zip filter=lfs diff=lfs merge=lfs -text
33
- *.zst filter=lfs diff=lfs merge=lfs -text
34
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  *.pkl filter=lfs diff=lfs merge=lfs -text
2
+ *.pkl* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ .env
2
+ __pycache__
app.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Optional, Tuple
3
+ import gradio as gr
4
+ import argparse
5
+ import datetime
6
+ import pickle
7
+ #import whisper
8
+ import dotenv
9
+ import sys
10
+ from io import StringIO
11
+ import re
12
+ dotenv.load_dotenv()
13
+
14
+ from langchain.callbacks import get_openai_callback
15
+
16
+ import hydra
17
+ from omegaconf import DictConfig, open_dict, OmegaConf
18
+
19
+
20
+
21
+ class ChatbotAgentGradio():
22
+ def __init__(
23
+ self,
24
+ config_name
25
+ ):
26
+
27
+ config = OmegaConf.load(f'./config/{config_name}.yaml')
28
+ self.chatbot = hydra.utils.instantiate(config.model, _convert_="partial")
29
+
30
+ def chat(self,
31
+ inp: str,
32
+ history: Optional[Tuple[str, str]],
33
+ ):
34
+
35
+ """Method for integration with gradio Chatbot"""
36
+ print("\n==== date/time: " + str(datetime.datetime.now()) + " ====")
37
+ print("inp: " + inp)
38
+ history = history or []
39
+
40
+ output = self.chatbot.run(inp)
41
+
42
+
43
+ history.append((inp, output))
44
+
45
+ return history, history#, ""
46
+
47
+ def update_foo(self, widget, state):
48
+ if widget:
49
+ state = widget
50
+ return state
51
+
52
+ def launch_app(self):
53
+
54
+ block = gr.Blocks(css=".gradio-container {background-color: lightgray}")
55
+
56
+ with block:
57
+ instance = gr.State()
58
+ show_chain_state = gr.State(False)
59
+
60
+
61
+ with gr.Row():
62
+ gr.Markdown("<h3><center>UNHCR</center></h3>")
63
+
64
+
65
+
66
+ with gr.Row():
67
+ chatbot = gr.Chatbot()
68
+
69
+
70
+ with gr.Row():
71
+ message = gr.Textbox(
72
+ label="What's your question?",
73
+ lines=1,
74
+ )
75
+ submit = gr.Button(value="Send", variant="secondary").style(full_width=False)
76
+
77
+
78
+ state = gr.State()
79
+ agent_state = gr.State()
80
+
81
+ submit.click(self.chat, inputs=[message, state], outputs=[chatbot, state])
82
+ message.submit(self.chat, inputs=[message, state], outputs=[chatbot, state])
83
+
84
+
85
+ block.launch(debug=True, share=False, server_port=7861)#, server_name='192.168.0.73', )
86
+
87
+
88
+
89
+ def simple(config):
90
+ config = OmegaConf.load(f'./config/{config}.yaml')
91
+ chatbot = hydra.utils.instantiate(config.model, _convert_="partial")
92
+
93
+ while True:
94
+ inp = input("\nUser: ")
95
+ print(chatbot.run(inp))
96
+ if __name__ == '__main__':
97
+
98
+ #simple('conf_0.1')
99
+ app = ChatbotAgentGradio('conf_0.1')
100
+ app.launch_app()
101
+
102
+ #QA = QA(store, k=1)
103
+ #app = QAGradio(QA)
104
+ #app.launch_app()
config/conf_0.1.yaml ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # chatbot model
3
+ model:
4
+ _target_: modules.initialize_agent
5
+
6
+ agent: "conversational-react-description" # langchain template for agent
7
+
8
+ tools:
9
+ - _target_: langchain.agents.Tool
10
+ name: "Content Search"
11
+ func:
12
+ _target_: tools.SemanticSearch
13
+ threshold: 0.5
14
+ k: 5
15
+ description: ^
16
+ A content search through the UNHCR documents, it will return relevant extracts for your query.
17
+ The action input should be a full english sentence.
18
+ ALWAYS use this to answer ANY question. If the tool doesn't return anything, say that you don't know.
19
+
20
+
21
+ llm:
22
+ _target_: langchain.llms.OpenAI
23
+ temperature: 0
24
+ openai_api_key: ${oc.env:OPENAI_API_KEY} # environment variable
25
+
26
+ memory:
27
+ _target_: langchain.chains.conversation.memory.ConversationBufferWindowMemory
28
+ memory_key: "chat_history"
29
+ k: 5 # how many of the past interactions it keeps
30
+ #verbose: True
31
+
32
+ prefix: |
33
+ - You are an AI whose purpose is to help answer questions about the UNHCR documents.
34
+ - You answer in a factual manner, always basing your answer on the context provided to you
35
+ - You are free to ignore irrelevant information
36
+ - If you do not know something, you will say that you don't know.
37
+ - Give long answers, answering every question with a lot of detail.
38
+
39
+ TOOLS:
40
+ ------
41
+ You have access to the following tools:
42
+
43
+ suffix: |
44
+ Begin!
45
+ Previous conversation history:
46
+ {chat_history}
47
+ New input: {input}
48
+ {agent_scratchpad}
49
+
50
+ ai_prefix: "AI"
51
+
52
+ verbose: True
data/store.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cccec9eb3ff0488f652e5db4ec9f263979c25deaecf773b4413c108f5493fb0e
3
+ size 6998194
ingest.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import faiss
3
+ import pickle
4
+ from PyPDF2 import PdfReader
5
+ from tqdm import tqdm
6
+ import glob
7
+ import os
8
+ import re
9
+
10
+ from langchain.embeddings.openai import OpenAIEmbeddings
11
+ from langchain.text_splitter import CharacterTextSplitter
12
+ from langchain.vectorstores import FAISS
13
+ from langchain.document_loaders import TextLoader
14
+
15
+
16
+ import dotenv
17
+
18
+ dotenv.load_dotenv()
19
+
20
+ def get_all_pdf_filenames(paths, recursive):
21
+ extensions = ["pdf"]
22
+ filenames = []
23
+ for ext_name in extensions:
24
+ ext = f"**/*.{ext_name}" if recursive else f"*.{ext_name}"
25
+ for path in paths:
26
+ filenames.extend(glob.glob(os.path.join(path, ext), recursive=recursive))
27
+ return filenames
28
+
29
+
30
+ #all_pdf_paths = get_all_pdf_filenames(["/mnt/c/users/elio/Downloads/UNHCR Emergency Manual"], recursive=True)
31
+ #print(f"Found {len(all_pdf_paths)} PDF files")
32
+ #assert len(all_pdf_paths) > 0
33
+ #all_pdf_paths = ['/mnt/c/users/elio/Downloads/UNHCR Emergency Manual/UNHCR Emergency Manual/46a9e29a2.pdf']
34
+
35
+ class Ingester():
36
+ """
37
+ Vectorises chunks of the data and puts source as metadata
38
+ """
39
+ def __init__(
40
+ self,
41
+ separator='\n',
42
+ chunk_overlap=200,
43
+ chunk_size=200,
44
+ ):
45
+
46
+ self.splitter = CharacterTextSplitter(chunk_size=chunk_size, separator=separator, chunk_overlap=chunk_overlap)
47
+
48
+ def ingest(self, path):
49
+ #ps = get_all_pdf_filenames([path], recursive=True) # get paths
50
+ ps = ['/mnt/c/users/elio/Downloads/UNHCR Emergency Manual/UNHCR Emergency Manual/46a9e29a2.pdf']
51
+ data = []
52
+ sources = []
53
+ for p in tqdm(ps): # extract data from paths
54
+ reader = PdfReader(p)
55
+ page = '\n'.join([reader.pages[i].extract_text() for i in range(len(reader.pages))])
56
+ data.append(page)
57
+ sources.append(p)
58
+
59
+ docs = []
60
+ metadatas = []
61
+ for i, d in tqdm(enumerate(data)): # split text and make documents
62
+ splits = self.splitter.split_text(d)
63
+ if all(s != "" for s in splits):
64
+ docs.extend(splits)
65
+ metadatas.extend([{"source": sources[i]}] * len(splits))
66
+
67
+ assert len(docs) > 0
68
+
69
+ print("Extracting embeddings")
70
+ store = FAISS.from_texts(docs, OpenAIEmbeddings(), metadatas=metadatas)
71
+
72
+ with open(os.path.join('./data', 'store.pkl'), "wb") as f:
73
+ pickle.dump(store, f)
74
+
75
+ print(f"Saved store at {os.path.join('./data', 'store.pkl')}.")
76
+
77
+ ingester = Ingester(chunk_size=2000)
78
+ ingester.ingest("/mnt/c/users/elio/Downloads/UNHCR Emergency Manual")
modules.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ """Load agent."""
3
+ from typing import Any, List, Optional
4
+
5
+ from langchain.agents.agent import AgentExecutor
6
+ from langchain.agents.loading import AGENT_TO_CLASS, load_agent
7
+ from langchain.agents.tools import Tool
8
+ from langchain.callbacks.base import BaseCallbackManager
9
+ from langchain.llms.base import BaseLLM
10
+
11
+
12
+ def initialize_agent(
13
+ tools: List[Tool],
14
+ llm: BaseLLM,
15
+ agent: Optional[str] = None,
16
+ callback_manager: Optional[BaseCallbackManager] = None,
17
+ agent_path: Optional[str] = None,
18
+ prefix: Optional[str] = None,
19
+ suffix: Optional[str]= None,
20
+ ai_prefix: Optional[str] = None,
21
+ human_prefix: Optional[str] = None,
22
+ **kwargs: Any,
23
+ ) -> AgentExecutor:
24
+ """Load agent given tools and LLM.
25
+
26
+ Args:
27
+ tools: List of tools this agent has access to.
28
+ llm: Language model to use as the agent.
29
+ agent: The agent to use. Valid options are:
30
+ `zero-shot-react-description`
31
+ `react-docstore`
32
+ `self-ask-with-search`
33
+ `conversational-react-description`
34
+ If None and agent_path is also None, will default to
35
+ `zero-shot-react-description`.
36
+ callback_manager: CallbackManager to use. Global callback manager is used if
37
+ not provided. Defaults to None.
38
+ agent_path: Path to serialized agent to use.
39
+ **kwargs: Additional key word arguments to pass to the agent.
40
+
41
+ Returns:
42
+ An agent.
43
+ """
44
+ if agent is None and agent_path is None:
45
+ agent = "zero-shot-react-description"
46
+ if agent is not None and agent_path is not None:
47
+ raise ValueError(
48
+ "Both `agent` and `agent_path` are specified, "
49
+ "but at most only one should be."
50
+ )
51
+ if agent is not None:
52
+ if agent not in AGENT_TO_CLASS:
53
+ raise ValueError(
54
+ f"Got unknown agent type: {agent}. "
55
+ f"Valid types are: {AGENT_TO_CLASS.keys()}."
56
+ )
57
+ agent_cls = AGENT_TO_CLASS[agent]
58
+ agent_obj = agent_cls.from_llm_and_tools(
59
+ llm, tools, prefix=prefix, suffix=suffix, ai_prefix=ai_prefix, human_prefix=human_prefix, callback_manager=callback_manager # added prefix and suffix
60
+ )
61
+ elif agent_path is not None:
62
+ agent_obj = load_agent(
63
+ agent_path, llm=llm, tools=tools, callback_manager=callback_manager
64
+ )
65
+ else:
66
+ raise ValueError(
67
+ "Somehow both `agent` and `agent_path` are None, "
68
+ "this should never happen."
69
+ )
70
+ return AgentExecutor.from_agent_and_tools(
71
+ agent=agent_obj,
72
+ tools=tools,
73
+ callback_manager=callback_manager,
74
+ **kwargs,
75
+ )
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ faiss-cpu
2
+ langchain
3
+ openai
4
+ numpy
5
+ gradio
6
+ PyPDF2
7
+ python-dotenv
tools.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import faiss
2
+ import pickle
3
+ import os
4
+ from PyPDF2 import PdfReader
5
+ import glob
6
+ from pathlib import Path
7
+ import re
8
+ import requests
9
+
10
+ from langchain.chains import LLMChain
11
+ from langchain.llms import OpenAI
12
+ from langchain import PromptTemplate
13
+
14
+ from langchain.vectorstores import FAISS
15
+ from langchain.embeddings import OpenAIEmbeddings
16
+
17
+ import dotenv
18
+ dotenv.load_dotenv()
19
+
20
+
21
+ def call_semantic_api(query, store_path, k):
22
+ payload = {
23
+ "query": query,
24
+ "store_path": store_path,
25
+ "k": k,
26
+ }
27
+
28
+ # response = requests.post("http://localhost:3001/search", json=payload)
29
+ response = semantic_search.search(payload)
30
+ return response
31
+
32
+
33
+ class SemanticSearch():
34
+ def __init__(
35
+ self,
36
+ threshold: float,
37
+ with_source=False,
38
+ k=5,
39
+ ):
40
+ self.threshold = threshold
41
+ self.with_source = with_source
42
+ self.k = k
43
+
44
+ with open('./data/store.pkl', 'rb') as f:
45
+ self.db = pickle.load(f)
46
+
47
+ def __call__(self, query):
48
+
49
+ documents = self.db.similarity_search_with_score(query, k=self.k)
50
+ if len(documents) == 0:
51
+ return None
52
+
53
+ if not self.with_source:
54
+ output = '\n\n\n'.join([i[0].page_content for i in documents])
55
+ else:
56
+ output = '\n\n\n'.join([i[0].page_content + '\n\nSource:' + os.path.basename(
57
+ str(i[0].metadata['source']) + '\n') for i in documents])
58
+
59
+ return output
60
+
61
+
62
+ class ContentSearch():
63
+ def __init__(
64
+ self,
65
+ semantic_search,
66
+ prompt_template,
67
+ ):
68
+ self.semantic_search = semantic_search
69
+ self.prompt_template = prompt_template
70
+
71
+ def __call__(self, query):
72
+ content = self.semantic_search(query)
73
+ if content is None:
74
+ return "No results found"
75
+ else:
76
+ return self.prompt_template.format(content=content)