| | import chromadb |
| | import traceback |
| | from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline |
| | from retriever import retrieve |
| | from utils import build_prompt, refine_response |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | _vector_store = None |
| | _finetuned_llm = None |
| | _base_model = None |
| |
|
| | def get_vector_store(): |
| | """Load vector store (lazy-loaded on first use)""" |
| | global _vector_store |
| | if _vector_store is None: |
| | db_client = chromadb.PersistentClient(path="./MedQuAD_db") |
| | try: |
| | _vector_store = db_client.get_collection("medical_rag") |
| | except: |
| | |
| | _vector_store = db_client.create_collection(name="medical_rag") |
| | return _vector_store |
| |
|
| | def get_finetuned_llm(): |
| | """Load fine-tuned model (lazy-loaded on first use)""" |
| | global _finetuned_llm |
| | if _finetuned_llm is None: |
| | ft_model_id = "amiraghhh/fine-tuned-flan-t5-small" |
| | ft_tokenizer = AutoTokenizer.from_pretrained(ft_model_id) |
| | ft_model = AutoModelForSeq2SeqLM.from_pretrained(ft_model_id) |
| | |
| | _finetuned_llm = pipeline( |
| | "text2text-generation", |
| | model=ft_model, |
| | tokenizer=ft_tokenizer, |
| | decoder_start_token_id=ft_model.config.pad_token_id |
| | ) |
| | return _finetuned_llm |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def rag(user_query): |
| | """Main RAG function: retrieve context and generate answer. |
| | Takes a question string and returns an answer string with confidence. |
| | Returns: str(generated_answer)""" |
| | |
| | try: |
| | |
| | vector_store = get_vector_store() |
| | finetuned_llm = get_finetuned_llm() |
| | |
| | |
| | emergency_keywords = ["emergency", "severe pain", "bleeding", |
| | "blind", "lose consciousness", "pass out"] |
| | |
| | if any(keyword in user_query.lower() for keyword in emergency_keywords): |
| | emergency_msg = """I am an AI and cannot provide medical advice for emergencies. |
| | PLEASE contact emergency services or a medical professional immediately.""" |
| | |
| | try: |
| | |
| | contexts = retrieve(vector_store, user_query, top_k=3, use_reranking=True) |
| | |
| | if not contexts: |
| | return f"{emergency_msg}\n\nNo relevant information found for your query." |
| | |
| | prompt = build_prompt(user_query, contexts) |
| | result = finetuned_llm( |
| | prompt, |
| | max_new_tokens=70, |
| | num_beams=3, |
| | early_stopping=True, |
| | do_sample=False, |
| | repetition_penalty=1.4, |
| | eos_token_id=finetuned_llm.tokenizer.eos_token_id |
| | ) |
| | |
| | answer = result[0]['generated_text'].strip() |
| | answer = refine_response(answer) |
| | |
| | |
| | if contexts: |
| | avg_distance = sum(c.get('chroma_distance', 1.0) for c in contexts) / len(contexts) |
| | confidence_score = (1 - avg_distance) * 100 |
| | confidence_score = max(0, min(100, confidence_score)) |
| | else: |
| | confidence_score = 0 |
| | |
| | return f"{emergency_msg}\n\n[Confidence: {confidence_score:.1f}%]\n\n{answer}" |
| | |
| | except Exception as e: |
| | return f"{emergency_msg}\n\nError generating answer: {str(e)}" |
| | |
| | |
| | contexts = retrieve(vector_store, user_query, top_k=3, use_reranking=True) |
| | |
| | if not contexts: |
| | return "I'm not confident about my answer (0%).\n\nCouldn't find relevant information to answer your question." |
| | |
| | |
| | prompt = build_prompt(user_query, contexts) |
| | |
| | |
| | result = finetuned_llm( |
| | prompt, |
| | max_new_tokens=70, |
| | num_beams=3, |
| | early_stopping=True, |
| | do_sample=False, |
| | repetition_penalty=1.4, |
| | eos_token_id=finetuned_llm.tokenizer.eos_token_id |
| | ) |
| | |
| | answer = result[0]['generated_text'].strip() |
| | answer = refine_response(answer) |
| | |
| | |
| | if contexts and len(contexts) > 0: |
| | avg_distance = sum(c.get('chroma_distance', 1.0) for c in contexts) / len(contexts) |
| | confidence_score = (1 - avg_distance) * 100 |
| | confidence_score = max(0, min(100, confidence_score)) |
| | |
| | |
| | if confidence_score < 40: |
| | final_response = f"I'm not confident about my answer ({confidence_score:.1f}%).\n\n{answer}" |
| | else: |
| | final_response = f"{answer}\n\n[Confidence: {confidence_score:.1f}%]" |
| | else: |
| | final_response = "I'm not confident about my answer (0%).\n\n" + answer |
| | |
| | return final_response |
| | |
| | except Exception as e: |
| | error_msg = f"ERROR in RAG pipeline: {str(e)}\n\nTraceback:\n{traceback.format_exc()}" |
| | print(error_msg) |
| | return error_msg |
| |
|