|
import os, requests, tiktoken |
|
|
|
from llama_hub.youtube_transcript import YoutubeTranscriptReader |
|
from llama_index import download_loader, PromptTemplate, ServiceContext |
|
from llama_index.callbacks import CallbackManager, TokenCountingHandler |
|
from llama_index.embeddings import OpenAIEmbedding |
|
from llama_index.indices.vector_store.base import VectorStoreIndex |
|
from llama_index.llms import OpenAI |
|
from llama_index.prompts import PromptTemplate |
|
from llama_index.storage.storage_context import StorageContext |
|
from llama_index.vector_stores.mongodb import MongoDBAtlasVectorSearch |
|
|
|
from pathlib import Path |
|
from pymongo import MongoClient |
|
from rag_base import BaseRAG |
|
|
|
class LlamaIndexRAG(BaseRAG): |
|
MONGODB_DB_NAME = "llamaindex_db" |
|
|
|
def load_documents(self): |
|
docs = [] |
|
|
|
|
|
PDFReader = download_loader("PDFReader") |
|
loader = PDFReader() |
|
out_dir = Path("data") |
|
|
|
if not out_dir.exists(): |
|
os.makedirs(out_dir) |
|
|
|
out_path = out_dir / "gpt-4.pdf" |
|
|
|
if not out_path.exists(): |
|
r = requests.get(self.PDF_URL) |
|
with open(out_path, "wb") as f: |
|
f.write(r.content) |
|
|
|
docs.extend(loader.load_data(file = Path(out_path))) |
|
|
|
|
|
|
|
SimpleWebPageReader = download_loader("SimpleWebPageReader") |
|
loader = SimpleWebPageReader() |
|
docs.extend(loader.load_data(urls = [self.WEB_URL])) |
|
|
|
|
|
|
|
loader = YoutubeTranscriptReader() |
|
docs.extend(loader.load_data(ytlinks = [self.YOUTUBE_URL_1, |
|
self.YOUTUBE_URL_2])) |
|
|
|
|
|
return docs |
|
|
|
def get_callback_manager(self, config): |
|
token_counter = TokenCountingHandler( |
|
tokenizer = tiktoken.encoding_for_model(config["model_name"]).encode |
|
) |
|
|
|
token_counter.reset_counts() |
|
|
|
return CallbackManager([token_counter]) |
|
|
|
def get_callback(self, token_counter): |
|
return ("Tokens Used: " + |
|
str(token_counter.total_llm_token_count) + "\n" + |
|
"Prompt Tokens: " + |
|
str(token_counter.prompt_llm_token_count) + "\n" + |
|
"Completion Tokens: " + |
|
str(token_counter.completion_llm_token_count)) |
|
|
|
def get_llm(self, config): |
|
return OpenAI( |
|
model = config["model_name"], |
|
temperature = config["temperature"] |
|
) |
|
|
|
def get_vector_store(self): |
|
return MongoDBAtlasVectorSearch( |
|
MongoClient(self.MONGODB_ATLAS_CLUSTER_URI), |
|
db_name = self.MONGODB_DB_NAME, |
|
collection_name = self.MONGODB_COLLECTION_NAME, |
|
index_name = self.MONGODB_INDEX_NAME |
|
) |
|
|
|
def get_service_context(self, config): |
|
return ServiceContext.from_defaults( |
|
callback_manager = self.get_callback_manager(config), |
|
chunk_overlap = config["chunk_overlap"], |
|
chunk_size = config["chunk_size"], |
|
embed_model = OpenAIEmbedding(), |
|
llm = self.get_llm(config) |
|
) |
|
|
|
def get_storage_context(self): |
|
return StorageContext.from_defaults( |
|
vector_store = self.get_vector_store() |
|
) |
|
|
|
def store_documents(self, config, docs): |
|
storage_context = StorageContext.from_defaults( |
|
vector_store = self.get_vector_store() |
|
) |
|
|
|
VectorStoreIndex.from_documents( |
|
docs, |
|
service_context = self.get_service_context(config), |
|
storage_context = self.get_storage_context() |
|
) |
|
|
|
def ingestion(self, config): |
|
docs = self.load_documents() |
|
|
|
self.store_documents(config, docs) |
|
|
|
def retrieval(self, config, prompt): |
|
index = VectorStoreIndex.from_vector_store( |
|
vector_store = self.get_vector_store() |
|
) |
|
|
|
service_context = self.get_service_context(config) |
|
|
|
query_engine = index.as_query_engine( |
|
service_context = service_context, |
|
similarity_top_k = config["k"], |
|
text_qa_template = PromptTemplate(os.environ["LLAMAINDEX_TEMPLATE"]) |
|
) |
|
|
|
completion = query_engine.query(prompt) |
|
callback = self.get_callback( |
|
service_context.callback_manager.handlers[0]) |
|
|
|
return completion, callback |