Spaces:
Sleeping
Sleeping
import os | |
os.environ['TOKENIZERS_PARALLELISM'] = 'true' | |
from dotenv import load_dotenv | |
load_dotenv() # load .env api keys | |
mistral_api_key = os.getenv("MISTRAL_API_KEY") | |
from langchain_community.document_loaders import PyPDFLoader | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain_community.document_loaders import WebBaseLoader | |
from langchain_community.vectorstores import Chroma, FAISS | |
from langchain_mistralai import MistralAIEmbeddings | |
from langchain import hub | |
from typing import Literal | |
from langchain_core.prompts import ChatPromptTemplate | |
from langchain_core.pydantic_v1 import BaseModel, Field | |
from langchain_mistralai import ChatMistralAI | |
from langchain.embeddings.sentence_transformer import SentenceTransformerEmbeddings | |
from langchain_community.tools import DuckDuckGoSearchRun | |
def load_chunk_persist_pdf() -> Chroma: | |
pdf_folder_path = "data/pdf_folder/" | |
documents = [] | |
for file in os.listdir(pdf_folder_path): | |
if file.endswith('.pdf'): | |
pdf_path = os.path.join(pdf_folder_path, file) | |
loader = PyPDFLoader(pdf_path) | |
documents.extend(loader.load()) | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=10) | |
chunked_documents = text_splitter.split_documents(documents) | |
os.makedirs("data/chroma_store/", exist_ok=True) | |
vectorstore = Chroma.from_documents( | |
documents=chunked_documents, | |
embedding=MistralAIEmbeddings(), | |
persist_directory="data/chroma_store/" | |
) | |
vectorstore.persist() | |
return vectorstore | |
vectorstore = load_chunk_persist_pdf() | |
retriever = vectorstore.as_retriever() | |
prompt = hub.pull("rlm/rag-prompt") | |
# Data model | |
class RouteQuery(BaseModel): | |
"""Route a user query to the most relevant datasource.""" | |
datasource: Literal["vectorstore", "websearch"] = Field( | |
..., | |
description="Given a user question choose to route it to web search or a vectorstore.", | |
) | |
# LLM with function call | |
llm = ChatMistralAI(model="mistral-large-latest", mistral_api_key=mistral_api_key, temperature=0) | |
# structured_llm_router = llm.with_structured_output(RouteQuery, method="json_mode") | |
# Prompt | |
system = """You are an expert at routing a user question to a vectorstore or web search. | |
The vectorstore contains documents related to agents, prompt engineering, and adversarial attacks. | |
Use the vectorstore for questions on these topics. For all else, use web-search.""" | |
route_prompt = ChatPromptTemplate.from_messages( | |
[ | |
("system", system), | |
("human", "{question}"), | |
] | |
) | |
prompt = hub.pull("rlm/rag-prompt") | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain_core.runnables import RunnablePassthrough | |
def format_docs(docs): | |
return "\n\n".join(doc.page_content for doc in docs) | |
rag_chain = ( | |
{"context": retriever | format_docs, "question": RunnablePassthrough()} | |
| prompt | |
| llm | |
| StrOutputParser() | |
) | |
print(rag_chain.invoke("Build a fitness program for me. Be precise in terms of exercises")) | |
# print(rag_chain.invoke("I am a 45 years old woman and I have to loose weight for the summer. Provide me with a fitness program")) |