|
from typing import Dict, Any, List |
|
import torch |
|
from transformers import AutoTokenizer, AutoModel |
|
import os |
|
import json |
|
|
|
class EndpointHandler: |
|
def __init__(self, path: str = ""): |
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
self.tokenizer = AutoTokenizer.from_pretrained(path) |
|
self.tokenizer.add_special_tokens({ |
|
"additional_special_tokens": ["[QUERY]", "[LABEL_NAME]", "[LABEL_DESCRIPTION]"] |
|
}) |
|
self.model = AutoModel.from_pretrained(path).to(self.device) |
|
|
|
head_path = os.path.join(path, "classifier_head.json") |
|
with open(head_path, "r") as f: |
|
head = json.load(f) |
|
|
|
self.classifier = torch.nn.Linear(self.model.config.hidden_size, 1).to(self.device) |
|
self.classifier.weight.data = torch.tensor(head["scorer_weight"]).to(self.device) |
|
self.classifier.bias.data = torch.tensor(head["scorer_bias"]).to(self.device) |
|
|
|
self.model.eval() |
|
|
|
|
|
self.max_batch_size = 128 |
|
self.max_length = 64 |
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
|
payload = data.get("inputs", data) |
|
|
|
|
|
if "queries" in payload: |
|
return self._process_batch(payload) |
|
else: |
|
return self._process_single(payload) |
|
|
|
def _process_single(self, payload: Dict[str, Any]) -> List[Dict[str, Any]]: |
|
"""Original single query processing for backward compatibility""" |
|
query = payload["query"] |
|
candidates = payload["candidates"] |
|
results = [] |
|
|
|
with torch.no_grad(): |
|
for entry in candidates: |
|
text = f"[QUERY] {query} [LABEL_NAME] {entry['label']} [LABEL_DESCRIPTION] {entry['description']}" |
|
tokens = self.tokenizer( |
|
text, |
|
return_tensors="pt", |
|
padding="max_length", |
|
truncation=True, |
|
max_length=self.max_length |
|
).to(self.device) |
|
|
|
out = self.model(**tokens) |
|
cls = out.last_hidden_state[:, 0, :] |
|
score = torch.sigmoid(self.classifier(cls)).item() |
|
results.append({ |
|
"label": entry["label"], |
|
"description": entry["description"], |
|
"score": round(score, 4) |
|
}) |
|
|
|
return sorted(results, key=lambda x: x["score"], reverse=True) |
|
|
|
def _process_batch(self, payload: Dict[str, Any]) -> List[List[Dict[str, Any]]]: |
|
"""True batch processing for multiple queries""" |
|
queries = payload["queries"] |
|
candidates = payload["candidates"] |
|
|
|
|
|
all_texts = [] |
|
query_indices = [] |
|
candidate_indices = [] |
|
|
|
for q_idx, query in enumerate(queries): |
|
for c_idx, candidate in enumerate(candidates): |
|
text = f"[QUERY] {query} [LABEL_NAME] {candidate['label']} [LABEL_DESCRIPTION] {candidate['description']}" |
|
all_texts.append(text) |
|
query_indices.append(q_idx) |
|
candidate_indices.append(c_idx) |
|
|
|
|
|
all_scores = [] |
|
total_combinations = len(all_texts) |
|
|
|
with torch.no_grad(): |
|
for i in range(0, total_combinations, self.max_batch_size): |
|
batch_texts = all_texts[i:i + self.max_batch_size] |
|
|
|
|
|
tokens = self.tokenizer( |
|
batch_texts, |
|
return_tensors="pt", |
|
padding="max_length", |
|
truncation=True, |
|
max_length=self.max_length |
|
).to(self.device) |
|
|
|
|
|
out = self.model(**tokens) |
|
cls = out.last_hidden_state[:, 0, :] |
|
scores = torch.sigmoid(self.classifier(cls)).squeeze() |
|
|
|
|
|
if scores.dim() == 0: |
|
scores = scores.unsqueeze(0) |
|
|
|
all_scores.extend(scores.cpu().tolist()) |
|
|
|
|
|
results = [] |
|
for q_idx in range(len(queries)): |
|
query_results = [] |
|
for c_idx, candidate in enumerate(candidates): |
|
|
|
combination_idx = q_idx * len(candidates) + c_idx |
|
score = all_scores[combination_idx] |
|
|
|
query_results.append({ |
|
"label": candidate["label"], |
|
"description": candidate["description"], |
|
"score": round(score, 4) |
|
}) |
|
|
|
|
|
query_results.sort(key=lambda x: x["score"], reverse=True) |
|
results.append(query_results) |
|
|
|
return results |
|
|
|
def get_batch_stats(self) -> Dict[str, Any]: |
|
"""Return batch processing statistics""" |
|
return { |
|
"max_batch_size": self.max_batch_size, |
|
"max_length": self.max_length, |
|
"device": str(self.device), |
|
"model_name": self.model.config.name_or_path if hasattr(self.model.config, 'name_or_path') else "unknown" |
|
} |
|
|
|
|