Spaces:
Running
Running
| """ | |
| Intent Classification Service — runs the trained IntentClassifier model | |
| for real-time inference on user queries. | |
| Loaded once at startup, reused for all search requests. | |
| Intents: single_search, multi_search, filtered_search, free_form (multi-label) | |
| """ | |
| import os | |
| import json | |
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| from transformers import BertModel, BertTokenizer | |
| from config import INTENT_MODEL_PATH, BERT_MODEL_NAME, INTENT_MAX_LENGTH | |
| class IntentClassifier(nn.Module): | |
| """ | |
| BERT + classification head for multi-label intent classification. | |
| Must match the architecture used during training. | |
| """ | |
| def __init__(self, bert_model_name="bert-base-multilingual-uncased", num_intents=4): | |
| super().__init__() | |
| self.bert = BertModel.from_pretrained(bert_model_name) | |
| self.dropout = nn.Dropout(0.3) | |
| self.fc1 = nn.Linear(768, 256) | |
| self.relu = nn.ReLU() | |
| self.fc2 = nn.Linear(256, num_intents) | |
| def forward(self, input_ids, attention_mask): | |
| outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) | |
| cls_output = outputs.last_hidden_state[:, 0, :] # [CLS] token | |
| x = self.dropout(cls_output) | |
| x = self.relu(self.fc1(x)) | |
| x = self.dropout(x) | |
| logits = self.fc2(x) | |
| return logits | |
| class IntentService: | |
| """Singleton service for intent classification at inference time.""" | |
| def __init__(self): | |
| self.model = None | |
| self.tokenizer = None | |
| self.device = None | |
| self.label_names = [] | |
| self._loaded = False | |
| def load(self): | |
| """Load the trained intent classifier. Call once at app startup.""" | |
| if self._loaded: | |
| return | |
| model_dir = INTENT_MODEL_PATH | |
| model_path = os.path.join(model_dir, "model.pt") | |
| label_map_path = os.path.join(model_dir, "label_map.json") | |
| config_path = os.path.join(model_dir, "config.json") | |
| if not os.path.exists(model_path): | |
| print(f"[IntentService] WARNING: Model not found at {model_path}") | |
| print("[IntentService] Intent classification will be unavailable.") | |
| return | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Load config for label names | |
| if os.path.exists(config_path): | |
| with open(config_path, "r") as f: | |
| config = json.load(f) | |
| self.label_names = config.get("label_names", []) | |
| num_intents = config.get("num_intents", 4) | |
| elif os.path.exists(label_map_path): | |
| with open(label_map_path, "r") as f: | |
| label_map = json.load(f) | |
| self.label_names = sorted(label_map.keys(), key=lambda k: label_map[k]) | |
| num_intents = len(self.label_names) | |
| else: | |
| self.label_names = ["single_search", "multi_search", "filtered_search", "free_form"] | |
| num_intents = 4 | |
| print(f"[IntentService] Loading intent classifier ({num_intents} intents)...") | |
| self.model = IntentClassifier( | |
| bert_model_name=BERT_MODEL_NAME, | |
| num_intents=num_intents, | |
| ) | |
| checkpoint = torch.load(model_path, map_location=self.device, weights_only=False) | |
| self.model.load_state_dict(checkpoint["model_state_dict"]) | |
| self.model.to(self.device) | |
| self.model.eval() | |
| self.tokenizer = BertTokenizer.from_pretrained(BERT_MODEL_NAME) | |
| self._loaded = True | |
| print(f"[IntentService] Intent classifier loaded on {self.device}") | |
| print(f"[IntentService] Labels: {self.label_names}") | |
| def predict(self, query: str, threshold: float = 0.5) -> dict: | |
| """ | |
| Classify a query's intents (multi-label). | |
| Returns: | |
| { | |
| "intents": ["single_search", "filtered_search"], | |
| "probabilities": { | |
| "single_search": 0.92, | |
| "multi_search": 0.03, | |
| "filtered_search": 0.87, | |
| "free_form": 0.01 | |
| } | |
| } | |
| """ | |
| if not self._loaded: | |
| return {"intents": [], "probabilities": {}} | |
| tokens = self.tokenizer( | |
| query, | |
| padding="max_length", | |
| truncation=True, | |
| max_length=INTENT_MAX_LENGTH, | |
| return_tensors="pt", | |
| ) | |
| input_ids = tokens["input_ids"].to(self.device) | |
| attention_mask = tokens["attention_mask"].to(self.device) | |
| with torch.no_grad(): | |
| logits = self.model(input_ids, attention_mask) | |
| probs = torch.sigmoid(logits).cpu().numpy()[0] | |
| probabilities = { | |
| name: round(float(probs[i]), 4) for i, name in enumerate(self.label_names) | |
| } | |
| active_intents = [ | |
| name for i, name in enumerate(self.label_names) if probs[i] > threshold | |
| ] | |
| return { | |
| "intents": active_intents, | |
| "probabilities": probabilities, | |
| } | |
| # Global singleton | |
| intent_service = IntentService() | |