os37-demo-v2 / ontology-mcp /biobert_embeddings.py
goldenski's picture
OS37 v2: Production medical AI infrastructure
468c313
"""
BioBERT embeddings for medical concept similarity.
Maps Spanish medical terms to standardized concepts.
"""
import torch
import numpy as np
from transformers import AutoTokenizer, AutoModel
from typing import List, Dict, Any, Tuple
import logging
from sklearn.metrics.pairwise import cosine_similarity
logger = logging.getLogger(__name__)
class BioBERTEmbeddings:
def __init__(self, model_name: str = "dmis-lab/biobert-base-cased-v1.1"):
"""Initialize BioBERT model for medical embeddings."""
self.model_name = model_name
self.tokenizer = None
self.model = None
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self._load_model()
def _load_model(self):
"""Load BioBERT tokenizer and model."""
try:
logger.info(f"Loading BioBERT model: {self.model_name}")
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
self.model = AutoModel.from_pretrained(self.model_name)
self.model.to(self.device)
self.model.eval()
logger.info("BioBERT model loaded successfully")
except Exception as e:
logger.error(f"Failed to load BioBERT model: {e}")
raise
def encode_text(self, text: str) -> np.ndarray:
"""Generate BioBERT embedding for medical text."""
try:
# Tokenize
inputs = self.tokenizer(
text,
return_tensors="pt",
padding=True,
truncation=True,
max_length=512
).to(self.device)
# Generate embeddings
with torch.no_grad():
outputs = self.model(**inputs)
# Use [CLS] token embedding as sentence representation
embeddings = outputs.last_hidden_state[:, 0, :].cpu().numpy()
return embeddings.flatten()
except Exception as e:
logger.error(f"Error encoding text '{text}': {e}")
return np.zeros(768) # BioBERT embedding size
def encode_batch(self, texts: List[str]) -> np.ndarray:
"""Generate embeddings for batch of texts."""
try:
# Tokenize batch
inputs = self.tokenizer(
texts,
return_tensors="pt",
padding=True,
truncation=True,
max_length=512
).to(self.device)
# Generate embeddings
with torch.no_grad():
outputs = self.model(**inputs)
# Use [CLS] token embeddings
embeddings = outputs.last_hidden_state[:, 0, :].cpu().numpy()
return embeddings
except Exception as e:
logger.error(f"Error encoding batch: {e}")
return np.zeros((len(texts), 768))
def calculate_similarity(self, text1: str, text2: str) -> float:
"""Calculate cosine similarity between two medical texts."""
try:
emb1 = self.encode_text(text1)
emb2 = self.encode_text(text2)
# Calculate cosine similarity
similarity = cosine_similarity([emb1], [emb2])[0][0]
return float(similarity)
except Exception as e:
logger.error(f"Error calculating similarity: {e}")
return 0.0
def find_best_matches(self,
query_text: str,
candidate_texts: List[str],
top_k: int = 5) -> List[Tuple[str, float]]:
"""Find best matching medical concepts."""
try:
# Encode query
query_emb = self.encode_text(query_text)
# Encode candidates
candidate_embs = self.encode_batch(candidate_texts)
# Calculate similarities
similarities = cosine_similarity([query_emb], candidate_embs)[0]
# Get top matches
top_indices = np.argsort(similarities)[-top_k:][::-1]
matches = []
for idx in top_indices:
matches.append((candidate_texts[idx], float(similarities[idx])))
return matches
except Exception as e:
logger.error(f"Error finding matches: {e}")
return []
class MedicalConceptMatcher:
def __init__(self, neo4j_driver, biobert_embeddings: BioBERT):
"""Initialize medical concept matcher with Neo4j and BioBERT."""
self.driver = neo4j_driver
self.biobert = biobert_embeddings
self.confidence_threshold = 0.8
async def map_spanish_term(self, spanish_term: str, context: str = "") -> Dict[str, Any]:
"""Map Spanish medical term to standardized concept."""
try:
# First, check if we have a direct mapping
direct_match = await self._check_direct_mapping(spanish_term)
if direct_match and direct_match["confidence"] >= self.confidence_threshold:
return direct_match
# If no direct match, use semantic similarity
semantic_matches = await self._semantic_concept_search(spanish_term, context)
if semantic_matches:
best_match = semantic_matches[0]
# Learn this mapping for future use
await self._learn_mapping(spanish_term, best_match, context)
return best_match
return {
"success": False,
"error": "No suitable mapping found",
"confidence": 0.0
}
except Exception as e:
logger.error(f"Error mapping Spanish term '{spanish_term}': {e}")
return {"success": False, "error": str(e), "confidence": 0.0}
async def _check_direct_mapping(self, spanish_term: str) -> Dict[str, Any]:
"""Check if Spanish term has existing direct mapping."""
async with self.driver.session() as session:
result = await session.run("""
MATCH (spanish:SpanishTerm {term: $term})-[r:MAPS_TO]->(concept:Concept)
RETURN concept.code as code,
concept.display as display,
concept.system as system,
r.confidence as confidence,
r.usage_count as usage_count
ORDER BY r.confidence DESC, r.usage_count DESC
LIMIT 1
""", term=spanish_term.lower())
record = await result.single()
if record:
return {
"success": True,
"spanish_term": spanish_term,
"mapped_code": record["code"],
"mapped_display": record["display"],
"system": record["system"],
"confidence": float(record["confidence"]),
"mapping_type": "direct",
"usage_count": record["usage_count"]
}
return None
async def _semantic_concept_search(self, spanish_term: str, context: str = "") -> List[Dict[str, Any]]:
"""Search for concepts using BioBERT semantic similarity."""
# Get candidate concepts from Neo4j
async with self.driver.session() as session:
result = await session.run("""
MATCH (concept:Concept)
WHERE concept.vocabulary IN ['ICD-10', 'SNOMED-CT']
RETURN concept.code as code,
concept.display as display,
concept.system as system,
concept.vocabulary as vocabulary,
concept.synonyms as synonyms
LIMIT 1000
""")
candidates = []
candidate_texts = []
async for record in result:
concept = {
"code": record["code"],
"display": record["display"],
"system": record["system"],
"vocabulary": record["vocabulary"]
}
candidates.append(concept)
# Build searchable text (display + synonyms)
search_text = record["display"]
if record["synonyms"]:
search_text += " " + " ".join(record["synonyms"])
candidate_texts.append(search_text)
# Use BioBERT to find best matches
query_text = spanish_term
if context:
query_text += f" {context}"
matches = self.biobert.find_best_matches(query_text, candidate_texts, top_k=5)
# Format results
semantic_matches = []
for match_text, similarity in matches:
if similarity >= self.confidence_threshold:
# Find corresponding concept
match_idx = candidate_texts.index(match_text)
concept = candidates[match_idx]
semantic_matches.append({
"success": True,
"spanish_term": spanish_term,
"mapped_code": concept["code"],
"mapped_display": concept["display"],
"system": concept["system"],
"vocabulary": concept["vocabulary"],
"confidence": similarity,
"mapping_type": "semantic",
"match_text": match_text
})
return semantic_matches
async def _learn_mapping(self, spanish_term: str, mapping: Dict[str, Any], context: str = ""):
"""Learn new Spanish -> concept mapping."""
async with self.driver.session() as session:
await session.run("""
MERGE (spanish:SpanishTerm {term: $spanish_term})
MERGE (concept:Concept {code: $code, system: $system})
MERGE (spanish)-[r:MAPS_TO]->(concept)
SET r.confidence = $confidence,
r.mapping_type = $mapping_type,
r.learned_context = $context,
r.usage_count = COALESCE(r.usage_count, 0) + 1,
r.last_updated = datetime()
""",
spanish_term=spanish_term.lower(),
code=mapping["mapped_code"],
system=mapping["system"],
confidence=mapping["confidence"],
mapping_type=mapping["mapping_type"],
context=context
)
async def validate_mapping(self, spanish_term: str, expected_code: str) -> Dict[str, Any]:
"""Validate existing mapping accuracy."""
try:
current_mapping = await self.map_spanish_term(spanish_term)
if not current_mapping["success"]:
return {
"valid": False,
"error": "No mapping found",
"expected": expected_code,
"actual": None
}
actual_code = current_mapping["mapped_code"]
is_valid = actual_code == expected_code
return {
"valid": is_valid,
"spanish_term": spanish_term,
"expected_code": expected_code,
"actual_code": actual_code,
"confidence": current_mapping["confidence"],
"mapping_type": current_mapping["mapping_type"]
}
except Exception as e:
return {
"valid": False,
"error": str(e),
"spanish_term": spanish_term,
"expected_code": expected_code
}
def get_embedding_stats(self) -> Dict[str, Any]:
"""Get BioBERT embedding statistics."""
return {
"model_name": self.biobert.model_name,
"device": str(self.biobert.device),
"embedding_size": 768,
"confidence_threshold": self.confidence_threshold
}