Commit
·
7f7dfaf
1
Parent(s):
924a7c9
Updated handler implementation
Browse files- handler.py +96 -32
handler.py
CHANGED
@@ -1,42 +1,106 @@
|
|
1 |
import os
|
2 |
import torch
|
3 |
from joblib import load
|
4 |
-
from transformers import BertTokenizer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
def load_model(model_path):
|
7 |
-
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
class EndpointHandler:
|
10 |
def __init__(self, path=""):
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
def __call__(self, data):
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
inputs
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import os
|
2 |
import torch
|
3 |
from joblib import load
|
4 |
+
from transformers import BertTokenizer, BertModel
|
5 |
+
from transformers.models.bert.modeling_bert import BertSelfAttention
|
6 |
+
|
7 |
+
class BertSdpaSelfAttention(BertSelfAttention):
|
8 |
+
def __init__(self, config):
|
9 |
+
super().__init__(config)
|
10 |
+
# Add any custom initialization here
|
11 |
+
self.sdpa_head = torch.nn.Linear(config.hidden_size, config.hidden_size)
|
12 |
+
|
13 |
+
def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, past_key_value=None, output_attentions=False):
|
14 |
+
# Custom forward pass
|
15 |
+
mixed_query_layer = self.query(hidden_states)
|
16 |
+
mixed_key_layer = self.key(hidden_states)
|
17 |
+
mixed_value_layer = self.value(hidden_states)
|
18 |
+
|
19 |
+
query_layer = self.transpose_for_scores(mixed_query_layer)
|
20 |
+
key_layer = self.transpose_for_scores(mixed_key_layer)
|
21 |
+
value_layer = self.transpose_for_scores(mixed_value_layer)
|
22 |
+
|
23 |
+
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
24 |
+
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
25 |
+
|
26 |
+
if attention_mask is not None:
|
27 |
+
attention_scores = attention_scores + attention_mask
|
28 |
+
|
29 |
+
attention_probs = torch.nn.functional.softmax(attention_scores, dim=-1)
|
30 |
+
attention_probs = self.dropout(attention_probs)
|
31 |
+
|
32 |
+
if head_mask is not None:
|
33 |
+
attention_probs = attention_probs * head_mask
|
34 |
+
|
35 |
+
context_layer = torch.matmul(attention_probs, value_layer)
|
36 |
+
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
37 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
38 |
+
context_layer = context_layer.view(*new_context_layer_shape)
|
39 |
+
|
40 |
+
# Apply SDPA head
|
41 |
+
sdpa_output = self.sdpa_head(context_layer)
|
42 |
+
|
43 |
+
outputs = (sdpa_output, attention_probs) if output_attentions else (sdpa_output,)
|
44 |
+
return outputs
|
45 |
+
|
46 |
+
# Register the custom class
|
47 |
+
setattr(torch.nn.modules, 'BertSdpaSelfAttention', BertSdpaSelfAttention)
|
48 |
|
49 |
def load_model(model_path):
|
50 |
+
try:
|
51 |
+
return load(model_path)
|
52 |
+
except AttributeError as e:
|
53 |
+
print(f"Error loading model: {e}")
|
54 |
+
print("Ensure all custom classes are properly defined.")
|
55 |
+
raise
|
56 |
|
57 |
class EndpointHandler:
|
58 |
def __init__(self, path=""):
|
59 |
+
try:
|
60 |
+
# Load the model in the __init__ method
|
61 |
+
model_path = os.path.join(path, "model.joblib")
|
62 |
+
self.model = load_model(model_path)
|
63 |
+
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
64 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
65 |
+
self.model.to(self.device)
|
66 |
+
except Exception as e:
|
67 |
+
print(f"Error initializing EndpointHandler: {e}")
|
68 |
+
raise
|
69 |
|
70 |
def __call__(self, data):
|
71 |
+
try:
|
72 |
+
inputs = data.pop("inputs", data)
|
73 |
+
|
74 |
+
# Ensure inputs is a list
|
75 |
+
if isinstance(inputs, str):
|
76 |
+
inputs = [inputs]
|
77 |
+
|
78 |
+
# Tokenize inputs
|
79 |
+
encoded_inputs = self.tokenizer(inputs, padding=True, truncation=True, max_length=128, return_tensors="pt")
|
80 |
+
|
81 |
+
# Move inputs to the correct device
|
82 |
+
input_ids = encoded_inputs['input_ids'].to(self.device)
|
83 |
+
attention_mask = encoded_inputs['attention_mask'].to(self.device)
|
84 |
+
|
85 |
+
# Perform inference
|
86 |
+
with torch.no_grad():
|
87 |
+
outputs = self.model(input_ids, attention_mask=attention_mask)
|
88 |
+
logits = outputs.logits
|
89 |
+
probabilities = torch.nn.functional.softmax(logits, dim=-1)
|
90 |
+
predictions = torch.argmax(probabilities, dim=-1)
|
91 |
+
|
92 |
+
# Convert predictions to human-readable labels
|
93 |
+
class_names = ["JAILBREAK", "INJECTION", "PHISHING", "SAFE"]
|
94 |
+
results = [{"label": class_names[pred], "score": prob[pred].item()} for pred, prob in zip(predictions, probabilities)]
|
95 |
+
|
96 |
+
return {"predictions": results}
|
97 |
+
except Exception as e:
|
98 |
+
print(f"Error during inference: {e}")
|
99 |
+
return {"error": str(e)}
|
100 |
+
|
101 |
+
# For local testing
|
102 |
+
if __name__ == "__main__":
|
103 |
+
handler = EndpointHandler()
|
104 |
+
test_input = {"inputs": "This is a test input"}
|
105 |
+
result = handler(test_input)
|
106 |
+
print(result)
|