menimeni123 commited on
Commit
f311c70
1 Parent(s): 1b297c3

added better handler

Browse files
Files changed (2) hide show
  1. .DS_Store +0 -0
  2. handler.py +20 -18
.DS_Store CHANGED
Binary files a/.DS_Store and b/.DS_Store differ
 
handler.py CHANGED
@@ -1,32 +1,34 @@
1
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
 
2
  import torch
 
 
 
3
 
4
  class EndpointHandler:
5
  def __init__(self, model_dir):
6
- self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
7
- self.model = AutoModelForSequenceClassification.from_pretrained(model_dir)
8
- self.label_mapping = {0: "SAFE", 1: "JAILBREAK", 2: "INJECTION", 3: "PHISHING"}
 
9
 
10
  def __call__(self, inputs):
11
  if isinstance(inputs, dict) and 'inputs' in inputs:
12
  return self.predict(inputs['inputs'])
13
  return self.predict(inputs)
14
 
15
- def predict(self, inputs):
16
- # Tokenize the input
17
- encoded_input = self.tokenizer(inputs, return_tensors='pt', truncation=True, padding=True)
18
-
19
- # Make prediction
20
  with torch.no_grad():
21
- output = self.model(**encoded_input)
22
-
23
- # Get the predicted class
24
- predicted_class = torch.argmax(output.logits, dim=1).item()
25
-
26
- # Map the predicted class to its label
27
- predicted_label = self.label_mapping[predicted_class]
28
-
29
- return {"label": predicted_label, "score": output.logits.softmax(dim=1).max().item()}
30
 
31
  def get_pipeline():
32
  return EndpointHandler
 
1
+ import os
2
+ import joblib
3
  import torch
4
+ import numpy as np
5
+ from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification
6
+ import torch.nn.functional as F
7
 
8
  class EndpointHandler:
9
  def __init__(self, model_dir):
10
+ self.model = DistilBertForSequenceClassification.from_pretrained(model_dir)
11
+ self.tokenizer = DistilBertTokenizerFast.from_pretrained(model_dir)
12
+ self.label_mapping = joblib.load(os.path.join(model_dir, "label_mapping.joblib"))
13
+ self.labels = {v: k for k, v in self.label_mapping.items()}
14
 
15
  def __call__(self, inputs):
16
  if isinstance(inputs, dict) and 'inputs' in inputs:
17
  return self.predict(inputs['inputs'])
18
  return self.predict(inputs)
19
 
20
+ def predict(self, text):
21
+ if len(text.split()) < 4:
22
+ return {"label": "SAFE", "score": 1.0}
23
+
24
+ encoded_input = self.tokenizer(text, return_tensors='pt', truncation=True, max_length=128)
25
  with torch.no_grad():
26
+ outputs = self.model(**encoded_input)
27
+ probabilities = F.softmax(outputs.logits, dim=-1).cpu().numpy()[0]
28
+ confidence = np.max(probabilities)
29
+ predicted_label_idx = int(np.argmax(probabilities))
30
+ predicted_label = self.labels[predicted_label_idx]
31
+ return {"label": predicted_label, "score": float(confidence)}
 
 
 
32
 
33
  def get_pipeline():
34
  return EndpointHandler