File size: 5,649 Bytes
6f1c63c 65018a5 6f1c63c 65018a5 6f1c63c 65018a5 6f1c63c c6e6058 6f1c63c 65018a5 c6e6058 65018a5 6f1c63c c6e6058 6f1c63c c6e6058 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
from typing import Dict, Any, List
import torch
from transformers import AutoTokenizer, AutoModel
import os
import json
class EndpointHandler:
def __init__(self, path: str = ""):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.tokenizer = AutoTokenizer.from_pretrained(path)
self.tokenizer.add_special_tokens({
"additional_special_tokens": ["[QUERY]", "[LABEL_NAME]", "[LABEL_DESCRIPTION]"]
})
self.model = AutoModel.from_pretrained(path).to(self.device)
head_path = os.path.join(path, "classifier_head.json")
with open(head_path, "r") as f:
head = json.load(f)
self.classifier = torch.nn.Linear(self.model.config.hidden_size, 1).to(self.device)
self.classifier.weight.data = torch.tensor(head["scorer_weight"]).to(self.device)
self.classifier.bias.data = torch.tensor(head["scorer_bias"]).to(self.device)
self.model.eval()
# Batch processing configuration
self.max_batch_size = 128 # Adjust based on GPU memory
self.max_length = 64
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
payload = data.get("inputs", data)
# Check if this is batch processing (multiple queries) or single query
if "queries" in payload:
return self._process_batch(payload)
else:
return self._process_single(payload)
def _process_single(self, payload: Dict[str, Any]) -> List[Dict[str, Any]]:
"""Original single query processing for backward compatibility"""
query = payload["query"]
candidates = payload["candidates"]
results = []
with torch.no_grad():
for entry in candidates:
text = f"[QUERY] {query} [LABEL_NAME] {entry['label']} [LABEL_DESCRIPTION] {entry['description']}"
tokens = self.tokenizer(
text,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=self.max_length
).to(self.device)
out = self.model(**tokens)
cls = out.last_hidden_state[:, 0, :]
score = torch.sigmoid(self.classifier(cls)).item()
results.append({
"label": entry["label"],
"description": entry["description"],
"score": round(score, 4)
})
return sorted(results, key=lambda x: x["score"], reverse=True)
def _process_batch(self, payload: Dict[str, Any]) -> List[List[Dict[str, Any]]]:
"""True batch processing for multiple queries"""
queries = payload["queries"]
candidates = payload["candidates"]
# Create all query-candidate combinations
all_texts = []
query_indices = []
candidate_indices = []
for q_idx, query in enumerate(queries):
for c_idx, candidate in enumerate(candidates):
text = f"[QUERY] {query} [LABEL_NAME] {candidate['label']} [LABEL_DESCRIPTION] {candidate['description']}"
all_texts.append(text)
query_indices.append(q_idx)
candidate_indices.append(c_idx)
# Process in batches to avoid memory issues
all_scores = []
total_combinations = len(all_texts)
with torch.no_grad():
for i in range(0, total_combinations, self.max_batch_size):
batch_texts = all_texts[i:i + self.max_batch_size]
# Tokenize batch
tokens = self.tokenizer(
batch_texts,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=self.max_length
).to(self.device)
# Single forward pass for entire batch
out = self.model(**tokens)
cls = out.last_hidden_state[:, 0, :]
scores = torch.sigmoid(self.classifier(cls)).squeeze()
# Handle single item case
if scores.dim() == 0:
scores = scores.unsqueeze(0)
all_scores.extend(scores.cpu().tolist())
# Reshape results back to query structure
results = []
for q_idx in range(len(queries)):
query_results = []
for c_idx, candidate in enumerate(candidates):
# Find the score for this query-candidate combination
combination_idx = q_idx * len(candidates) + c_idx
score = all_scores[combination_idx]
query_results.append({
"label": candidate["label"],
"description": candidate["description"],
"score": round(score, 4)
})
# Sort by score for this query
query_results.sort(key=lambda x: x["score"], reverse=True)
results.append(query_results)
return results
def get_batch_stats(self) -> Dict[str, Any]:
"""Return batch processing statistics"""
return {
"max_batch_size": self.max_batch_size,
"max_length": self.max_length,
"device": str(self.device),
"model_name": self.model.config.name_or_path if hasattr(self.model.config, 'name_or_path') else "unknown"
}
|