agentic-intent-classifier / iab_retrieval.py
manikumargouni's picture
Upload folder using huggingface_hub
0584798 verified
from __future__ import annotations
import json
import re
from functools import lru_cache
import torch
import torch.nn.functional as F
from transformers import AutoModel, AutoTokenizer
from config import (
IAB_RETRIEVAL_DEPTH_BONUS,
IAB_RETRIEVAL_MODEL_MAX_LENGTH,
IAB_RETRIEVAL_MODEL_NAME,
IAB_RETRIEVAL_PREFIX_CONFIDENCE_THRESHOLDS,
IAB_RETRIEVAL_TOP_K,
IAB_TAXONOMY_EMBEDDINGS_PATH,
IAB_TAXONOMY_NODES_PATH,
IAB_TAXONOMY_VERSION,
ensure_artifact_dirs,
)
from iab_taxonomy import IabNode, get_iab_taxonomy, path_to_label
RETRIEVAL_STOPWORDS = {
"a",
"an",
"and",
"are",
"as",
"at",
"best",
"buy",
"for",
"from",
"how",
"i",
"in",
"is",
"it",
"me",
"my",
"need",
"of",
"on",
"or",
"should",
"the",
"to",
"tonight",
"what",
"which",
"with",
}
GTE_QWEN_QUERY_INSTRUCTION = "Given a user query, retrieve the most relevant IAB content taxonomy category."
def round_score(value: float) -> float:
return round(float(value), 4)
def _normalize_keyword(value: str) -> str:
value = value.lower().replace("&", " and ")
value = re.sub(r"[^a-z0-9]+", " ", value)
return " ".join(value.split())
def _keyword_tokens(value: str) -> set[str]:
return {
token
for token in _normalize_keyword(value).split()
if token and token not in RETRIEVAL_STOPWORDS and len(token) > 1
}
def _is_gte_qwen_model(model_name: str) -> bool:
normalized = model_name.lower()
return "gte-qwen" in normalized
def _last_token_pool(last_hidden_state: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
left_padding = bool(torch.all(attention_mask[:, -1] == 1))
if left_padding:
return last_hidden_state[:, -1]
sequence_lengths = attention_mask.sum(dim=1) - 1
batch_indices = torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device)
return last_hidden_state[batch_indices, sequence_lengths]
def _node_keywords(node: IabNode) -> list[str]:
keywords = {node.label, node.path_label}
keywords.update(node.path)
normalized = {_normalize_keyword(keyword) for keyword in keywords if keyword.strip()}
return sorted(keyword for keyword in normalized if keyword)
def _node_retrieval_text(node: IabNode) -> str:
keywords = _node_keywords(node)
parts = [
f"IAB category path: {node.path_label}",
f"Canonical label: {node.label}",
f"Tier depth: {node.level}",
]
if len(node.path) > 1:
parts.append(f"Parent path: {' > '.join(node.path[:-1])}")
if keywords:
parts.append(f"Keywords: {', '.join(keywords)}")
return ". ".join(parts)
def _serialize_node(node: IabNode) -> dict:
return {
"unique_id": node.unique_id,
"parent_id": node.parent_id,
"label": node.label,
"path": list(node.path),
"path_label": node.path_label,
"level": node.level,
"keywords": _node_keywords(node),
"retrieval_text": _node_retrieval_text(node),
}
class LocalTextEmbedder:
def __init__(self, model_name: str, max_length: int):
self.model_name = model_name
self.max_length = max_length
self._tokenizer = None
self._model = None
self._batch_size = 32
self._device = "cuda" if torch.cuda.is_available() else "cpu"
self._is_gte_qwen = _is_gte_qwen_model(model_name)
@property
def tokenizer(self):
if self._tokenizer is None:
self._tokenizer = AutoTokenizer.from_pretrained(
self.model_name,
trust_remote_code=self._is_gte_qwen,
)
return self._tokenizer
@property
def model(self):
if self._model is None:
model_kwargs = {"trust_remote_code": self._is_gte_qwen}
if self._device == "cuda":
model_kwargs["torch_dtype"] = torch.float16
self._model = AutoModel.from_pretrained(self.model_name, **model_kwargs)
self._model.to(self._device)
self._model.eval()
return self._model
def encode_documents(self, texts: list[str], batch_size: int | None = None) -> torch.Tensor:
return self._encode_texts(texts, batch_size=batch_size, treat_as_query=False)
def encode_queries(self, texts: list[str], batch_size: int | None = None) -> torch.Tensor:
return self._encode_texts(texts, batch_size=batch_size, treat_as_query=True)
def _encode_texts(
self,
texts: list[str],
batch_size: int | None = None,
treat_as_query: bool = False,
) -> torch.Tensor:
if not texts:
return torch.empty(0, 0)
effective_batch_size = batch_size or self._batch_size
rows: list[torch.Tensor] = []
for start in range(0, len(texts), effective_batch_size):
batch_texts = texts[start : start + effective_batch_size]
if treat_as_query and self._is_gte_qwen:
batch_texts = [
f"Instruct: {GTE_QWEN_QUERY_INSTRUCTION}\nQuery: {text}"
for text in batch_texts
]
inputs = self.tokenizer(
batch_texts,
return_tensors="pt",
truncation=True,
padding=True,
max_length=self.max_length,
)
inputs = {key: value.to(self._device) for key, value in inputs.items()}
with torch.no_grad():
outputs = self.model(**inputs)
hidden = outputs.last_hidden_state
if self._is_gte_qwen:
pooled = _last_token_pool(hidden, inputs["attention_mask"])
else:
mask = inputs["attention_mask"].unsqueeze(-1)
pooled = (hidden * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1)
rows.append(F.normalize(pooled.float(), p=2, dim=1).cpu())
return torch.cat(rows, dim=0)
@lru_cache(maxsize=1)
def get_iab_text_embedder() -> LocalTextEmbedder:
return LocalTextEmbedder(IAB_RETRIEVAL_MODEL_NAME, IAB_RETRIEVAL_MODEL_MAX_LENGTH)
def build_iab_taxonomy_embedding_index(batch_size: int = 32) -> dict:
ensure_artifact_dirs()
taxonomy = get_iab_taxonomy()
nodes = [_serialize_node(node) for node in taxonomy.nodes]
embedder = get_iab_text_embedder()
embeddings = embedder.encode_documents([node["retrieval_text"] for node in nodes], batch_size=batch_size)
IAB_TAXONOMY_NODES_PATH.write_text(json.dumps(nodes, indent=2, sort_keys=True) + "\n", encoding="utf-8")
torch.save(
{
"model_name": embedder.model_name,
"taxonomy_version": IAB_TAXONOMY_VERSION,
"embedding_dim": int(embeddings.shape[1]),
"node_count": len(nodes),
"embeddings": embeddings,
},
IAB_TAXONOMY_EMBEDDINGS_PATH,
)
return {
"taxonomy_version": IAB_TAXONOMY_VERSION,
"model_name": embedder.model_name,
"node_count": len(nodes),
"embedding_dim": int(embeddings.shape[1]),
"nodes_path": str(IAB_TAXONOMY_NODES_PATH),
"embeddings_path": str(IAB_TAXONOMY_EMBEDDINGS_PATH),
}
class IabEmbeddingRetriever:
def __init__(self):
self.taxonomy = get_iab_taxonomy()
self.embedder = get_iab_text_embedder()
self._nodes: list[dict] | None = None
self._embeddings: torch.Tensor | None = None
def _load_index(self) -> bool:
if self._nodes is not None and self._embeddings is not None:
return True
if not IAB_TAXONOMY_NODES_PATH.exists() or not IAB_TAXONOMY_EMBEDDINGS_PATH.exists():
return False
nodes = json.loads(IAB_TAXONOMY_NODES_PATH.read_text(encoding="utf-8"))
payload = torch.load(IAB_TAXONOMY_EMBEDDINGS_PATH, map_location="cpu")
if payload.get("model_name") != IAB_RETRIEVAL_MODEL_NAME:
return False
if payload.get("taxonomy_version") != IAB_TAXONOMY_VERSION:
return False
embeddings = payload.get("embeddings")
if not isinstance(embeddings, torch.Tensor):
embeddings = torch.tensor(embeddings, dtype=torch.float32)
if len(nodes) != embeddings.shape[0]:
return False
self._nodes = nodes
self._embeddings = F.normalize(embeddings.float(), p=2, dim=1)
return True
def ready(self) -> bool:
return self._load_index()
@staticmethod
def _score_to_confidence(score: float) -> float:
return min(max((score + 1.0) / 2.0, 0.0), 1.0)
def _candidate_from_index(self, score: float, index: int) -> dict:
assert self._nodes is not None
node = self._nodes[index]
confidence = self._score_to_confidence(float(score))
adjusted_confidence = confidence + (IAB_RETRIEVAL_DEPTH_BONUS * max(int(node["level"]) - 1, 0))
return {
"unique_id": node["unique_id"],
"label": node["label"],
"path": tuple(node["path"]),
"path_label": node["path_label"],
"level": int(node["level"]),
"confidence": round_score(confidence),
"adjusted_confidence": round_score(adjusted_confidence),
"keywords": list(node.get("keywords", [])),
}
def _rerank_candidates(self, query_text: str, candidates: list[dict]) -> list[dict]:
if not candidates:
return []
query_normalized = _normalize_keyword(query_text)
query_tokens = _keyword_tokens(query_text)
reranked = []
for candidate in candidates:
keyword_tokens = set()
for keyword in candidate.get("keywords", []):
keyword_tokens.update(_keyword_tokens(keyword))
token_overlap = len(query_tokens & keyword_tokens)
path_overlap = len(query_tokens & _keyword_tokens(candidate["path_label"]))
lexical_bonus = min(0.04, (0.008 * token_overlap) + (0.004 * path_overlap))
reranked.append(
{
**candidate,
"token_overlap": token_overlap,
"path_overlap": path_overlap,
"lexical_bonus": round_score(lexical_bonus),
"rerank_score": round_score(candidate["adjusted_confidence"] + lexical_bonus),
}
)
reranked.sort(
key=lambda item: (
item["rerank_score"],
item["adjusted_confidence"],
item["confidence"],
),
reverse=True,
)
return reranked
def _top_candidates_from_embedding(self, query_text: str, query_embedding: torch.Tensor) -> list[dict]:
if not self._load_index():
return []
assert self._embeddings is not None
scores = torch.mv(self._embeddings, query_embedding)
top_k = min(IAB_RETRIEVAL_TOP_K, scores.shape[0])
top_scores, top_indices = torch.topk(scores, k=top_k)
candidates = [self._candidate_from_index(score, index) for score, index in zip(top_scores.tolist(), top_indices.tolist())]
return self._rerank_candidates(query_text, candidates)
def _top_candidates(self, text: str) -> list[dict]:
if not self._load_index():
return []
query_embedding = self.embedder.encode_queries([text])[0]
return self._top_candidates_from_embedding(text, query_embedding)
def _select_path(self, candidates: list[dict]) -> dict | None:
if not candidates:
return None
top_candidate = candidates[0]
top_path = tuple(top_candidate["path"])
top_margin = round_score(
top_candidate["confidence"] - candidates[1]["confidence"] if len(candidates) > 1 else top_candidate["confidence"]
)
prefix_support: dict[tuple[str, ...], float] = {}
for depth in range(1, len(top_path) + 1):
prefix = top_path[:depth]
prefix_support[prefix] = max(
candidate["confidence"]
for candidate in candidates
if tuple(candidate["path"][:depth]) == prefix
)
selected_path: tuple[str, ...] | None = None
selected_threshold = 0.0
for depth in range(1, len(top_path) + 1):
threshold = IAB_RETRIEVAL_PREFIX_CONFIDENCE_THRESHOLDS.get(depth, 0.62)
prefix = top_path[:depth]
if prefix_support[prefix] >= threshold:
selected_path = prefix
selected_threshold = threshold
continue
break
if selected_path is None:
return None
stopped_reason = "accepted" if selected_path == top_path else "parent_fallback"
if len(top_path) > 1:
ambiguous_sibling = any(
tuple(candidate["path"][:-1]) == top_path[:-1]
and (top_candidate["confidence"] - candidate["confidence"]) <= 0.03
for candidate in candidates[1:]
)
if ambiguous_sibling:
selected_path = top_path[:-1]
selected_threshold = IAB_RETRIEVAL_PREFIX_CONFIDENCE_THRESHOLDS.get(len(selected_path), 0.62)
stopped_reason = "ambiguous_sibling_parent_fallback"
mapping_confidence = prefix_support[selected_path]
return {
"path": selected_path,
"path_label": path_to_label(selected_path),
"mapping_mode": "nearest_equivalent",
"mapping_confidence": round_score(mapping_confidence),
"confidence_threshold": round_score(selected_threshold),
"top_candidate_confidence": round_score(top_candidate["confidence"]),
"top_margin": top_margin,
"stopped_reason": stopped_reason,
}
def predict(self, text: str) -> dict | None:
candidates = self._top_candidates(text)
return self._prediction_from_candidates(candidates)
def _prediction_from_candidates(self, candidates: list[dict]) -> dict | None:
selection = self._select_path(candidates)
if selection is None:
return None
content = self.taxonomy.build_content_object(
path=selection["path"],
mapping_mode=selection["mapping_mode"],
mapping_confidence=selection["mapping_confidence"],
)
return {
"label": selection["path_label"],
"confidence": selection["mapping_confidence"],
"raw_confidence": selection["top_candidate_confidence"],
"confidence_threshold": selection["confidence_threshold"],
"calibrated": False,
"meets_confidence_threshold": True,
"content": content,
"path": selection["path"],
"mapping_mode": selection["mapping_mode"],
"mapping_confidence": selection["mapping_confidence"],
"source": "embedding_retrieval",
"retrieval_model_name": IAB_RETRIEVAL_MODEL_NAME,
"stopped_reason": selection["stopped_reason"],
"top_margin": selection["top_margin"],
"top_candidates": [
{
**candidate,
"path": list(candidate["path"]),
"keywords": candidate["keywords"][:12],
}
for candidate in candidates
],
}
def predict_batch(self, texts: list[str], batch_size: int | None = None) -> list[dict | None]:
if not texts:
return []
if not self._load_index():
return [None for _ in texts]
query_embeddings = self.embedder.encode_queries(texts, batch_size=batch_size)
return [
self._prediction_from_candidates(self._top_candidates_from_embedding(text, query_embedding))
for text, query_embedding in zip(texts, query_embeddings)
]
@lru_cache(maxsize=1)
def get_iab_embedding_retriever() -> IabEmbeddingRetriever:
return IabEmbeddingRetriever()
def predict_iab_content_retrieval(text: str) -> dict | None:
retriever = get_iab_embedding_retriever()
if not retriever.ready():
return None
return retriever.predict(text)
def predict_iab_content_retrieval_batch(texts: list[str], batch_size: int | None = None) -> list[dict | None]:
retriever = get_iab_embedding_retriever()
if not retriever.ready():
return [None for _ in texts]
return retriever.predict_batch(texts, batch_size=batch_size)