Spaces:
Sleeping
Sleeping
| """ | |
| BioBERT embeddings for medical concept similarity. | |
| Maps Spanish medical terms to standardized concepts. | |
| """ | |
| import torch | |
| import numpy as np | |
| from transformers import AutoTokenizer, AutoModel | |
| from typing import List, Dict, Any, Tuple | |
| import logging | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| logger = logging.getLogger(__name__) | |
| class BioBERTEmbeddings: | |
| def __init__(self, model_name: str = "dmis-lab/biobert-base-cased-v1.1"): | |
| """Initialize BioBERT model for medical embeddings.""" | |
| self.model_name = model_name | |
| self.tokenizer = None | |
| self.model = None | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| self._load_model() | |
| def _load_model(self): | |
| """Load BioBERT tokenizer and model.""" | |
| try: | |
| logger.info(f"Loading BioBERT model: {self.model_name}") | |
| self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) | |
| self.model = AutoModel.from_pretrained(self.model_name) | |
| self.model.to(self.device) | |
| self.model.eval() | |
| logger.info("BioBERT model loaded successfully") | |
| except Exception as e: | |
| logger.error(f"Failed to load BioBERT model: {e}") | |
| raise | |
| def encode_text(self, text: str) -> np.ndarray: | |
| """Generate BioBERT embedding for medical text.""" | |
| try: | |
| # Tokenize | |
| inputs = self.tokenizer( | |
| text, | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=True, | |
| max_length=512 | |
| ).to(self.device) | |
| # Generate embeddings | |
| with torch.no_grad(): | |
| outputs = self.model(**inputs) | |
| # Use [CLS] token embedding as sentence representation | |
| embeddings = outputs.last_hidden_state[:, 0, :].cpu().numpy() | |
| return embeddings.flatten() | |
| except Exception as e: | |
| logger.error(f"Error encoding text '{text}': {e}") | |
| return np.zeros(768) # BioBERT embedding size | |
| def encode_batch(self, texts: List[str]) -> np.ndarray: | |
| """Generate embeddings for batch of texts.""" | |
| try: | |
| # Tokenize batch | |
| inputs = self.tokenizer( | |
| texts, | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=True, | |
| max_length=512 | |
| ).to(self.device) | |
| # Generate embeddings | |
| with torch.no_grad(): | |
| outputs = self.model(**inputs) | |
| # Use [CLS] token embeddings | |
| embeddings = outputs.last_hidden_state[:, 0, :].cpu().numpy() | |
| return embeddings | |
| except Exception as e: | |
| logger.error(f"Error encoding batch: {e}") | |
| return np.zeros((len(texts), 768)) | |
| def calculate_similarity(self, text1: str, text2: str) -> float: | |
| """Calculate cosine similarity between two medical texts.""" | |
| try: | |
| emb1 = self.encode_text(text1) | |
| emb2 = self.encode_text(text2) | |
| # Calculate cosine similarity | |
| similarity = cosine_similarity([emb1], [emb2])[0][0] | |
| return float(similarity) | |
| except Exception as e: | |
| logger.error(f"Error calculating similarity: {e}") | |
| return 0.0 | |
| def find_best_matches(self, | |
| query_text: str, | |
| candidate_texts: List[str], | |
| top_k: int = 5) -> List[Tuple[str, float]]: | |
| """Find best matching medical concepts.""" | |
| try: | |
| # Encode query | |
| query_emb = self.encode_text(query_text) | |
| # Encode candidates | |
| candidate_embs = self.encode_batch(candidate_texts) | |
| # Calculate similarities | |
| similarities = cosine_similarity([query_emb], candidate_embs)[0] | |
| # Get top matches | |
| top_indices = np.argsort(similarities)[-top_k:][::-1] | |
| matches = [] | |
| for idx in top_indices: | |
| matches.append((candidate_texts[idx], float(similarities[idx]))) | |
| return matches | |
| except Exception as e: | |
| logger.error(f"Error finding matches: {e}") | |
| return [] | |
| class MedicalConceptMatcher: | |
| def __init__(self, neo4j_driver, biobert_embeddings: BioBERT): | |
| """Initialize medical concept matcher with Neo4j and BioBERT.""" | |
| self.driver = neo4j_driver | |
| self.biobert = biobert_embeddings | |
| self.confidence_threshold = 0.8 | |
| async def map_spanish_term(self, spanish_term: str, context: str = "") -> Dict[str, Any]: | |
| """Map Spanish medical term to standardized concept.""" | |
| try: | |
| # First, check if we have a direct mapping | |
| direct_match = await self._check_direct_mapping(spanish_term) | |
| if direct_match and direct_match["confidence"] >= self.confidence_threshold: | |
| return direct_match | |
| # If no direct match, use semantic similarity | |
| semantic_matches = await self._semantic_concept_search(spanish_term, context) | |
| if semantic_matches: | |
| best_match = semantic_matches[0] | |
| # Learn this mapping for future use | |
| await self._learn_mapping(spanish_term, best_match, context) | |
| return best_match | |
| return { | |
| "success": False, | |
| "error": "No suitable mapping found", | |
| "confidence": 0.0 | |
| } | |
| except Exception as e: | |
| logger.error(f"Error mapping Spanish term '{spanish_term}': {e}") | |
| return {"success": False, "error": str(e), "confidence": 0.0} | |
| async def _check_direct_mapping(self, spanish_term: str) -> Dict[str, Any]: | |
| """Check if Spanish term has existing direct mapping.""" | |
| async with self.driver.session() as session: | |
| result = await session.run(""" | |
| MATCH (spanish:SpanishTerm {term: $term})-[r:MAPS_TO]->(concept:Concept) | |
| RETURN concept.code as code, | |
| concept.display as display, | |
| concept.system as system, | |
| r.confidence as confidence, | |
| r.usage_count as usage_count | |
| ORDER BY r.confidence DESC, r.usage_count DESC | |
| LIMIT 1 | |
| """, term=spanish_term.lower()) | |
| record = await result.single() | |
| if record: | |
| return { | |
| "success": True, | |
| "spanish_term": spanish_term, | |
| "mapped_code": record["code"], | |
| "mapped_display": record["display"], | |
| "system": record["system"], | |
| "confidence": float(record["confidence"]), | |
| "mapping_type": "direct", | |
| "usage_count": record["usage_count"] | |
| } | |
| return None | |
| async def _semantic_concept_search(self, spanish_term: str, context: str = "") -> List[Dict[str, Any]]: | |
| """Search for concepts using BioBERT semantic similarity.""" | |
| # Get candidate concepts from Neo4j | |
| async with self.driver.session() as session: | |
| result = await session.run(""" | |
| MATCH (concept:Concept) | |
| WHERE concept.vocabulary IN ['ICD-10', 'SNOMED-CT'] | |
| RETURN concept.code as code, | |
| concept.display as display, | |
| concept.system as system, | |
| concept.vocabulary as vocabulary, | |
| concept.synonyms as synonyms | |
| LIMIT 1000 | |
| """) | |
| candidates = [] | |
| candidate_texts = [] | |
| async for record in result: | |
| concept = { | |
| "code": record["code"], | |
| "display": record["display"], | |
| "system": record["system"], | |
| "vocabulary": record["vocabulary"] | |
| } | |
| candidates.append(concept) | |
| # Build searchable text (display + synonyms) | |
| search_text = record["display"] | |
| if record["synonyms"]: | |
| search_text += " " + " ".join(record["synonyms"]) | |
| candidate_texts.append(search_text) | |
| # Use BioBERT to find best matches | |
| query_text = spanish_term | |
| if context: | |
| query_text += f" {context}" | |
| matches = self.biobert.find_best_matches(query_text, candidate_texts, top_k=5) | |
| # Format results | |
| semantic_matches = [] | |
| for match_text, similarity in matches: | |
| if similarity >= self.confidence_threshold: | |
| # Find corresponding concept | |
| match_idx = candidate_texts.index(match_text) | |
| concept = candidates[match_idx] | |
| semantic_matches.append({ | |
| "success": True, | |
| "spanish_term": spanish_term, | |
| "mapped_code": concept["code"], | |
| "mapped_display": concept["display"], | |
| "system": concept["system"], | |
| "vocabulary": concept["vocabulary"], | |
| "confidence": similarity, | |
| "mapping_type": "semantic", | |
| "match_text": match_text | |
| }) | |
| return semantic_matches | |
| async def _learn_mapping(self, spanish_term: str, mapping: Dict[str, Any], context: str = ""): | |
| """Learn new Spanish -> concept mapping.""" | |
| async with self.driver.session() as session: | |
| await session.run(""" | |
| MERGE (spanish:SpanishTerm {term: $spanish_term}) | |
| MERGE (concept:Concept {code: $code, system: $system}) | |
| MERGE (spanish)-[r:MAPS_TO]->(concept) | |
| SET r.confidence = $confidence, | |
| r.mapping_type = $mapping_type, | |
| r.learned_context = $context, | |
| r.usage_count = COALESCE(r.usage_count, 0) + 1, | |
| r.last_updated = datetime() | |
| """, | |
| spanish_term=spanish_term.lower(), | |
| code=mapping["mapped_code"], | |
| system=mapping["system"], | |
| confidence=mapping["confidence"], | |
| mapping_type=mapping["mapping_type"], | |
| context=context | |
| ) | |
| async def validate_mapping(self, spanish_term: str, expected_code: str) -> Dict[str, Any]: | |
| """Validate existing mapping accuracy.""" | |
| try: | |
| current_mapping = await self.map_spanish_term(spanish_term) | |
| if not current_mapping["success"]: | |
| return { | |
| "valid": False, | |
| "error": "No mapping found", | |
| "expected": expected_code, | |
| "actual": None | |
| } | |
| actual_code = current_mapping["mapped_code"] | |
| is_valid = actual_code == expected_code | |
| return { | |
| "valid": is_valid, | |
| "spanish_term": spanish_term, | |
| "expected_code": expected_code, | |
| "actual_code": actual_code, | |
| "confidence": current_mapping["confidence"], | |
| "mapping_type": current_mapping["mapping_type"] | |
| } | |
| except Exception as e: | |
| return { | |
| "valid": False, | |
| "error": str(e), | |
| "spanish_term": spanish_term, | |
| "expected_code": expected_code | |
| } | |
| def get_embedding_stats(self) -> Dict[str, Any]: | |
| """Get BioBERT embedding statistics.""" | |
| return { | |
| "model_name": self.biobert.model_name, | |
| "device": str(self.biobert.device), | |
| "embedding_size": 768, | |
| "confidence_threshold": self.confidence_threshold | |
| } |