File size: 5,649 Bytes
6f1c63c
 
 
 
 
 
 
 
 
 
 
 
 
 
65018a5
6f1c63c
 
 
65018a5
6f1c63c
 
 
65018a5
6f1c63c
c6e6058
 
 
 
6f1c63c
 
65018a5
c6e6058
 
 
 
 
 
 
 
 
65018a5
 
6f1c63c
 
 
 
 
 
 
 
 
 
c6e6058
6f1c63c
 
 
 
 
 
 
 
 
 
 
 
c6e6058
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
from typing import Dict, Any, List
import torch
from transformers import AutoTokenizer, AutoModel
import os
import json

class EndpointHandler:
    def __init__(self, path: str = ""):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.tokenizer = AutoTokenizer.from_pretrained(path)
        self.tokenizer.add_special_tokens({
            "additional_special_tokens": ["[QUERY]", "[LABEL_NAME]", "[LABEL_DESCRIPTION]"]
        })
        self.model = AutoModel.from_pretrained(path).to(self.device)

        head_path = os.path.join(path, "classifier_head.json")
        with open(head_path, "r") as f:
            head = json.load(f)

        self.classifier = torch.nn.Linear(self.model.config.hidden_size, 1).to(self.device)
        self.classifier.weight.data = torch.tensor(head["scorer_weight"]).to(self.device)
        self.classifier.bias.data = torch.tensor(head["scorer_bias"]).to(self.device)

        self.model.eval()
        
        # Batch processing configuration
        self.max_batch_size = 128  # Adjust based on GPU memory
        self.max_length = 64

    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        payload = data.get("inputs", data)
        
        # Check if this is batch processing (multiple queries) or single query
        if "queries" in payload:
            return self._process_batch(payload)
        else:
            return self._process_single(payload)
    
    def _process_single(self, payload: Dict[str, Any]) -> List[Dict[str, Any]]:
        """Original single query processing for backward compatibility"""
        query = payload["query"]
        candidates = payload["candidates"]
        results = []

        with torch.no_grad():
            for entry in candidates:
                text = f"[QUERY] {query} [LABEL_NAME] {entry['label']} [LABEL_DESCRIPTION] {entry['description']}"
                tokens = self.tokenizer(
                    text,
                    return_tensors="pt",
                    padding="max_length",
                    truncation=True,
                    max_length=self.max_length
                ).to(self.device)

                out = self.model(**tokens)
                cls = out.last_hidden_state[:, 0, :]
                score = torch.sigmoid(self.classifier(cls)).item()
                results.append({
                    "label": entry["label"],
                    "description": entry["description"],
                    "score": round(score, 4)
                })

        return sorted(results, key=lambda x: x["score"], reverse=True)
    
    def _process_batch(self, payload: Dict[str, Any]) -> List[List[Dict[str, Any]]]:
        """True batch processing for multiple queries"""
        queries = payload["queries"]
        candidates = payload["candidates"]
        
        # Create all query-candidate combinations
        all_texts = []
        query_indices = []
        candidate_indices = []
        
        for q_idx, query in enumerate(queries):
            for c_idx, candidate in enumerate(candidates):
                text = f"[QUERY] {query} [LABEL_NAME] {candidate['label']} [LABEL_DESCRIPTION] {candidate['description']}"
                all_texts.append(text)
                query_indices.append(q_idx)
                candidate_indices.append(c_idx)
        
        # Process in batches to avoid memory issues
        all_scores = []
        total_combinations = len(all_texts)
        
        with torch.no_grad():
            for i in range(0, total_combinations, self.max_batch_size):
                batch_texts = all_texts[i:i + self.max_batch_size]
                
                # Tokenize batch
                tokens = self.tokenizer(
                    batch_texts,
                    return_tensors="pt",
                    padding="max_length",
                    truncation=True,
                    max_length=self.max_length
                ).to(self.device)
                
                # Single forward pass for entire batch
                out = self.model(**tokens)
                cls = out.last_hidden_state[:, 0, :]
                scores = torch.sigmoid(self.classifier(cls)).squeeze()
                
                # Handle single item case
                if scores.dim() == 0:
                    scores = scores.unsqueeze(0)
                
                all_scores.extend(scores.cpu().tolist())
        
        # Reshape results back to query structure
        results = []
        for q_idx in range(len(queries)):
            query_results = []
            for c_idx, candidate in enumerate(candidates):
                # Find the score for this query-candidate combination
                combination_idx = q_idx * len(candidates) + c_idx
                score = all_scores[combination_idx]
                
                query_results.append({
                    "label": candidate["label"],
                    "description": candidate["description"],
                    "score": round(score, 4)
                })
            
            # Sort by score for this query
            query_results.sort(key=lambda x: x["score"], reverse=True)
            results.append(query_results)
        
        return results
    
    def get_batch_stats(self) -> Dict[str, Any]:
        """Return batch processing statistics"""
        return {
            "max_batch_size": self.max_batch_size,
            "max_length": self.max_length,
            "device": str(self.device),
            "model_name": self.model.config.name_or_path if hasattr(self.model.config, 'name_or_path') else "unknown"
        }