Spaces:
Running
Running
File size: 2,309 Bytes
1cfcd72 7340cb6 1cfcd72 7340cb6 1cfcd72 7340cb6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 |
import os
from langchain_chroma import Chroma
from langchain.chains import ConversationalRetrievalChain
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from langchain_core.callbacks import StdOutCallbackHandler
from langchain.memory import ConversationBufferMemory
from dotenv import load_dotenv
from langchain_huggingface import HuggingFaceEmbeddings
from pydantic import SecretStr
CHROMA_DB_DIR = "./chroma_db"
# Load environment variables from .env file
load_dotenv()
# Get API credentials from environment variables
OPENAI_API_KEY = os.getenv("GROQ_API_KEY")
OPENAI_API_BASE = os.getenv("GROQ_API_BASE")
if not OPENAI_API_KEY:
raise ValueError(
"OPENAI_API_KEY not found in environment variables. Please check your .env file."
)
if not OPENAI_API_BASE:
raise ValueError(
"OPENAI_API_BASE not found in environment variables. Please check your .env file."
)
model_name = "sentence-transformers/all-mpnet-base-v2"
model_kwargs = {"device": "cpu"}
encode_kwargs = {"normalize_embeddings": False}
embeddings = HuggingFaceEmbeddings(
model_name=model_name, model_kwargs=model_kwargs, encode_kwargs=encode_kwargs
)
def get_qa_chain():
vectordb = Chroma(
persist_directory=None,
embedding_function=embeddings,
collection_name="docs_collection",
)
print(f"Number of embedded documents: {vectordb._collection.count()}")
retriever = vectordb.as_retriever(search_kwargs={"k": 3})
llm = ChatOpenAI(
model="llama-3.1-8b-instant",
api_key=SecretStr(OPENAI_API_KEY) if OPENAI_API_KEY else None,
base_url=OPENAI_API_BASE,
temperature=0.2,
)
memory = ConversationBufferMemory(
memory_key="chat_history", return_messages=True, output_key="answer"
)
conversation_chain = ConversationalRetrievalChain.from_llm(
llm=llm, retriever=retriever, memory=memory, callbacks=[StdOutCallbackHandler()]
)
return conversation_chain
def answer_question(question):
qa_chain = get_qa_chain()
result = qa_chain.invoke({"question": question})
answer = result["answer"]
# Format the answer for better markdown display
formatted_answer = f"""
## Answer
{answer}
---
*Generated using AI-powered document analysis*
"""
return formatted_answer
|