menimeni123 commited on
Commit
7f7dfaf
·
1 Parent(s): 924a7c9

Updated handler implementation

Browse files
Files changed (1) hide show
  1. handler.py +96 -32
handler.py CHANGED
@@ -1,42 +1,106 @@
1
  import os
2
  import torch
3
  from joblib import load
4
- from transformers import BertTokenizer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  def load_model(model_path):
7
- return load(model_path)
 
 
 
 
 
8
 
9
  class EndpointHandler:
10
  def __init__(self, path=""):
11
- # Load the model in the __init__ method
12
- self.model = load_model(os.path.join(path, "model.joblib"))
13
- self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
14
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
- self.model.to(self.device)
 
 
 
 
 
16
 
17
  def __call__(self, data):
18
- inputs = data.pop("inputs", data)
19
-
20
- # Ensure inputs is a list
21
- if isinstance(inputs, str):
22
- inputs = [inputs]
23
-
24
- # Tokenize inputs
25
- encoded_inputs = self.tokenizer(inputs, padding=True, truncation=True, max_length=128, return_tensors="pt")
26
-
27
- # Move inputs to the correct device
28
- input_ids = encoded_inputs['input_ids'].to(self.device)
29
- attention_mask = encoded_inputs['attention_mask'].to(self.device)
30
-
31
- # Perform inference
32
- with torch.no_grad():
33
- outputs = self.model(input_ids, attention_mask=attention_mask)
34
- logits = outputs.logits
35
- probabilities = torch.nn.functional.softmax(logits, dim=-1)
36
- predictions = torch.argmax(probabilities, dim=-1)
37
-
38
- # Convert predictions to human-readable labels
39
- class_names = ["JAILBREAK", "INJECTION", "PHISHING", "SAFE"]
40
- results = [{"label": class_names[pred], "score": prob[pred].item()} for pred, prob in zip(predictions, probabilities)]
41
-
42
- return {"predictions": results}
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import torch
3
  from joblib import load
4
+ from transformers import BertTokenizer, BertModel
5
+ from transformers.models.bert.modeling_bert import BertSelfAttention
6
+
7
+ class BertSdpaSelfAttention(BertSelfAttention):
8
+ def __init__(self, config):
9
+ super().__init__(config)
10
+ # Add any custom initialization here
11
+ self.sdpa_head = torch.nn.Linear(config.hidden_size, config.hidden_size)
12
+
13
+ def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, past_key_value=None, output_attentions=False):
14
+ # Custom forward pass
15
+ mixed_query_layer = self.query(hidden_states)
16
+ mixed_key_layer = self.key(hidden_states)
17
+ mixed_value_layer = self.value(hidden_states)
18
+
19
+ query_layer = self.transpose_for_scores(mixed_query_layer)
20
+ key_layer = self.transpose_for_scores(mixed_key_layer)
21
+ value_layer = self.transpose_for_scores(mixed_value_layer)
22
+
23
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
24
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
25
+
26
+ if attention_mask is not None:
27
+ attention_scores = attention_scores + attention_mask
28
+
29
+ attention_probs = torch.nn.functional.softmax(attention_scores, dim=-1)
30
+ attention_probs = self.dropout(attention_probs)
31
+
32
+ if head_mask is not None:
33
+ attention_probs = attention_probs * head_mask
34
+
35
+ context_layer = torch.matmul(attention_probs, value_layer)
36
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
37
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
38
+ context_layer = context_layer.view(*new_context_layer_shape)
39
+
40
+ # Apply SDPA head
41
+ sdpa_output = self.sdpa_head(context_layer)
42
+
43
+ outputs = (sdpa_output, attention_probs) if output_attentions else (sdpa_output,)
44
+ return outputs
45
+
46
+ # Register the custom class
47
+ setattr(torch.nn.modules, 'BertSdpaSelfAttention', BertSdpaSelfAttention)
48
 
49
  def load_model(model_path):
50
+ try:
51
+ return load(model_path)
52
+ except AttributeError as e:
53
+ print(f"Error loading model: {e}")
54
+ print("Ensure all custom classes are properly defined.")
55
+ raise
56
 
57
  class EndpointHandler:
58
  def __init__(self, path=""):
59
+ try:
60
+ # Load the model in the __init__ method
61
+ model_path = os.path.join(path, "model.joblib")
62
+ self.model = load_model(model_path)
63
+ self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
64
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
65
+ self.model.to(self.device)
66
+ except Exception as e:
67
+ print(f"Error initializing EndpointHandler: {e}")
68
+ raise
69
 
70
  def __call__(self, data):
71
+ try:
72
+ inputs = data.pop("inputs", data)
73
+
74
+ # Ensure inputs is a list
75
+ if isinstance(inputs, str):
76
+ inputs = [inputs]
77
+
78
+ # Tokenize inputs
79
+ encoded_inputs = self.tokenizer(inputs, padding=True, truncation=True, max_length=128, return_tensors="pt")
80
+
81
+ # Move inputs to the correct device
82
+ input_ids = encoded_inputs['input_ids'].to(self.device)
83
+ attention_mask = encoded_inputs['attention_mask'].to(self.device)
84
+
85
+ # Perform inference
86
+ with torch.no_grad():
87
+ outputs = self.model(input_ids, attention_mask=attention_mask)
88
+ logits = outputs.logits
89
+ probabilities = torch.nn.functional.softmax(logits, dim=-1)
90
+ predictions = torch.argmax(probabilities, dim=-1)
91
+
92
+ # Convert predictions to human-readable labels
93
+ class_names = ["JAILBREAK", "INJECTION", "PHISHING", "SAFE"]
94
+ results = [{"label": class_names[pred], "score": prob[pred].item()} for pred, prob in zip(predictions, probabilities)]
95
+
96
+ return {"predictions": results}
97
+ except Exception as e:
98
+ print(f"Error during inference: {e}")
99
+ return {"error": str(e)}
100
+
101
+ # For local testing
102
+ if __name__ == "__main__":
103
+ handler = EndpointHandler()
104
+ test_input = {"inputs": "This is a test input"}
105
+ result = handler(test_input)
106
+ print(result)