Spaces:
Running
Running
import os | |
from langchain_chroma import Chroma | |
from langchain.chains import ConversationalRetrievalChain | |
from langchain_openai import OpenAIEmbeddings, ChatOpenAI | |
from langchain_core.callbacks import StdOutCallbackHandler | |
from langchain.memory import ConversationBufferMemory | |
from dotenv import load_dotenv | |
from langchain_huggingface import HuggingFaceEmbeddings | |
from pydantic import SecretStr | |
CHROMA_DB_DIR = "./chroma_db" | |
# Load environment variables from .env file | |
load_dotenv() | |
# Get API credentials from environment variables | |
OPENAI_API_KEY = os.getenv("GROQ_API_KEY") | |
OPENAI_API_BASE = os.getenv("GROQ_API_BASE") | |
if not OPENAI_API_KEY: | |
raise ValueError( | |
"OPENAI_API_KEY not found in environment variables. Please check your .env file." | |
) | |
if not OPENAI_API_BASE: | |
raise ValueError( | |
"OPENAI_API_BASE not found in environment variables. Please check your .env file." | |
) | |
model_name = "sentence-transformers/all-mpnet-base-v2" | |
model_kwargs = {"device": "cpu"} | |
encode_kwargs = {"normalize_embeddings": False} | |
embeddings = HuggingFaceEmbeddings( | |
model_name=model_name, model_kwargs=model_kwargs, encode_kwargs=encode_kwargs | |
) | |
def get_qa_chain(): | |
vectordb = Chroma( | |
persist_directory=None, | |
embedding_function=embeddings, | |
collection_name="docs_collection", | |
) | |
print(f"Number of embedded documents: {vectordb._collection.count()}") | |
retriever = vectordb.as_retriever(search_kwargs={"k": 3}) | |
llm = ChatOpenAI( | |
model="llama-3.1-8b-instant", | |
api_key=SecretStr(OPENAI_API_KEY) if OPENAI_API_KEY else None, | |
base_url=OPENAI_API_BASE, | |
temperature=0.2, | |
) | |
memory = ConversationBufferMemory( | |
memory_key="chat_history", return_messages=True, output_key="answer" | |
) | |
conversation_chain = ConversationalRetrievalChain.from_llm( | |
llm=llm, retriever=retriever, memory=memory, callbacks=[StdOutCallbackHandler()] | |
) | |
return conversation_chain | |
def answer_question(question): | |
qa_chain = get_qa_chain() | |
result = qa_chain.invoke({"question": question}) | |
answer = result["answer"] | |
# Format the answer for better markdown display | |
formatted_answer = f""" | |
## Answer | |
{answer} | |
--- | |
*Generated using AI-powered document analysis* | |
""" | |
return formatted_answer | |