added refine summary chain
Browse files- app_modules/init.py +34 -30
- app_modules/llm_summarize_chain.py +20 -0
- summarize.py +70 -0
app_modules/init.py
CHANGED
@@ -23,55 +23,59 @@ load_dotenv(found_dotenv, override=False)
|
|
23 |
init_settings()
|
24 |
|
25 |
|
26 |
-
def app_init():
|
27 |
# https://github.com/huggingface/transformers/issues/17611
|
28 |
os.environ["CURL_CA_BUNDLE"] = ""
|
29 |
|
|
|
|
|
|
|
30 |
hf_embeddings_device_type, hf_pipeline_device_type = get_device_types()
|
31 |
print(f"hf_embeddings_device_type: {hf_embeddings_device_type}")
|
32 |
print(f"hf_pipeline_device_type: {hf_pipeline_device_type}")
|
33 |
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
n_threds = int(os.environ.get("NUMBER_OF_CPU_CORES") or "4")
|
39 |
-
index_path = os.environ.get("FAISS_INDEX_PATH") or os.environ.get(
|
40 |
-
"CHROMADB_INDEX_PATH"
|
41 |
-
)
|
42 |
-
using_faiss = os.environ.get("FAISS_INDEX_PATH") is not None
|
43 |
-
llm_model_type = os.environ.get("LLM_MODEL_TYPE")
|
44 |
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
)
|
50 |
-
end = timer()
|
51 |
|
52 |
-
|
|
|
|
|
|
|
|
|
|
|
53 |
|
54 |
-
|
55 |
|
56 |
-
|
57 |
|
58 |
-
|
59 |
-
|
60 |
-
elif using_faiss:
|
61 |
-
vectorstore = FAISS.load_local(index_path, embeddings)
|
62 |
-
else:
|
63 |
-
vectorstore = Chroma(
|
64 |
-
embedding_function=embeddings, persist_directory=index_path
|
65 |
)
|
66 |
|
67 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
|
69 |
-
|
|
|
|
|
70 |
|
71 |
start = timer()
|
72 |
llm_loader = LLMLoader(llm_model_type)
|
73 |
llm_loader.init(n_threds=n_threds, hf_pipeline_device_type=hf_pipeline_device_type)
|
74 |
-
qa_chain = QAChain(vectorstore, llm_loader)
|
75 |
end = timer()
|
76 |
print(f"Completed in {end - start:.3f}s")
|
77 |
|
|
|
23 |
init_settings()
|
24 |
|
25 |
|
26 |
+
def app_init(initQAChain: bool = True):
|
27 |
# https://github.com/huggingface/transformers/issues/17611
|
28 |
os.environ["CURL_CA_BUNDLE"] = ""
|
29 |
|
30 |
+
llm_model_type = os.environ.get("LLM_MODEL_TYPE")
|
31 |
+
n_threds = int(os.environ.get("NUMBER_OF_CPU_CORES") or "4")
|
32 |
+
|
33 |
hf_embeddings_device_type, hf_pipeline_device_type = get_device_types()
|
34 |
print(f"hf_embeddings_device_type: {hf_embeddings_device_type}")
|
35 |
print(f"hf_pipeline_device_type: {hf_pipeline_device_type}")
|
36 |
|
37 |
+
if initQAChain:
|
38 |
+
hf_embeddings_model_name = (
|
39 |
+
os.environ.get("HF_EMBEDDINGS_MODEL_NAME") or "hkunlp/instructor-xl"
|
40 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
|
42 |
+
index_path = os.environ.get("FAISS_INDEX_PATH") or os.environ.get(
|
43 |
+
"CHROMADB_INDEX_PATH"
|
44 |
+
)
|
45 |
+
using_faiss = os.environ.get("FAISS_INDEX_PATH") is not None
|
|
|
|
|
46 |
|
47 |
+
start = timer()
|
48 |
+
embeddings = HuggingFaceInstructEmbeddings(
|
49 |
+
model_name=hf_embeddings_model_name,
|
50 |
+
model_kwargs={"device": hf_embeddings_device_type},
|
51 |
+
)
|
52 |
+
end = timer()
|
53 |
|
54 |
+
print(f"Completed in {end - start:.3f}s")
|
55 |
|
56 |
+
start = timer()
|
57 |
|
58 |
+
print(
|
59 |
+
f"Load index from {index_path} with {'FAISS' if using_faiss else 'Chroma'}"
|
|
|
|
|
|
|
|
|
|
|
60 |
)
|
61 |
|
62 |
+
if not os.path.isdir(index_path):
|
63 |
+
raise ValueError(f"{index_path} does not exist!")
|
64 |
+
elif using_faiss:
|
65 |
+
vectorstore = FAISS.load_local(index_path, embeddings)
|
66 |
+
else:
|
67 |
+
vectorstore = Chroma(
|
68 |
+
embedding_function=embeddings, persist_directory=index_path
|
69 |
+
)
|
70 |
|
71 |
+
end = timer()
|
72 |
+
|
73 |
+
print(f"Completed in {end - start:.3f}s")
|
74 |
|
75 |
start = timer()
|
76 |
llm_loader = LLMLoader(llm_model_type)
|
77 |
llm_loader.init(n_threds=n_threds, hf_pipeline_device_type=hf_pipeline_device_type)
|
78 |
+
qa_chain = QAChain(vectorstore, llm_loader) if initQAChain else None
|
79 |
end = timer()
|
80 |
print(f"Completed in {end - start:.3f}s")
|
81 |
|
app_modules/llm_summarize_chain.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import List, Optional
|
3 |
+
|
4 |
+
from langchain.chains.base import Chain
|
5 |
+
from langchain.chains.summarize import load_summarize_chain
|
6 |
+
|
7 |
+
from app_modules.llm_inference import LLMInference
|
8 |
+
|
9 |
+
|
10 |
+
class SummarizeChain(LLMInference):
|
11 |
+
def __init__(self, llm_loader):
|
12 |
+
super().__init__(llm_loader)
|
13 |
+
|
14 |
+
def create_chain(self) -> Chain:
|
15 |
+
chain = load_summarize_chain(self.llm_loader.llm, chain_type="refine")
|
16 |
+
return chain
|
17 |
+
|
18 |
+
def run_chain(self, chain, inputs, callbacks: Optional[List] = []):
|
19 |
+
result = chain(inputs, return_only_outputs=True)
|
20 |
+
return result
|
summarize.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# setting device on GPU if available, else CPU
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
from timeit import default_timer as timer
|
5 |
+
from typing import List
|
6 |
+
|
7 |
+
from langchain.document_loaders import PyPDFDirectoryLoader
|
8 |
+
from langchain.embeddings import HuggingFaceInstructEmbeddings
|
9 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
10 |
+
from langchain.vectorstores.base import VectorStore
|
11 |
+
from langchain.vectorstores.chroma import Chroma
|
12 |
+
from langchain.vectorstores.faiss import FAISS
|
13 |
+
|
14 |
+
from app_modules.init import app_init, get_device_types
|
15 |
+
from app_modules.llm_summarize_chain import SummarizeChain
|
16 |
+
|
17 |
+
|
18 |
+
def load_documents(source_pdfs_path, urls) -> List:
|
19 |
+
loader = PyPDFDirectoryLoader(source_pdfs_path, silent_errors=True)
|
20 |
+
documents = loader.load()
|
21 |
+
if urls is not None and len(urls) > 0:
|
22 |
+
for doc in documents:
|
23 |
+
source = doc.metadata["source"]
|
24 |
+
filename = source.split("/")[-1]
|
25 |
+
for url in urls:
|
26 |
+
if url.endswith(filename):
|
27 |
+
doc.metadata["url"] = url
|
28 |
+
break
|
29 |
+
return documents
|
30 |
+
|
31 |
+
|
32 |
+
def split_chunks(documents: List, chunk_size, chunk_overlap) -> List:
|
33 |
+
text_splitter = RecursiveCharacterTextSplitter(
|
34 |
+
chunk_size=chunk_size, chunk_overlap=chunk_overlap
|
35 |
+
)
|
36 |
+
return text_splitter.split_documents(documents)
|
37 |
+
|
38 |
+
|
39 |
+
llm_loader = app_init(False)[0]
|
40 |
+
|
41 |
+
source_pdfs_path = (
|
42 |
+
sys.argv[1] if len(sys.argv) > 1 else os.environ.get("SOURCE_PDFS_PATH")
|
43 |
+
)
|
44 |
+
chunk_size = os.environ.get("CHUNCK_SIZE")
|
45 |
+
chunk_overlap = os.environ.get("CHUNK_OVERLAP")
|
46 |
+
|
47 |
+
sources = load_documents(source_pdfs_path, None)
|
48 |
+
|
49 |
+
print(f"Splitting {len(sources)} PDF pages in to chunks ...")
|
50 |
+
|
51 |
+
chunks = split_chunks(
|
52 |
+
sources, chunk_size=int(chunk_size), chunk_overlap=int(chunk_overlap)
|
53 |
+
)
|
54 |
+
|
55 |
+
print(f"Summarizing {len(chunks)} chunks ...")
|
56 |
+
start = timer()
|
57 |
+
|
58 |
+
summarize_chain = SummarizeChain(llm_loader)
|
59 |
+
result = summarize_chain.call_chain(
|
60 |
+
{"input_documents": chunks},
|
61 |
+
None,
|
62 |
+
None,
|
63 |
+
True,
|
64 |
+
)
|
65 |
+
|
66 |
+
end = timer()
|
67 |
+
print(f"Completed in {end - start:.3f}s")
|
68 |
+
|
69 |
+
print("\n\n***Summary:")
|
70 |
+
print(result["output_text"])
|