Spaces:
Runtime error
Runtime error
# setting device on GPU if available, else CPU | |
import os | |
import sys | |
from timeit import default_timer as timer | |
from typing import List | |
from langchain.document_loaders import PyPDFDirectoryLoader | |
from langchain.embeddings import HuggingFaceInstructEmbeddings | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain.vectorstores.base import VectorStore | |
from langchain.vectorstores.chroma import Chroma | |
from langchain.vectorstores.faiss import FAISS | |
from app_modules.init import app_init, get_device_types | |
from app_modules.llm_summarize_chain import SummarizeChain | |
def load_documents(source_pdfs_path, keep_page_info) -> List: | |
loader = PyPDFDirectoryLoader(source_pdfs_path, silent_errors=True) | |
documents = loader.load() | |
if not keep_page_info: | |
for doc in documents: | |
if doc is not documents[0]: | |
documents[0].page_content = ( | |
documents[0].page_content + "\n" + doc.page_content | |
) | |
documents = [documents[0]] | |
return documents | |
def split_chunks(documents: List, chunk_size, chunk_overlap) -> List: | |
text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=chunk_size, chunk_overlap=chunk_overlap | |
) | |
return text_splitter.split_documents(documents) | |
llm_loader = app_init(False)[0] | |
source_pdfs_path = ( | |
sys.argv[1] if len(sys.argv) > 1 else os.environ.get("SOURCE_PDFS_PATH") | |
) | |
chunk_size = sys.argv[2] if len(sys.argv) > 2 else os.environ.get("CHUNCK_SIZE") | |
chunk_overlap = sys.argv[3] if len(sys.argv) > 3 else os.environ.get("CHUNK_OVERLAP") | |
keep_page_info = ( | |
sys.argv[3] if len(sys.argv) > 3 else os.environ.get("KEEP_PAGE_INFO") | |
) == "true" | |
sources = load_documents(source_pdfs_path, keep_page_info) | |
print(f"Splitting {len(sources)} documents in to chunks ...") | |
chunks = split_chunks( | |
sources, chunk_size=int(chunk_size), chunk_overlap=int(chunk_overlap) | |
) | |
print(f"Summarizing {len(chunks)} chunks ...") | |
start = timer() | |
summarize_chain = SummarizeChain(llm_loader) | |
result = summarize_chain.call_chain( | |
{"input_documents": chunks}, | |
None, | |
None, | |
True, | |
) | |
end = timer() | |
total_time = end - start | |
print("\n\n***Summary:") | |
print(result["output_text"]) | |
print(f"Total time used: {total_time:.3f} s") | |
print(f"Number of tokens generated: {llm_loader.streamer.total_tokens}") | |
print( | |
f"Average generation speed: {llm_loader.streamer.total_tokens / total_time:.3f} tokens/s" | |
) | |