Doux Thibault
add llm to front + api key in dot env
025e412
raw
history blame
3.22 kB
import os
os.environ['TOKENIZERS_PARALLELISM'] = 'true'
from dotenv import load_dotenv
load_dotenv() # load .env api keys
mistral_api_key = os.getenv("MISTRAL_API_KEY")
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import WebBaseLoader
from langchain_community.vectorstores import Chroma, FAISS
from langchain_mistralai import MistralAIEmbeddings
from langchain import hub
from typing import Literal
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_mistralai import ChatMistralAI
from langchain.embeddings.sentence_transformer import SentenceTransformerEmbeddings
from langchain_community.tools import DuckDuckGoSearchRun
def load_chunk_persist_pdf() -> Chroma:
pdf_folder_path = "data/pdf_folder/"
documents = []
for file in os.listdir(pdf_folder_path):
if file.endswith('.pdf'):
pdf_path = os.path.join(pdf_folder_path, file)
loader = PyPDFLoader(pdf_path)
documents.extend(loader.load())
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=10)
chunked_documents = text_splitter.split_documents(documents)
os.makedirs("data/chroma_store/", exist_ok=True)
vectorstore = Chroma.from_documents(
documents=chunked_documents,
embedding=MistralAIEmbeddings(),
persist_directory="data/chroma_store/"
)
vectorstore.persist()
return vectorstore
vectorstore = load_chunk_persist_pdf()
retriever = vectorstore.as_retriever()
prompt = hub.pull("rlm/rag-prompt")
# Data model
class RouteQuery(BaseModel):
"""Route a user query to the most relevant datasource."""
datasource: Literal["vectorstore", "websearch"] = Field(
...,
description="Given a user question choose to route it to web search or a vectorstore.",
)
# LLM with function call
llm = ChatMistralAI(model="mistral-large-latest", mistral_api_key=mistral_api_key, temperature=0)
# structured_llm_router = llm.with_structured_output(RouteQuery, method="json_mode")
# Prompt
system = """You are an expert at routing a user question to a vectorstore or web search.
The vectorstore contains documents related to agents, prompt engineering, and adversarial attacks.
Use the vectorstore for questions on these topics. For all else, use web-search."""
route_prompt = ChatPromptTemplate.from_messages(
[
("system", system),
("human", "{question}"),
]
)
prompt = hub.pull("rlm/rag-prompt")
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
def format_docs(docs):
return "\n\n".join(doc.page_content for doc in docs)
rag_chain = (
{"context": retriever | format_docs, "question": RunnablePassthrough()}
| prompt
| llm
| StrOutputParser()
)
print(rag_chain.invoke("Build a fitness program for me. Be precise in terms of exercises"))
# print(rag_chain.invoke("I am a 45 years old woman and I have to loose weight for the summer. Provide me with a fitness program"))