medico-bot / src /bot /bot.py
pasupuletkarthiksai's picture
Update src/bot/bot.py
29865e5 verified
import os
import toml
from typing import Optional
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_community.vectorstores import FAISS
from langchain_openai import OpenAIEmbeddings
from langchain_core.runnables import RunnableParallel, RunnablePassthrough, RunnableLambda
import logging
from langchain_groq import ChatGroq
from src.bot.extract_metadata import Metadata
# Configure logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
class Medibot:
def __init__(self, config_path: str = "src/bot/configs/prompt.toml",
metadata_database: str = "database/metadata.csv",
faiss_database: str = "database/faiss_index"
):
"""Initialize Medibot with configuration and Groq client."""
# Load environment variables
api_key = os.environ.get("GROQ_API_KEY")
if not api_key:
logger.error("GROQ_API_KEY not found in environment variables")
raise ValueError("GROQ_API_KEY is required")
os.environ["OPENAI_API_KEY"] = os.environ.get("OPENAI_API_KEY")
# Load prompt configuration
try:
config = toml.load(config_path)
system_prompt = config["rag_prompt"]["system_prompt"]
user_prompt_template = config["rag_prompt"]["user_prompt_template"]
except (FileNotFoundError, toml.TomlDecodeError) as e:
logger.error(f"Failed to load config from {config_path}: {e}")
raise
# Initialize prompt template
self.prompt_template = ChatPromptTemplate.from_messages([
("system", system_prompt),
("user", user_prompt_template)
])
# initialize vector database
embeddings = OpenAIEmbeddings(model="text-embedding-3-large")
vector_store = FAISS.load_local(
faiss_database, embeddings, allow_dangerous_deserialization=True
)
self.retriever = vector_store.as_retriever(search_type="mmr", search_kwargs={"k": 10})
# Initialize Groq client
self.model = ChatGroq(
model="llama-3.1-8b-instant",
temperature=0.2,
max_tokens=None,
timeout=None,
max_retries=2,
)
self.metadata_extactor = Metadata(metadata_database)
def query(self, question: str) -> str:
retrieved_docs = self.retriever.invoke(question)
# RunnableParallel({"context": retriever, "question": RunnablePassthrough()})
rag_chain = (
RunnableParallel({
"context": RunnableLambda(lambda _: retrieved_docs), # Reuse retrieved docs
"question": RunnablePassthrough()
})
| self.prompt_template
| self.model
| StrOutputParser()
)
answer = rag_chain.invoke({"question": question})
refered_tables , refered_images = self.metadata_extactor.get_data_from_ref(retrieved_docs)
return answer, retrieved_docs, refered_tables , refered_images