| from __future__ import annotations |
|
|
| import json |
| import re |
| from functools import lru_cache |
|
|
| import torch |
| import torch.nn.functional as F |
| from transformers import AutoModel, AutoTokenizer |
|
|
| from config import ( |
| IAB_RETRIEVAL_DEPTH_BONUS, |
| IAB_RETRIEVAL_MODEL_MAX_LENGTH, |
| IAB_RETRIEVAL_MODEL_NAME, |
| IAB_RETRIEVAL_PREFIX_CONFIDENCE_THRESHOLDS, |
| IAB_RETRIEVAL_TOP_K, |
| IAB_TAXONOMY_EMBEDDINGS_PATH, |
| IAB_TAXONOMY_NODES_PATH, |
| IAB_TAXONOMY_VERSION, |
| ensure_artifact_dirs, |
| ) |
| from iab_taxonomy import IabNode, get_iab_taxonomy, path_to_label |
|
|
| RETRIEVAL_STOPWORDS = { |
| "a", |
| "an", |
| "and", |
| "are", |
| "as", |
| "at", |
| "best", |
| "buy", |
| "for", |
| "from", |
| "how", |
| "i", |
| "in", |
| "is", |
| "it", |
| "me", |
| "my", |
| "need", |
| "of", |
| "on", |
| "or", |
| "should", |
| "the", |
| "to", |
| "tonight", |
| "what", |
| "which", |
| "with", |
| } |
| GTE_QWEN_QUERY_INSTRUCTION = "Given a user query, retrieve the most relevant IAB content taxonomy category." |
|
|
|
|
| def round_score(value: float) -> float: |
| return round(float(value), 4) |
|
|
|
|
| def _normalize_keyword(value: str) -> str: |
| value = value.lower().replace("&", " and ") |
| value = re.sub(r"[^a-z0-9]+", " ", value) |
| return " ".join(value.split()) |
|
|
|
|
| def _keyword_tokens(value: str) -> set[str]: |
| return { |
| token |
| for token in _normalize_keyword(value).split() |
| if token and token not in RETRIEVAL_STOPWORDS and len(token) > 1 |
| } |
|
|
|
|
| def _is_gte_qwen_model(model_name: str) -> bool: |
| normalized = model_name.lower() |
| return "gte-qwen" in normalized |
|
|
|
|
| def _last_token_pool(last_hidden_state: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: |
| left_padding = bool(torch.all(attention_mask[:, -1] == 1)) |
| if left_padding: |
| return last_hidden_state[:, -1] |
|
|
| sequence_lengths = attention_mask.sum(dim=1) - 1 |
| batch_indices = torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device) |
| return last_hidden_state[batch_indices, sequence_lengths] |
|
|
|
|
| def _node_keywords(node: IabNode) -> list[str]: |
| keywords = {node.label, node.path_label} |
| keywords.update(node.path) |
| normalized = {_normalize_keyword(keyword) for keyword in keywords if keyword.strip()} |
| return sorted(keyword for keyword in normalized if keyword) |
|
|
|
|
| def _node_retrieval_text(node: IabNode) -> str: |
| keywords = _node_keywords(node) |
| parts = [ |
| f"IAB category path: {node.path_label}", |
| f"Canonical label: {node.label}", |
| f"Tier depth: {node.level}", |
| ] |
| if len(node.path) > 1: |
| parts.append(f"Parent path: {' > '.join(node.path[:-1])}") |
| if keywords: |
| parts.append(f"Keywords: {', '.join(keywords)}") |
| return ". ".join(parts) |
|
|
|
|
| def _serialize_node(node: IabNode) -> dict: |
| return { |
| "unique_id": node.unique_id, |
| "parent_id": node.parent_id, |
| "label": node.label, |
| "path": list(node.path), |
| "path_label": node.path_label, |
| "level": node.level, |
| "keywords": _node_keywords(node), |
| "retrieval_text": _node_retrieval_text(node), |
| } |
|
|
|
|
| class LocalTextEmbedder: |
| def __init__(self, model_name: str, max_length: int): |
| self.model_name = model_name |
| self.max_length = max_length |
| self._tokenizer = None |
| self._model = None |
| self._batch_size = 32 |
| self._device = "cuda" if torch.cuda.is_available() else "cpu" |
| self._is_gte_qwen = _is_gte_qwen_model(model_name) |
|
|
| @property |
| def tokenizer(self): |
| if self._tokenizer is None: |
| self._tokenizer = AutoTokenizer.from_pretrained( |
| self.model_name, |
| trust_remote_code=self._is_gte_qwen, |
| ) |
| return self._tokenizer |
|
|
| @property |
| def model(self): |
| if self._model is None: |
| model_kwargs = {"trust_remote_code": self._is_gte_qwen} |
| if self._device == "cuda": |
| model_kwargs["torch_dtype"] = torch.float16 |
| self._model = AutoModel.from_pretrained(self.model_name, **model_kwargs) |
| self._model.to(self._device) |
| self._model.eval() |
| return self._model |
|
|
| def encode_documents(self, texts: list[str], batch_size: int | None = None) -> torch.Tensor: |
| return self._encode_texts(texts, batch_size=batch_size, treat_as_query=False) |
|
|
| def encode_queries(self, texts: list[str], batch_size: int | None = None) -> torch.Tensor: |
| return self._encode_texts(texts, batch_size=batch_size, treat_as_query=True) |
|
|
| def _encode_texts( |
| self, |
| texts: list[str], |
| batch_size: int | None = None, |
| treat_as_query: bool = False, |
| ) -> torch.Tensor: |
| if not texts: |
| return torch.empty(0, 0) |
|
|
| effective_batch_size = batch_size or self._batch_size |
| rows: list[torch.Tensor] = [] |
| for start in range(0, len(texts), effective_batch_size): |
| batch_texts = texts[start : start + effective_batch_size] |
| if treat_as_query and self._is_gte_qwen: |
| batch_texts = [ |
| f"Instruct: {GTE_QWEN_QUERY_INSTRUCTION}\nQuery: {text}" |
| for text in batch_texts |
| ] |
| inputs = self.tokenizer( |
| batch_texts, |
| return_tensors="pt", |
| truncation=True, |
| padding=True, |
| max_length=self.max_length, |
| ) |
| inputs = {key: value.to(self._device) for key, value in inputs.items()} |
| with torch.no_grad(): |
| outputs = self.model(**inputs) |
| hidden = outputs.last_hidden_state |
| if self._is_gte_qwen: |
| pooled = _last_token_pool(hidden, inputs["attention_mask"]) |
| else: |
| mask = inputs["attention_mask"].unsqueeze(-1) |
| pooled = (hidden * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1) |
| rows.append(F.normalize(pooled.float(), p=2, dim=1).cpu()) |
| return torch.cat(rows, dim=0) |
|
|
|
|
| @lru_cache(maxsize=1) |
| def get_iab_text_embedder() -> LocalTextEmbedder: |
| return LocalTextEmbedder(IAB_RETRIEVAL_MODEL_NAME, IAB_RETRIEVAL_MODEL_MAX_LENGTH) |
|
|
|
|
| def build_iab_taxonomy_embedding_index(batch_size: int = 32) -> dict: |
| ensure_artifact_dirs() |
| taxonomy = get_iab_taxonomy() |
| nodes = [_serialize_node(node) for node in taxonomy.nodes] |
| embedder = get_iab_text_embedder() |
| embeddings = embedder.encode_documents([node["retrieval_text"] for node in nodes], batch_size=batch_size) |
|
|
| IAB_TAXONOMY_NODES_PATH.write_text(json.dumps(nodes, indent=2, sort_keys=True) + "\n", encoding="utf-8") |
| torch.save( |
| { |
| "model_name": embedder.model_name, |
| "taxonomy_version": IAB_TAXONOMY_VERSION, |
| "embedding_dim": int(embeddings.shape[1]), |
| "node_count": len(nodes), |
| "embeddings": embeddings, |
| }, |
| IAB_TAXONOMY_EMBEDDINGS_PATH, |
| ) |
| return { |
| "taxonomy_version": IAB_TAXONOMY_VERSION, |
| "model_name": embedder.model_name, |
| "node_count": len(nodes), |
| "embedding_dim": int(embeddings.shape[1]), |
| "nodes_path": str(IAB_TAXONOMY_NODES_PATH), |
| "embeddings_path": str(IAB_TAXONOMY_EMBEDDINGS_PATH), |
| } |
|
|
|
|
| class IabEmbeddingRetriever: |
| def __init__(self): |
| self.taxonomy = get_iab_taxonomy() |
| self.embedder = get_iab_text_embedder() |
| self._nodes: list[dict] | None = None |
| self._embeddings: torch.Tensor | None = None |
|
|
| def _load_index(self) -> bool: |
| if self._nodes is not None and self._embeddings is not None: |
| return True |
| if not IAB_TAXONOMY_NODES_PATH.exists() or not IAB_TAXONOMY_EMBEDDINGS_PATH.exists(): |
| return False |
|
|
| nodes = json.loads(IAB_TAXONOMY_NODES_PATH.read_text(encoding="utf-8")) |
| payload = torch.load(IAB_TAXONOMY_EMBEDDINGS_PATH, map_location="cpu") |
| if payload.get("model_name") != IAB_RETRIEVAL_MODEL_NAME: |
| return False |
| if payload.get("taxonomy_version") != IAB_TAXONOMY_VERSION: |
| return False |
|
|
| embeddings = payload.get("embeddings") |
| if not isinstance(embeddings, torch.Tensor): |
| embeddings = torch.tensor(embeddings, dtype=torch.float32) |
| if len(nodes) != embeddings.shape[0]: |
| return False |
|
|
| self._nodes = nodes |
| self._embeddings = F.normalize(embeddings.float(), p=2, dim=1) |
| return True |
|
|
| def ready(self) -> bool: |
| return self._load_index() |
|
|
| @staticmethod |
| def _score_to_confidence(score: float) -> float: |
| return min(max((score + 1.0) / 2.0, 0.0), 1.0) |
|
|
| def _candidate_from_index(self, score: float, index: int) -> dict: |
| assert self._nodes is not None |
| node = self._nodes[index] |
| confidence = self._score_to_confidence(float(score)) |
| adjusted_confidence = confidence + (IAB_RETRIEVAL_DEPTH_BONUS * max(int(node["level"]) - 1, 0)) |
| return { |
| "unique_id": node["unique_id"], |
| "label": node["label"], |
| "path": tuple(node["path"]), |
| "path_label": node["path_label"], |
| "level": int(node["level"]), |
| "confidence": round_score(confidence), |
| "adjusted_confidence": round_score(adjusted_confidence), |
| "keywords": list(node.get("keywords", [])), |
| } |
|
|
| def _rerank_candidates(self, query_text: str, candidates: list[dict]) -> list[dict]: |
| if not candidates: |
| return [] |
|
|
| query_normalized = _normalize_keyword(query_text) |
| query_tokens = _keyword_tokens(query_text) |
|
|
| reranked = [] |
| for candidate in candidates: |
| keyword_tokens = set() |
| for keyword in candidate.get("keywords", []): |
| keyword_tokens.update(_keyword_tokens(keyword)) |
|
|
| token_overlap = len(query_tokens & keyword_tokens) |
| path_overlap = len(query_tokens & _keyword_tokens(candidate["path_label"])) |
| lexical_bonus = min(0.04, (0.008 * token_overlap) + (0.004 * path_overlap)) |
|
|
| reranked.append( |
| { |
| **candidate, |
| "token_overlap": token_overlap, |
| "path_overlap": path_overlap, |
| "lexical_bonus": round_score(lexical_bonus), |
| "rerank_score": round_score(candidate["adjusted_confidence"] + lexical_bonus), |
| } |
| ) |
|
|
| reranked.sort( |
| key=lambda item: ( |
| item["rerank_score"], |
| item["adjusted_confidence"], |
| item["confidence"], |
| ), |
| reverse=True, |
| ) |
| return reranked |
|
|
| def _top_candidates_from_embedding(self, query_text: str, query_embedding: torch.Tensor) -> list[dict]: |
| if not self._load_index(): |
| return [] |
|
|
| assert self._embeddings is not None |
|
|
| scores = torch.mv(self._embeddings, query_embedding) |
| top_k = min(IAB_RETRIEVAL_TOP_K, scores.shape[0]) |
| top_scores, top_indices = torch.topk(scores, k=top_k) |
|
|
| candidates = [self._candidate_from_index(score, index) for score, index in zip(top_scores.tolist(), top_indices.tolist())] |
| return self._rerank_candidates(query_text, candidates) |
|
|
| def _top_candidates(self, text: str) -> list[dict]: |
| if not self._load_index(): |
| return [] |
| query_embedding = self.embedder.encode_queries([text])[0] |
| return self._top_candidates_from_embedding(text, query_embedding) |
|
|
| def _select_path(self, candidates: list[dict]) -> dict | None: |
| if not candidates: |
| return None |
|
|
| top_candidate = candidates[0] |
| top_path = tuple(top_candidate["path"]) |
| top_margin = round_score( |
| top_candidate["confidence"] - candidates[1]["confidence"] if len(candidates) > 1 else top_candidate["confidence"] |
| ) |
| prefix_support: dict[tuple[str, ...], float] = {} |
| for depth in range(1, len(top_path) + 1): |
| prefix = top_path[:depth] |
| prefix_support[prefix] = max( |
| candidate["confidence"] |
| for candidate in candidates |
| if tuple(candidate["path"][:depth]) == prefix |
| ) |
|
|
| selected_path: tuple[str, ...] | None = None |
| selected_threshold = 0.0 |
| for depth in range(1, len(top_path) + 1): |
| threshold = IAB_RETRIEVAL_PREFIX_CONFIDENCE_THRESHOLDS.get(depth, 0.62) |
| prefix = top_path[:depth] |
| if prefix_support[prefix] >= threshold: |
| selected_path = prefix |
| selected_threshold = threshold |
| continue |
| break |
|
|
| if selected_path is None: |
| return None |
|
|
| stopped_reason = "accepted" if selected_path == top_path else "parent_fallback" |
| if len(top_path) > 1: |
| ambiguous_sibling = any( |
| tuple(candidate["path"][:-1]) == top_path[:-1] |
| and (top_candidate["confidence"] - candidate["confidence"]) <= 0.03 |
| for candidate in candidates[1:] |
| ) |
| if ambiguous_sibling: |
| selected_path = top_path[:-1] |
| selected_threshold = IAB_RETRIEVAL_PREFIX_CONFIDENCE_THRESHOLDS.get(len(selected_path), 0.62) |
| stopped_reason = "ambiguous_sibling_parent_fallback" |
|
|
| mapping_confidence = prefix_support[selected_path] |
| return { |
| "path": selected_path, |
| "path_label": path_to_label(selected_path), |
| "mapping_mode": "nearest_equivalent", |
| "mapping_confidence": round_score(mapping_confidence), |
| "confidence_threshold": round_score(selected_threshold), |
| "top_candidate_confidence": round_score(top_candidate["confidence"]), |
| "top_margin": top_margin, |
| "stopped_reason": stopped_reason, |
| } |
|
|
| def predict(self, text: str) -> dict | None: |
| candidates = self._top_candidates(text) |
| return self._prediction_from_candidates(candidates) |
|
|
| def _prediction_from_candidates(self, candidates: list[dict]) -> dict | None: |
| selection = self._select_path(candidates) |
| if selection is None: |
| return None |
|
|
| content = self.taxonomy.build_content_object( |
| path=selection["path"], |
| mapping_mode=selection["mapping_mode"], |
| mapping_confidence=selection["mapping_confidence"], |
| ) |
| return { |
| "label": selection["path_label"], |
| "confidence": selection["mapping_confidence"], |
| "raw_confidence": selection["top_candidate_confidence"], |
| "confidence_threshold": selection["confidence_threshold"], |
| "calibrated": False, |
| "meets_confidence_threshold": True, |
| "content": content, |
| "path": selection["path"], |
| "mapping_mode": selection["mapping_mode"], |
| "mapping_confidence": selection["mapping_confidence"], |
| "source": "embedding_retrieval", |
| "retrieval_model_name": IAB_RETRIEVAL_MODEL_NAME, |
| "stopped_reason": selection["stopped_reason"], |
| "top_margin": selection["top_margin"], |
| "top_candidates": [ |
| { |
| **candidate, |
| "path": list(candidate["path"]), |
| "keywords": candidate["keywords"][:12], |
| } |
| for candidate in candidates |
| ], |
| } |
|
|
| def predict_batch(self, texts: list[str], batch_size: int | None = None) -> list[dict | None]: |
| if not texts: |
| return [] |
| if not self._load_index(): |
| return [None for _ in texts] |
|
|
| query_embeddings = self.embedder.encode_queries(texts, batch_size=batch_size) |
| return [ |
| self._prediction_from_candidates(self._top_candidates_from_embedding(text, query_embedding)) |
| for text, query_embedding in zip(texts, query_embeddings) |
| ] |
|
|
|
|
| @lru_cache(maxsize=1) |
| def get_iab_embedding_retriever() -> IabEmbeddingRetriever: |
| return IabEmbeddingRetriever() |
|
|
|
|
| def predict_iab_content_retrieval(text: str) -> dict | None: |
| retriever = get_iab_embedding_retriever() |
| if not retriever.ready(): |
| return None |
| return retriever.predict(text) |
|
|
|
|
| def predict_iab_content_retrieval_batch(texts: list[str], batch_size: int | None = None) -> list[dict | None]: |
| retriever = get_iab_embedding_retriever() |
| if not retriever.ready(): |
| return [None for _ in texts] |
| return retriever.predict_batch(texts, batch_size=batch_size) |
|
|