dejanseo's picture
Update handler.py
c6e6058 verified
from typing import Dict, Any, List
import torch
from transformers import AutoTokenizer, AutoModel
import os
import json
class EndpointHandler:
def __init__(self, path: str = ""):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.tokenizer = AutoTokenizer.from_pretrained(path)
self.tokenizer.add_special_tokens({
"additional_special_tokens": ["[QUERY]", "[LABEL_NAME]", "[LABEL_DESCRIPTION]"]
})
self.model = AutoModel.from_pretrained(path).to(self.device)
head_path = os.path.join(path, "classifier_head.json")
with open(head_path, "r") as f:
head = json.load(f)
self.classifier = torch.nn.Linear(self.model.config.hidden_size, 1).to(self.device)
self.classifier.weight.data = torch.tensor(head["scorer_weight"]).to(self.device)
self.classifier.bias.data = torch.tensor(head["scorer_bias"]).to(self.device)
self.model.eval()
# Batch processing configuration
self.max_batch_size = 128 # Adjust based on GPU memory
self.max_length = 64
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
payload = data.get("inputs", data)
# Check if this is batch processing (multiple queries) or single query
if "queries" in payload:
return self._process_batch(payload)
else:
return self._process_single(payload)
def _process_single(self, payload: Dict[str, Any]) -> List[Dict[str, Any]]:
"""Original single query processing for backward compatibility"""
query = payload["query"]
candidates = payload["candidates"]
results = []
with torch.no_grad():
for entry in candidates:
text = f"[QUERY] {query} [LABEL_NAME] {entry['label']} [LABEL_DESCRIPTION] {entry['description']}"
tokens = self.tokenizer(
text,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=self.max_length
).to(self.device)
out = self.model(**tokens)
cls = out.last_hidden_state[:, 0, :]
score = torch.sigmoid(self.classifier(cls)).item()
results.append({
"label": entry["label"],
"description": entry["description"],
"score": round(score, 4)
})
return sorted(results, key=lambda x: x["score"], reverse=True)
def _process_batch(self, payload: Dict[str, Any]) -> List[List[Dict[str, Any]]]:
"""True batch processing for multiple queries"""
queries = payload["queries"]
candidates = payload["candidates"]
# Create all query-candidate combinations
all_texts = []
query_indices = []
candidate_indices = []
for q_idx, query in enumerate(queries):
for c_idx, candidate in enumerate(candidates):
text = f"[QUERY] {query} [LABEL_NAME] {candidate['label']} [LABEL_DESCRIPTION] {candidate['description']}"
all_texts.append(text)
query_indices.append(q_idx)
candidate_indices.append(c_idx)
# Process in batches to avoid memory issues
all_scores = []
total_combinations = len(all_texts)
with torch.no_grad():
for i in range(0, total_combinations, self.max_batch_size):
batch_texts = all_texts[i:i + self.max_batch_size]
# Tokenize batch
tokens = self.tokenizer(
batch_texts,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=self.max_length
).to(self.device)
# Single forward pass for entire batch
out = self.model(**tokens)
cls = out.last_hidden_state[:, 0, :]
scores = torch.sigmoid(self.classifier(cls)).squeeze()
# Handle single item case
if scores.dim() == 0:
scores = scores.unsqueeze(0)
all_scores.extend(scores.cpu().tolist())
# Reshape results back to query structure
results = []
for q_idx in range(len(queries)):
query_results = []
for c_idx, candidate in enumerate(candidates):
# Find the score for this query-candidate combination
combination_idx = q_idx * len(candidates) + c_idx
score = all_scores[combination_idx]
query_results.append({
"label": candidate["label"],
"description": candidate["description"],
"score": round(score, 4)
})
# Sort by score for this query
query_results.sort(key=lambda x: x["score"], reverse=True)
results.append(query_results)
return results
def get_batch_stats(self) -> Dict[str, Any]:
"""Return batch processing statistics"""
return {
"max_batch_size": self.max_batch_size,
"max_length": self.max_length,
"device": str(self.device),
"model_name": self.model.config.name_or_path if hasattr(self.model.config, 'name_or_path') else "unknown"
}