Spaces:
Running
Running
| """ | |
| NER inference module. | |
| Loads fine-tuned DistilBERT and extracts legal entities from query text. | |
| Loaded once at FastAPI startup via load_ner_model(). | |
| Fails gracefully — app runs without NER if model not found. | |
| Example: | |
| Input: "What did Justice Chandrachud say about Section 302 IPC?" | |
| Output: {"JUDGE": ["Justice Chandrachud"], | |
| "PROVISION": ["Section 302"], | |
| "STATUTE": ["IPC"]} | |
| The augmented query becomes: | |
| "What did Justice Chandrachud say about Section 302 IPC? | |
| JUDGE: Justice Chandrachud PROVISION: Section 302 STATUTE: IPC" | |
| """ | |
| import os | |
| import logging | |
| from transformers import pipeline, AutoTokenizer, AutoModelForTokenClassification | |
| logger = logging.getLogger(__name__) | |
| NER_MODEL_PATH = os.getenv("NER_MODEL_PATH", "models/ner_model") | |
| TARGET_ENTITIES = { | |
| "JUDGE", "COURT", "STATUTE", "PROVISION", | |
| "CASE_NUMBER", "DATE", "PRECEDENT", "LAWYER", | |
| "PETITIONER", "RESPONDENT", "GPE", "ORG" | |
| } | |
| _ner_pipeline = None | |
| def load_ner_model(): | |
| """ | |
| Load NER model once at startup. | |
| Fails gracefully — app runs without NER if model not found. | |
| Call this from api/main.py after download_models(). | |
| """ | |
| global _ner_pipeline | |
| if not os.path.exists(NER_MODEL_PATH): | |
| logger.warning( | |
| f"NER model not found at {NER_MODEL_PATH}. " | |
| "Entity extraction disabled. App will run without NER." | |
| ) | |
| return | |
| try: | |
| logger.info(f"Loading NER model from {NER_MODEL_PATH}...") | |
| tokenizer = AutoTokenizer.from_pretrained(NER_MODEL_PATH) | |
| model = AutoModelForTokenClassification.from_pretrained(NER_MODEL_PATH) | |
| _ner_pipeline = pipeline( | |
| "ner", | |
| model=model, | |
| tokenizer=tokenizer, | |
| aggregation_strategy="simple" | |
| ) | |
| logger.info("NER model ready.") | |
| except Exception as e: | |
| logger.error(f"NER model load failed: {e}. Entity extraction disabled.") | |
| _ner_pipeline = None | |
| def extract_entities(text: str) -> dict: | |
| """ | |
| Run NER on input text. | |
| Returns dict of {entity_type: [entity_text, ...]} | |
| Returns empty dict if NER not loaded or inference fails. | |
| """ | |
| if _ner_pipeline is None: | |
| return {} | |
| if not text.strip(): | |
| return {} | |
| try: | |
| results = _ner_pipeline(text[:512]) | |
| except Exception as e: | |
| logger.warning(f"NER inference failed: {e}") | |
| return {} | |
| entities = {} | |
| for result in results: | |
| entity_type = result["entity_group"] | |
| entity_text = result["word"].strip() | |
| if entity_type not in TARGET_ENTITIES: | |
| continue | |
| if len(entity_text) < 2: | |
| continue | |
| if entity_type not in entities: | |
| entities[entity_type] = [] | |
| if entity_text not in entities[entity_type]: | |
| entities[entity_type].append(entity_text) | |
| return entities | |
| def augment_query(query: str, entities: dict) -> str: | |
| """ | |
| Append extracted entities to query string for better FAISS retrieval. | |
| Returns original query unchanged if no entities found. | |
| """ | |
| if not entities: | |
| return query | |
| entity_string = " ".join( | |
| f"{etype}: {etext}" | |
| for etype, texts in entities.items() | |
| for etext in texts | |
| ) | |
| return f"{query} {entity_string}" |