| import gradio as gr |
| import os |
| import json |
| import pickle |
| from datetime import datetime |
| import requests |
| from bs4 import BeautifulSoup |
| import fitz |
| import numpy as np |
| from sentence_transformers import SentenceTransformer |
| from sklearn.metrics.pairwise import cosine_similarity |
| import sqlite3 |
| import hashlib |
| from typing import List, Dict, Any, Tuple |
| import logging |
| import tempfile |
| import shutil |
| from urllib.parse import urlparse, urljoin |
| import re |
|
|
| |
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
| class MedicalRAGSystem: |
| def __init__(self): |
| self.embedding_model = None |
| self.db_path = "medical_rag.db" |
| self.embeddings_cache = {} |
| self.init_database() |
| self.load_embedding_model() |
| |
| def load_embedding_model(self): |
| """Load a free sentence transformer model""" |
| try: |
| |
| self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2') |
| logger.info("Embedding model loaded successfully") |
| except Exception as e: |
| logger.error(f"Error loading embedding model: {e}") |
| return None |
| |
| def init_database(self): |
| """Initialize SQLite database for persistent storage""" |
| conn = sqlite3.connect(self.db_path) |
| cursor = conn.cursor() |
| |
| |
| cursor.execute(''' |
| CREATE TABLE IF NOT EXISTS documents ( |
| id INTEGER PRIMARY KEY AUTOINCREMENT, |
| filename TEXT NOT NULL, |
| content TEXT NOT NULL, |
| content_hash TEXT UNIQUE, |
| category TEXT NOT NULL, |
| created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, |
| metadata TEXT |
| ) |
| ''') |
| |
| cursor.execute(''' |
| CREATE TABLE IF NOT EXISTS websites ( |
| id INTEGER PRIMARY KEY AUTOINCREMENT, |
| url TEXT NOT NULL, |
| content TEXT NOT NULL, |
| content_hash TEXT UNIQUE, |
| title TEXT, |
| created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, |
| metadata TEXT |
| ) |
| ''') |
| |
| cursor.execute(''' |
| CREATE TABLE IF NOT EXISTS standards ( |
| id INTEGER PRIMARY KEY AUTOINCREMENT, |
| standard_name TEXT NOT NULL, |
| content TEXT NOT NULL, |
| content_hash TEXT UNIQUE, |
| version TEXT, |
| created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, |
| metadata TEXT |
| ) |
| ''') |
| |
| cursor.execute(''' |
| CREATE TABLE IF NOT EXISTS embeddings ( |
| id INTEGER PRIMARY KEY AUTOINCREMENT, |
| source_type TEXT NOT NULL, |
| source_id INTEGER NOT NULL, |
| chunk_index INTEGER NOT NULL, |
| embedding BLOB NOT NULL, |
| text_chunk TEXT NOT NULL |
| ) |
| ''') |
| |
| conn.commit() |
| conn.close() |
| logger.info("Database initialized successfully") |
| |
| def get_content_hash(self, content: str) -> str: |
| """Generate hash for content to avoid duplicates""" |
| return hashlib.md5(content.encode()).hexdigest() |
| |
| def chunk_text(self, text: str, chunk_size: int = 500, overlap: int = 50) -> List[str]: |
| """Split text into overlapping chunks for better retrieval""" |
| words = text.split() |
| chunks = [] |
| |
| for i in range(0, len(words), chunk_size - overlap): |
| chunk = ' '.join(words[i:i + chunk_size]) |
| if chunk.strip(): |
| chunks.append(chunk) |
| |
| return chunks |
| |
| def process_pdf_document(self, file_path: str) -> Tuple[str, Dict]: |
| """Extract text content from PDF documents""" |
| try: |
| doc = fitz.open(file_path) |
| text_content = "" |
| metadata = {"pages": doc.page_count, "format": "PDF"} |
| |
| for page_num in range(doc.page_count): |
| page = doc[page_num] |
| text_content += page.get_text() |
| |
| doc.close() |
| return text_content, metadata |
| except Exception as e: |
| logger.error(f"Error processing PDF: {e}") |
| return "", {} |
| |
| def process_text_document(self, file_path: str) -> Tuple[str, Dict]: |
| """Process text documents""" |
| try: |
| with open(file_path, 'r', encoding='utf-8') as f: |
| content = f.read() |
| return content, {"format": "TEXT"} |
| except Exception as e: |
| logger.error(f"Error processing text document: {e}") |
| return "", {} |
| |
| def scrape_website(self, url: str) -> Tuple[str, str, Dict]: |
| """Scrape content from regulatory websites""" |
| try: |
| headers = { |
| 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36' |
| } |
| response = requests.get(url, headers=headers, timeout=30) |
| response.raise_for_status() |
| |
| soup = BeautifulSoup(response.content, 'html.parser') |
| |
| |
| for script in soup(["script", "style"]): |
| script.decompose() |
| |
| |
| title = soup.title.string if soup.title else url |
| |
| |
| content = soup.get_text() |
| content = re.sub(r'\s+', ' ', content).strip() |
| |
| metadata = { |
| "title": title, |
| "url": url, |
| "scraped_at": datetime.now().isoformat() |
| } |
| |
| return content, title, metadata |
| |
| except Exception as e: |
| logger.error(f"Error scraping website {url}: {e}") |
| return "", "", {} |
| |
| def add_document(self, file_path: str, filename: str, category: str) -> str: |
| """Add document to the knowledge base""" |
| try: |
| |
| if filename.lower().endswith('.pdf'): |
| content, metadata = self.process_pdf_document(file_path) |
| else: |
| content, metadata = self.process_text_document(file_path) |
| |
| if not content: |
| return "Error: Could not extract content from document" |
| |
| content_hash = self.get_content_hash(content) |
| |
| |
| conn = sqlite3.connect(self.db_path) |
| cursor = conn.cursor() |
| |
| try: |
| cursor.execute(''' |
| INSERT INTO documents (filename, content, content_hash, category, metadata) |
| VALUES (?, ?, ?, ?, ?) |
| ''', (filename, content, content_hash, category, json.dumps(metadata))) |
| |
| doc_id = cursor.lastrowid |
| conn.commit() |
| |
| |
| self.generate_embeddings_for_content(content, 'document', doc_id) |
| |
| conn.close() |
| return f"Document '{filename}' added successfully to category '{category}'" |
| |
| except sqlite3.IntegrityError: |
| conn.close() |
| return "Document already exists in the knowledge base" |
| |
| except Exception as e: |
| logger.error(f"Error adding document: {e}") |
| return f"Error adding document: {str(e)}" |
| |
| def add_website(self, url: str) -> str: |
| """Add website content to the knowledge base""" |
| try: |
| content, title, metadata = self.scrape_website(url) |
| |
| if not content: |
| return "Error: Could not scrape website content" |
| |
| content_hash = self.get_content_hash(content) |
| |
| conn = sqlite3.connect(self.db_path) |
| cursor = conn.cursor() |
| |
| try: |
| cursor.execute(''' |
| INSERT INTO websites (url, content, content_hash, title, metadata) |
| VALUES (?, ?, ?, ?, ?) |
| ''', (url, content, content_hash, title, json.dumps(metadata))) |
| |
| website_id = cursor.lastrowid |
| conn.commit() |
| |
| |
| self.generate_embeddings_for_content(content, 'website', website_id) |
| |
| conn.close() |
| return f"Website '{title}' added successfully" |
| |
| except sqlite3.IntegrityError: |
| conn.close() |
| return "Website already exists in the knowledge base" |
| |
| except Exception as e: |
| logger.error(f"Error adding website: {e}") |
| return f"Error adding website: {str(e)}" |
| |
| def add_standard(self, standard_name: str, content: str, version: str = "") -> str: |
| """Add standard content to the knowledge base""" |
| try: |
| if not content.strip(): |
| return "Error: Standard content cannot be empty" |
| |
| content_hash = self.get_content_hash(content) |
| |
| conn = sqlite3.connect(self.db_path) |
| cursor = conn.cursor() |
| |
| metadata = {"version": version, "added_at": datetime.now().isoformat()} |
| |
| try: |
| cursor.execute(''' |
| INSERT INTO standards (standard_name, content, content_hash, version, metadata) |
| VALUES (?, ?, ?, ?, ?) |
| ''', (standard_name, content, content_hash, version, json.dumps(metadata))) |
| |
| standard_id = cursor.lastrowid |
| conn.commit() |
| |
| |
| self.generate_embeddings_for_content(content, 'standard', standard_id) |
| |
| conn.close() |
| return f"Standard '{standard_name}' added successfully" |
| |
| except sqlite3.IntegrityError: |
| conn.close() |
| return "Standard already exists in the knowledge base" |
| |
| except Exception as e: |
| logger.error(f"Error adding standard: {e}") |
| return f"Error adding standard: {str(e)}" |
| |
| def generate_embeddings_for_content(self, content: str, source_type: str, source_id: int): |
| """Generate embeddings for content chunks""" |
| if not self.embedding_model: |
| logger.error("Embedding model not available") |
| return |
| |
| chunks = self.chunk_text(content) |
| |
| conn = sqlite3.connect(self.db_path) |
| cursor = conn.cursor() |
| |
| for i, chunk in enumerate(chunks): |
| try: |
| embedding = self.embedding_model.encode(chunk) |
| embedding_blob = pickle.dumps(embedding) |
| |
| cursor.execute(''' |
| INSERT INTO embeddings (source_type, source_id, chunk_index, embedding, text_chunk) |
| VALUES (?, ?, ?, ?, ?) |
| ''', (source_type, source_id, i, embedding_blob, chunk)) |
| |
| except Exception as e: |
| logger.error(f"Error generating embedding for chunk {i}: {e}") |
| |
| conn.commit() |
| conn.close() |
| |
| def search_knowledge_base(self, query: str, top_k: int = 5) -> List[Dict]: |
| """Search the knowledge base using semantic similarity""" |
| if not self.embedding_model: |
| return [] |
| |
| try: |
| query_embedding = self.embedding_model.encode(query) |
| |
| conn = sqlite3.connect(self.db_path) |
| cursor = conn.cursor() |
| |
| |
| cursor.execute(''' |
| SELECT e.source_type, e.source_id, e.text_chunk, e.embedding, |
| CASE |
| WHEN e.source_type = 'document' THEN d.filename |
| WHEN e.source_type = 'website' THEN w.title |
| WHEN e.source_type = 'standard' THEN s.standard_name |
| END as source_name |
| FROM embeddings e |
| LEFT JOIN documents d ON e.source_type = 'document' AND e.source_id = d.id |
| LEFT JOIN websites w ON e.source_type = 'website' AND e.source_id = w.id |
| LEFT JOIN standards s ON e.source_type = 'standard' AND e.source_id = s.id |
| ''') |
| |
| results = [] |
| for row in cursor.fetchall(): |
| try: |
| stored_embedding = pickle.loads(row[3]) |
| similarity = cosine_similarity([query_embedding], [stored_embedding])[0][0] |
| |
| results.append({ |
| 'source_type': row[0], |
| 'source_id': row[1], |
| 'text_chunk': row[2], |
| 'source_name': row[4], |
| 'similarity': similarity |
| }) |
| except Exception as e: |
| logger.error(f"Error processing embedding: {e}") |
| |
| conn.close() |
| |
| |
| results.sort(key=lambda x: x['similarity'], reverse=True) |
| return results[:top_k] |
| |
| except Exception as e: |
| logger.error(f"Error searching knowledge base: {e}") |
| return [] |
| |
| def get_knowledge_base_stats(self) -> Dict: |
| """Get statistics about the knowledge base""" |
| conn = sqlite3.connect(self.db_path) |
| cursor = conn.cursor() |
| |
| stats = {} |
| |
| |
| cursor.execute("SELECT COUNT(*) FROM documents") |
| stats['documents'] = cursor.fetchone()[0] |
| |
| |
| cursor.execute("SELECT COUNT(*) FROM websites") |
| stats['websites'] = cursor.fetchone()[0] |
| |
| |
| cursor.execute("SELECT COUNT(*) FROM standards") |
| stats['standards'] = cursor.fetchone()[0] |
| |
| |
| cursor.execute("SELECT COUNT(*) FROM embeddings") |
| stats['embeddings'] = cursor.fetchone()[0] |
| |
| conn.close() |
| return stats |
|
|
| |
| rag_system = MedicalRAGSystem() |
|
|
| def handle_document_upload(files, category): |
| """Handle document upload""" |
| if not files: |
| return "No files selected" |
| |
| results = [] |
| for file in files: |
| filename = os.path.basename(file.name) |
| result = rag_system.add_document(file.name, filename, category) |
| results.append(result) |
| |
| return "\n".join(results) |
|
|
| def handle_website_addition(url): |
| """Handle website addition""" |
| if not url.strip(): |
| return "Please enter a valid URL" |
| |
| return rag_system.add_website(url.strip()) |
|
|
| def handle_standard_addition(standard_name, content, version): |
| """Handle standard addition""" |
| if not standard_name.strip() or not content.strip(): |
| return "Please provide both standard name and content" |
| |
| return rag_system.add_standard(standard_name.strip(), content.strip(), version.strip()) |
|
|
| def handle_search(query): |
| """Handle search queries""" |
| if not query.strip(): |
| return "Please enter a search query", "" |
| |
| results = rag_system.search_knowledge_base(query.strip()) |
| |
| if not results: |
| return "No relevant results found", "" |
| |
| |
| formatted_results = [] |
| context = [] |
| |
| for i, result in enumerate(results, 1): |
| similarity_pct = result['similarity'] * 100 |
| formatted_results.append(f""" |
| **Result {i}** (Similarity: {similarity_pct:.1f}%) |
| **Source:** {result['source_name']} ({result['source_type']}) |
| **Content:** {result['text_chunk'][:300]}{'...' if len(result['text_chunk']) > 300 else ''} |
| --- |
| """) |
| context.append(result['text_chunk']) |
| |
| |
| answer = generate_answer(query, context) |
| |
| return "\n".join(formatted_results), answer |
|
|
| def generate_answer(query: str, context: List[str]) -> str: |
| """Generate an answer based on the retrieved context""" |
| |
| relevant_info = [] |
| |
| query_lower = query.lower() |
| for chunk in context: |
| |
| sentences = chunk.split('.') |
| for sentence in sentences: |
| if any(term in sentence.lower() for term in query_lower.split()): |
| relevant_info.append(sentence.strip()) |
| |
| if relevant_info: |
| |
| unique_info = list(dict.fromkeys(relevant_info)) |
| return "Based on the regulatory documents:\n\n" + "\n\n".join(unique_info[:3]) |
| else: |
| return "The retrieved content may contain relevant information, but I couldn't extract a specific answer. Please review the search results above." |
|
|
| def get_stats(): |
| """Get knowledge base statistics""" |
| stats = rag_system.get_knowledge_base_stats() |
| return f""" |
| Knowledge Base Statistics: |
| - Documents: {stats['documents']} |
| - Websites: {stats['websites']} |
| - Standards: {stats['standards']} |
| - Total Text Chunks: {stats['embeddings']} |
| """ |
|
|
| |
| with gr.Blocks(title="Medical Devices RAG System", theme=gr.themes.Soft()) as demo: |
| gr.Markdown(""" |
| # π₯ Medical Devices Regulatory RAG System |
| |
| A comprehensive knowledge base system for medical device regulatory analysts. |
| Add documents, websites, and standards to build your regulatory knowledge base. |
| """) |
| |
| with gr.Tabs(): |
| |
| with gr.Tab("π Search Knowledge Base"): |
| gr.Markdown("### Search your regulatory knowledge base") |
| |
| search_input = gr.Textbox( |
| placeholder="Enter your regulatory question (e.g., 'What are the requirements for Class II medical devices?')", |
| label="Search Query", |
| lines=2 |
| ) |
| search_button = gr.Button("Search", variant="primary") |
| |
| with gr.Row(): |
| with gr.Column(): |
| search_results = gr.Markdown(label="Search Results") |
| with gr.Column(): |
| answer_output = gr.Markdown(label="Generated Answer") |
| |
| search_button.click( |
| handle_search, |
| inputs=[search_input], |
| outputs=[search_results, answer_output] |
| ) |
| |
| |
| with gr.Tab("π Add Documents"): |
| gr.Markdown("### Add regulatory documents (PDF, TXT)") |
| |
| document_files = gr.File( |
| label="Upload Documents", |
| file_count="multiple", |
| file_types=[".pdf", ".txt", ".docx"] |
| ) |
| document_category = gr.Dropdown( |
| choices=["EU MDR 2017/745", "CMDR SOR/98-282", "MDCG", "MDSAP Audit Approach", "UK MDR", "Other"], |
| label="Document Category", |
| value="Other" |
| ) |
| add_doc_button = gr.Button("Add Documents", variant="primary") |
| doc_output = gr.Textbox(label="Result", lines=3) |
| |
| add_doc_button.click( |
| handle_document_upload, |
| inputs=[document_files, document_category], |
| outputs=[doc_output] |
| ) |
| |
| |
| with gr.Tab("π Add Websites"): |
| gr.Markdown("### Add regulatory websites") |
| |
| website_url = gr.Textbox( |
| placeholder="https://www.fda.gov/medical-devices/...", |
| label="Website URL", |
| lines=1 |
| ) |
| add_website_button = gr.Button("Add Website", variant="primary") |
| website_output = gr.Textbox(label="Result", lines=3) |
| |
| gr.Markdown("**Suggested regulatory websites:**") |
| gr.Markdown(""" |
| - US FDA 21CFR: https://www.accessdata.fda.gov/scripts/cdrh/cfdocs/cfcfr/cfrsearch.cfm |
| - EU Medical Devices: https://ec.europa.eu/health/medical-devices-sector_en |
| - Health Canada Medical Devices: https://www.canada.ca/en/health-canada/services/drugs-health-products/medical-devices.html |
| """) |
| |
| add_website_button.click( |
| handle_website_addition, |
| inputs=[website_url], |
| outputs=[website_output] |
| ) |
| |
| |
| with gr.Tab("π Add Standards"): |
| gr.Markdown("### Add regulatory standards") |
| |
| standard_name = gr.Textbox( |
| placeholder="ISO 13485:2016", |
| label="Standard Name", |
| lines=1 |
| ) |
| standard_version = gr.Textbox( |
| placeholder="2016 (optional)", |
| label="Version", |
| lines=1 |
| ) |
| standard_content = gr.Textbox( |
| placeholder="Enter or paste the standard content here...", |
| label="Standard Content", |
| lines=10 |
| ) |
| add_standard_button = gr.Button("Add Standard", variant="primary") |
| standard_output = gr.Textbox(label="Result", lines=3) |
| |
| add_standard_button.click( |
| handle_standard_addition, |
| inputs=[standard_name, standard_content, standard_version], |
| outputs=[standard_output] |
| ) |
| |
| |
| with gr.Tab("π Knowledge Base Stats"): |
| gr.Markdown("### Knowledge Base Statistics") |
| |
| stats_button = gr.Button("Refresh Statistics", variant="secondary") |
| stats_output = gr.Textbox(label="Statistics", lines=8) |
| |
| stats_button.click( |
| get_stats, |
| outputs=[stats_output] |
| ) |
| |
| |
| demo.load(get_stats, outputs=[stats_output]) |
|
|
| if __name__ == "__main__": |
| demo.launch(share=True) |