|
|
|
import os |
|
import sys |
|
from timeit import default_timer as timer |
|
from typing import List |
|
|
|
from langchain.document_loaders import PyPDFDirectoryLoader |
|
from langchain.embeddings import HuggingFaceInstructEmbeddings |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from langchain.vectorstores.base import VectorStore |
|
from langchain.vectorstores.chroma import Chroma |
|
from langchain.vectorstores.faiss import FAISS |
|
|
|
from app_modules.init import app_init, get_device_types |
|
from app_modules.llm_summarize_chain import SummarizeChain |
|
|
|
|
|
def load_documents(source_pdfs_path, keep_page_info) -> List: |
|
loader = PyPDFDirectoryLoader(source_pdfs_path, silent_errors=True) |
|
documents = loader.load() |
|
if not keep_page_info: |
|
for doc in documents: |
|
if doc is not documents[0]: |
|
documents[0].page_content = ( |
|
documents[0].page_content + "\n" + doc.page_content |
|
) |
|
documents = [documents[0]] |
|
return documents |
|
|
|
|
|
def split_chunks(documents: List, chunk_size, chunk_overlap) -> List: |
|
text_splitter = RecursiveCharacterTextSplitter( |
|
chunk_size=chunk_size, chunk_overlap=chunk_overlap |
|
) |
|
return text_splitter.split_documents(documents) |
|
|
|
|
|
llm_loader = app_init(False)[0] |
|
|
|
source_pdfs_path = ( |
|
sys.argv[1] if len(sys.argv) > 1 else os.environ.get("SOURCE_PDFS_PATH") |
|
) |
|
chunk_size = sys.argv[2] if len(sys.argv) > 2 else os.environ.get("CHUNCK_SIZE") |
|
chunk_overlap = sys.argv[3] if len(sys.argv) > 3 else os.environ.get("CHUNK_OVERLAP") |
|
keep_page_info = ( |
|
sys.argv[3] if len(sys.argv) > 3 else os.environ.get("KEEP_PAGE_INFO") |
|
) == "true" |
|
|
|
sources = load_documents(source_pdfs_path, keep_page_info) |
|
|
|
print(f"Splitting {len(sources)} documents in to chunks ...") |
|
|
|
chunks = split_chunks( |
|
sources, chunk_size=int(chunk_size), chunk_overlap=int(chunk_overlap) |
|
) |
|
|
|
print(f"Summarizing {len(chunks)} chunks ...") |
|
start = timer() |
|
|
|
summarize_chain = SummarizeChain(llm_loader) |
|
result = summarize_chain.call_chain( |
|
{"input_documents": chunks}, |
|
None, |
|
None, |
|
True, |
|
) |
|
|
|
end = timer() |
|
total_time = end - start |
|
|
|
print("\n\n***Summary:") |
|
print(result["output_text"]) |
|
|
|
print(f"Total time used: {total_time:.3f} s") |
|
print(f"Number of tokens generated: {llm_loader.streamer.total_tokens}") |
|
print( |
|
f"Average generation speed: {llm_loader.streamer.total_tokens / total_time:.3f} tokens/s" |
|
) |
|
|