Spaces:
Sleeping
Sleeping
| """ | |
| text_emotion_engine.py — DistilBERT Multi-Label Text Emotion Classifier | |
| Uses: bhadresh-savani/distilbert-base-uncased-emotion | |
| Output: top-N emotions with calibrated confidence scores. | |
| Runs inference in asyncio.to_thread to avoid blocking the event loop. | |
| """ | |
| from __future__ import annotations | |
| import asyncio | |
| import logging | |
| from typing import List, Optional | |
| from app.schemas import EmotionLabel | |
| logger = logging.getLogger(__name__) | |
| _pipeline = None | |
| _load_error: Optional[str] = None | |
| def _load_pipeline(model_name: str) -> None: | |
| """Called once at startup. Loads the HuggingFace pipeline into global.""" | |
| global _pipeline, _load_error | |
| try: | |
| from transformers import pipeline as hf_pipeline | |
| import os | |
| # Determine local path | |
| local_path = os.path.join("app", "ml_assets", "distilbert_model") | |
| logger.info("Loading DistilBERT text emotion model from %s", local_path) | |
| _pipeline = hf_pipeline( | |
| "text-classification", | |
| model=local_path if os.path.exists(local_path) else model_name, | |
| top_k=None, # Return ALL labels | |
| truncation=True, | |
| max_length=512, | |
| ) | |
| logger.info("✅ DistilBERT emotion model loaded successfully.") | |
| except Exception as exc: | |
| _load_error = str(exc) | |
| logger.error("❌ Failed to load DistilBERT model: %s", exc) | |
| def initialize(model_name: str) -> None: | |
| """Called at app startup to pre-warm the model.""" | |
| _load_pipeline(model_name) | |
| class TextEmotionEngine: | |
| """ | |
| Wraps the HuggingFace DistilBERT pipeline for async use in FastAPI. | |
| """ | |
| def _classify_sync(self, text: str) -> List[EmotionLabel]: | |
| if _pipeline is None: | |
| return [] | |
| try: | |
| results = _pipeline(text[:512]) | |
| if not results: | |
| return [] | |
| # pipeline returns list-of-list when top_k=None | |
| raw = results[0] if isinstance(results[0], list) else results | |
| labels = [ | |
| EmotionLabel(label=item["label"].lower(), score=round(item["score"], 4)) | |
| for item in raw | |
| ] | |
| # Sort descending by score | |
| return sorted(labels, key=lambda x: x.score, reverse=True) | |
| except Exception as exc: | |
| logger.error("DistilBERT inference error: %s", exc) | |
| return [] | |
| async def classify(self, text: str) -> List[EmotionLabel]: | |
| """ | |
| Async wrapper — runs CPU-bound inference in a thread pool. | |
| Returns list of EmotionLabel sorted by confidence desc. | |
| """ | |
| return await asyncio.to_thread(self._classify_sync, text) | |
| async def top_emotion(self, text: str) -> str: | |
| """Returns the single dominant emotion label.""" | |
| labels = await self.classify(text) | |
| return labels[0].label if labels else "neutral" | |
| def summary_string(self, labels: List[EmotionLabel], top_k: int = 3) -> str: | |
| """ | |
| Formats top-k labels as a string for LLM prompt injection. | |
| Example: "sadness(0.87), fear(0.08), anger(0.03)" | |
| """ | |
| return ", ".join( | |
| f"{lbl.label}({lbl.score:.2f})" for lbl in labels[:top_k] | |
| ) | |
| def is_loaded(self) -> bool: | |
| return _pipeline is not None | |
| def load_error(self) -> Optional[str]: | |
| return _load_error | |
| # Singleton | |
| text_emotion_engine = TextEmotionEngine() | |