cricket_bot / utils.py
Samagra07's picture
Upload 3 files
ea60252 verified
from langchain import hub
from langchain_community.vectorstores.chroma import Chroma
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_huggingface.embeddings import HuggingFaceEndpointEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_openai import ChatOpenAI
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_community.document_loaders import TextLoader
from langchain.prompts import PromptTemplate
from langchain_core.messages import HumanMessage
from langchain.chains.history_aware_retriever import create_history_aware_retriever
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain.chains.retrieval import create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
import os
from dotenv import load_dotenv
load_dotenv()
prompt = hub.pull('rlm/rag-prompt')
loader = TextLoader("dataset.txt")
docs = loader.load()
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1000, chunk_overlap=200, add_start_index=True
)
all_splits = text_splitter.split_documents(docs)
embeddings = HuggingFaceEndpointEmbeddings(huggingfacehub_api_token=os.getenv("HUGGINGFACEHUB_API_TOKEN"))
vector_store = Chroma.from_documents(documents=all_splits, embedding=embeddings)
retriever = vector_store.as_retriever(search_type="similarity", search_kwargs={"k": 6})
llm = ChatOpenAI(temperature=0.5, model="meta-llama/llama-3-8b-instruct:free",base_url="https://openrouter.ai/api/v1",api_key=os.getenv("OPENAI_API_KEY"))
system_prompt = """Given a chat history and the latest user question \
which might refrence context in the chat history, formulate a standalone question \
which can be understood without the chat history. Do NOT answer the question, \
just reformulate it if needed and otherwise return as it is."""
prompt_template = ChatPromptTemplate.from_messages(
[
(
"system", system_prompt
),
MessagesPlaceholder("chat_history"),
("human", "{input}"),
]
)
history_aware_retriever = create_history_aware_retriever(
llm,
retriever,
prompt_template
)
qa_system_prompt = """You are a assistant for question-answering tasks. \
Use the following pieces of retrieved context to answer the question. \
If you don't know the answer, just say that you don't know. \
Use three sentences maximum and keep the answer concise.\
{context}"""
qa_prompt = ChatPromptTemplate.from_messages(
[
(
"system", qa_system_prompt
),
MessagesPlaceholder("chat_history"),
("human", "{input}"),
]
)
question_answer_chain = create_stuff_documents_chain(llm, qa_prompt)
rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain)