RetailTalk / backend /models /intent_service.py
Dashm
Initial commit — RetailTalk backend for HuggingFace Spaces
26d82f3
"""
Intent Classification Service — runs the trained IntentClassifier model
for real-time inference on user queries.
Loaded once at startup, reused for all search requests.
Intents: single_search, multi_search, filtered_search, free_form (multi-label)
"""
import os
import json
import torch
import torch.nn as nn
import numpy as np
from transformers import BertModel, BertTokenizer
from config import INTENT_MODEL_PATH, BERT_MODEL_NAME, INTENT_MAX_LENGTH
class IntentClassifier(nn.Module):
"""
BERT + classification head for multi-label intent classification.
Must match the architecture used during training.
"""
def __init__(self, bert_model_name="bert-base-multilingual-uncased", num_intents=4):
super().__init__()
self.bert = BertModel.from_pretrained(bert_model_name)
self.dropout = nn.Dropout(0.3)
self.fc1 = nn.Linear(768, 256)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(256, num_intents)
def forward(self, input_ids, attention_mask):
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
cls_output = outputs.last_hidden_state[:, 0, :] # [CLS] token
x = self.dropout(cls_output)
x = self.relu(self.fc1(x))
x = self.dropout(x)
logits = self.fc2(x)
return logits
class IntentService:
"""Singleton service for intent classification at inference time."""
def __init__(self):
self.model = None
self.tokenizer = None
self.device = None
self.label_names = []
self._loaded = False
def load(self):
"""Load the trained intent classifier. Call once at app startup."""
if self._loaded:
return
model_dir = INTENT_MODEL_PATH
model_path = os.path.join(model_dir, "model.pt")
label_map_path = os.path.join(model_dir, "label_map.json")
config_path = os.path.join(model_dir, "config.json")
if not os.path.exists(model_path):
print(f"[IntentService] WARNING: Model not found at {model_path}")
print("[IntentService] Intent classification will be unavailable.")
return
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load config for label names
if os.path.exists(config_path):
with open(config_path, "r") as f:
config = json.load(f)
self.label_names = config.get("label_names", [])
num_intents = config.get("num_intents", 4)
elif os.path.exists(label_map_path):
with open(label_map_path, "r") as f:
label_map = json.load(f)
self.label_names = sorted(label_map.keys(), key=lambda k: label_map[k])
num_intents = len(self.label_names)
else:
self.label_names = ["single_search", "multi_search", "filtered_search", "free_form"]
num_intents = 4
print(f"[IntentService] Loading intent classifier ({num_intents} intents)...")
self.model = IntentClassifier(
bert_model_name=BERT_MODEL_NAME,
num_intents=num_intents,
)
checkpoint = torch.load(model_path, map_location=self.device, weights_only=False)
self.model.load_state_dict(checkpoint["model_state_dict"])
self.model.to(self.device)
self.model.eval()
self.tokenizer = BertTokenizer.from_pretrained(BERT_MODEL_NAME)
self._loaded = True
print(f"[IntentService] Intent classifier loaded on {self.device}")
print(f"[IntentService] Labels: {self.label_names}")
def predict(self, query: str, threshold: float = 0.5) -> dict:
"""
Classify a query's intents (multi-label).
Returns:
{
"intents": ["single_search", "filtered_search"],
"probabilities": {
"single_search": 0.92,
"multi_search": 0.03,
"filtered_search": 0.87,
"free_form": 0.01
}
}
"""
if not self._loaded:
return {"intents": [], "probabilities": {}}
tokens = self.tokenizer(
query,
padding="max_length",
truncation=True,
max_length=INTENT_MAX_LENGTH,
return_tensors="pt",
)
input_ids = tokens["input_ids"].to(self.device)
attention_mask = tokens["attention_mask"].to(self.device)
with torch.no_grad():
logits = self.model(input_ids, attention_mask)
probs = torch.sigmoid(logits).cpu().numpy()[0]
probabilities = {
name: round(float(probs[i]), 4) for i, name in enumerate(self.label_names)
}
active_intents = [
name for i, name in enumerate(self.label_names) if probs[i] > threshold
]
return {
"intents": active_intents,
"probabilities": probabilities,
}
# Global singleton
intent_service = IntentService()