import streamlit as st from langchain.chains import RetrievalQA from langchain.prompts import PromptTemplate from sentence_transformers import SentenceTransformer from qdrant_client import QdrantClient from langchain_qdrant import Qdrant from langchain_community.embeddings import HuggingFaceEmbeddings from langchain_community.embeddings import SentenceTransformerEmbeddings from transformers import pipeline import os import torch from groq import Groq import google.generativeai as genai from langchain_openai import OpenAIEmbeddings, ChatOpenAI import cohere available_models = ["OpenAI GPT-4o", "LLaMA 3", "Gemini Pro", "Ensemble"] AI_PROMPT_TEMPLATE = """You are an AI-assisted Dermatology Chatbot, specializing in diagnosing and educating users about skin diseases. You provide accurate, compassionate, and detailed explanations while using correct medical terminology. Guidelines: 1. Symptoms - Explain in simple terms with proper medical definitions. 2. Causes - Include genetic, environmental, and lifestyle-related risk factors. 3. Medications & Treatments - Provide common prescription and over-the-counter treatments. 4. Warnings & Emergencies - Always recommend consulting a licensed dermatologist. 5. Emergency Note - If symptoms worsen or include difficulty breathing, **advise calling 911 immediately. Query: {question} Relevant Information: {context} Answer: """ @st.cache_resource(show_spinner=False) def initialize_rag_components(): components = { 'cohere_client': cohere.Client(st.secrets["COHERE_API_KEY"]), 'pair_ranker': pipeline("text-classification", model="llm-blender/PairRM", tokenizer="llm-blender/PairRM", return_all_scores=True ), 'gen_fuser': pipeline("text-generation", model="llm-blender/gen_fuser_3b", tokenizer="llm-blender/gen_fuser_3b", max_length=2048, do_sample=False ), 'retriever': get_retriever() } return components class AllModelsWrapper: def invoke(self, messages): prompt = messages[0]["content"] rag_components = st.session_state.app_models['rag_components'] # Get components responses = get_all_responses(prompt) fused = rank_and_fuse(prompt, responses, rag_components) return type('obj', (object,), {'content': fused})() def get_all_responses(prompt): # Get responses from all models openai_resp = ChatOpenAI(model="gpt-4o", temperature=0.2, api_key=st.secrets["OPENAI_API_KEY"]).invoke( [{"role": "user", "content": prompt}]).content gemini = genai.GenerativeModel("gemini-2.5-pro-exp-03-25") gemini_resp = gemini.generate_content(prompt).text llama = Groq(api_key=st.secrets["GROQ_API_KEY"]) llama_resp = llama.chat.completions.create( model="meta-llama/llama-4-maverick-17b-128e-instruct", messages=[{"role": "user", "content": prompt}], temperature=1, max_completion_tokens=1024, top_p=1, stream=False ).choices[0].message.content return [openai_resp, gemini_resp, llama_resp] def rank_and_fuse(prompt, responses, rag_components): ranked = [(resp, rag_components['pair_ranker'](f"{prompt}\n\n{resp}")[0][1]['score']) for resp in responses] ranked.sort(key=lambda x: x[1], reverse=True) # Fuse top responses fusion_input = "\n\n".join([f"[Answer {i + 1}]: {ans}" for i, (ans, _) in enumerate(ranked[:2])]) return rag_components['gen_fuser'](f"Fuse these responses:\n{fusion_input}", return_full_text=False)[0]['generated_text'] def get_retriever(): # === Qdrant DB Setup === qdrant_client = QdrantClient( url="https://2715ddd8-647f-40ee-bca4-9027d193e8aa.us-east-1-0.aws.cloud.qdrant.io", api_key="eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhY2Nlc3MiOiJtIn0.HXzezXdWMFeeR16F7zvqgjzsqrcm8hqa-StXdToFP9Q" ) collection_name = "ks_collection_1.5BE" model = SentenceTransformer("Alibaba-NLP/gte-Qwen2-1.5B-instruct", trust_remote_code=True) local_embedding = HuggingFaceEmbeddings( model_name="Alibaba-NLP/gte-Qwen2-1.5B-instruct", model_kwargs={"trust_remote_code": True, "device": "cuda" if torch.cuda.is_available() else "cpu"} ) print(" Qwen2-1.5B local embedding model loaded.") vector_store = Qdrant( client=qdrant_client, collection_name=collection_name, embeddings=local_embedding ) return vector_store.as_retriever() def initialize_llm(_model_name): """Initialize the LLM based on selection""" print(f"Model name : {_model_name}") if "OpenAI" in _model_name: return ChatOpenAI(model="gpt-4o", temperature=0.2, api_key=st.secrets["OPENAI_API_KEY"]) elif "LLaMA" in _model_name: client = Groq(api_key=st.secrets["GROQ_API_KEY"]) def get_llama_response(prompt): completion = client.chat.completions.create( model="meta-llama/llama-4-maverick-17b-128e-instruct", messages=[{"role": "user", "content": prompt}], temperature=1, max_completion_tokens=1024, top_p=1, stream=False ) return completion.choices[0].message.content return type('obj', (object,), {'invoke': lambda self, x: get_llama_response(x[0]["content"])})() elif "Gemini" in _model_name: genai.configure(api_key=st.secrets["GEMINI_API_KEY"]) gemini_model = genai.GenerativeModel("gemini-2.5-pro-exp-03-25") def get_gemini_response(prompt): response = gemini_model.generate_content(prompt) return response.text return type('obj', (object,), {'invoke': lambda self, x: get_gemini_response(x[0]["content"])})() elif "Ensemble" in _model_name: return AllModelsWrapper() else: raise ValueError("Unsupported model selected") def load_rag_chain(llm): prompt_template = PromptTemplate(template=AI_PROMPT_TEMPLATE, input_variables=["question", "context"]) rag_chain = RetrievalQA.from_chain_type( llm=llm, retriever=get_retriever(), chain_type="stuff", chain_type_kwargs={"prompt": prompt_template, "document_variable_name": "context"} ) return rag_chain def rerank_with_cohere(query, documents, co, top_n=5): if not documents: return [] raw_texts = [doc.page_content for doc in documents] results = co.rerank(query=query, documents=raw_texts, top_n=min(top_n, len(raw_texts)), model="rerank-v3.5") return [documents[result.index] for result in results.results] def get_reranked_response(query, llm, rag_components): """Get response with reranking""" docs = rag_components['retriever'].get_relevant_documents(query) reranked_docs = rerank_with_cohere(query, docs, rag_components['cohere_client']) context = "\n\n".join([doc.page_content for doc in reranked_docs]) if isinstance(llm, (ChatOpenAI, AllModelsWrapper)): return load_rag_chain(llm).invoke({"query": query, "context": context})['result'] else: prompt = AI_PROMPT_TEMPLATE.format(question=query, context=context) return llm.invoke([{"role": "user", "content": prompt}]).content if __name__ == "__main__": print("This is a module - import it instead of running directly")