Spaces:
Sleeping
Sleeping
import os | |
import logging | |
from pathlib import Path | |
from typing import Optional | |
from dotenv import load_dotenv | |
from langchain_community.vectorstores import FAISS | |
from langchain_huggingface import HuggingFaceEmbeddings | |
from langchain_openai import ChatOpenAI | |
from langchain.chains import RetrievalQA | |
from langchain.prompts import PromptTemplate | |
from langchain.retrievers import MergerRetriever | |
# Load environment variables from .env file | |
load_dotenv() | |
# Retrieve the OpenAI API key from environment variables | |
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") | |
# Configure logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
) | |
logger = logging.getLogger(__name__) | |
def load_faiss_index(folder_path: str, model_name: str) -> Optional[FAISS]: | |
""" | |
Load a FAISS index with a specific embedding model. | |
Args: | |
folder_path: Path to the FAISS index folder | |
model_name: Name of the HuggingFace embedding model | |
Returns: | |
FAISS: Loaded FAISS index object | |
Raises: | |
ValueError: If the folder path doesn't exist | |
""" | |
try: | |
if not os.path.exists(folder_path): | |
raise ValueError(f"FAISS index folder not found: {folder_path}") | |
logger.info(f"Loading FAISS index from {folder_path}") | |
embeddings = HuggingFaceEmbeddings(model_name=model_name) | |
return FAISS.load_local( | |
folder_path=folder_path, | |
embeddings=embeddings, | |
allow_dangerous_deserialization=True | |
) | |
except Exception as e: | |
logger.error(f"Error loading FAISS index: {str(e)}") | |
raise | |
def generate_answer(query: str) -> str: | |
""" | |
Generate an answer for the given query using RAG. | |
Args: | |
query: User's question | |
Returns: | |
str: Generated answer | |
Raises: | |
ValueError: If query is empty or required files are missing | |
""" | |
try: | |
if not query or not query.strip(): | |
raise ValueError("Query cannot be empty") | |
# Get the current directory and construct paths | |
current_dir = Path(__file__).parent | |
vectors_dir = current_dir / "vectors_data" | |
# Validate vectors directory exists | |
if not vectors_dir.exists(): | |
raise ValueError(f"Vectors directory not found at {vectors_dir}") | |
# Load FAISS indices | |
logger.info("Loading FAISS indices...") | |
data_vec = load_faiss_index( | |
str(vectors_dir / "faiss_v4"), | |
"sentence-transformers/all-MiniLM-L12-v2" | |
) | |
# Create the LLM instance | |
llm = ChatOpenAI( | |
model="gpt-4o-mini", | |
temperature=0, | |
openai_api_key=OPENAI_API_KEY | |
) | |
template = """You are a knowledgeable and approachable medical information assistant. Use the context provided to answer the medical question at the end. Follow these guidelines to ensure a clear, user-friendly, and professional response: | |
Important Guidelines: | |
1. **Clarity and Accessibility:** | |
- Write in simple, understandable language suitable for a general audience. | |
- Explain any technical terms briefly, if used. | |
2. **Structure:** | |
- Use clear paragraphs or bullet points for organization. | |
- Start with a concise summary of the issue before diving into details. | |
3. **Accuracy and Reliability:** | |
- Base your response strictly on the context provided. | |
- If you cannot provide an answer based on the context, state this honestly. | |
4. **Medical Safety and Disclaimers:** | |
- Include a disclaimer emphasizing the need to consult a healthcare professional for a personalized diagnosis or treatment plan. | |
5. **Treatment Information (if applicable):** | |
- Clearly outline treatment options, including: | |
- Drug name | |
- Drug class | |
- Dosage | |
- Frequency and duration | |
- Potential side effects | |
- Risks and additional recommendations | |
- Specify that these options are general and should be discussed with a healthcare provider. | |
6. **Encourage Engagement:** | |
- Invite users to ask clarifying questions or provide additional details for a more tailored response. | |
Context: {context} | |
Question: {question} | |
Medical Information Assistant:""" | |
QA_CHAIN_PROMPT = PromptTemplate( | |
input_variables=["context", "question"], | |
template=template | |
) | |
# Initialize and combine retrievers | |
logger.info("Setting up retrieval chain...") | |
data_retriever = data_vec.as_retriever() | |
# Initialize the RetrievalQA chain | |
qa_chain = RetrievalQA.from_chain_type( | |
llm=llm, | |
retriever=data_retriever, | |
return_source_documents=True, | |
chain_type_kwargs={"prompt": QA_CHAIN_PROMPT} | |
) | |
# Run the chain | |
logger.info("Generating answer...") | |
result = qa_chain.invoke({"query": query}) | |
logger.info("Answer generated successfully") | |
# Extracting the relevant documents from the result | |
extracted_docs = result.get("source_documents", []) | |
logger.info(f"Extracted documents: {extracted_docs}") # Log the extracted documents | |
return result["result"] | |
except Exception as e: | |
logger.error(f"Error generating answer: {str(e)}") | |
raise | |
def main(): | |
""" | |
Main function to demonstrate the usage of the RAG system. | |
""" | |
try: | |
# Example usage | |
query = "suggest me some medicine for bronchitis" | |
logger.info(f"Processing query: {query}") | |
response = generate_answer(query) | |
print("\nQuery:", query) | |
print("\nResponse:", response) | |
except Exception as e: | |
logger.error(f"Error in main function: {str(e)}") | |
print(f"An error occurred: {str(e)}") | |
if __name__ == "__main__": | |
main() |