asofter commited on
Commit
6b80b1f
β€’
1 Parent(s): e8cf854

* small model support

Browse files
Files changed (3) hide show
  1. README.md +1 -1
  2. app.py +22 -24
  3. requirements.txt +7 -7
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: πŸ“
4
  colorFrom: yellow
5
  colorTo: gray
6
  sdk: gradio
7
- sdk_version: 4.31.2
8
  pinned: true
9
  license: apache-2.0
10
  ---
 
4
  colorFrom: yellow
5
  colorTo: gray
6
  sdk: gradio
7
+ sdk_version: 4.39.0
8
  pinned: true
9
  license: apache-2.0
10
  ---
app.py CHANGED
@@ -21,7 +21,8 @@ from transformers import AutoTokenizer, pipeline
21
  logging.basicConfig(level=logging.INFO)
22
  logger = logging.getLogger(__name__)
23
 
24
- hf_api = HfApi(token=os.getenv("HF_TOKEN"))
 
25
  num_processes = 2 # mp.cpu_count()
26
 
27
  lakera_api_key = os.getenv("LAKERA_API_KEY")
@@ -35,9 +36,15 @@ aws_comprehend_client = boto3.client(service_name="comprehend", region_name="us-
35
  @lru_cache(maxsize=2)
36
  def init_prompt_injection_model(prompt_injection_ort_model: str, subfolder: str = "") -> pipeline:
37
  hf_model = ORTModelForSequenceClassification.from_pretrained(
38
- prompt_injection_ort_model, export=False, subfolder=subfolder, file_name="model.onnx"
 
 
 
 
 
 
 
39
  )
40
- hf_tokenizer = AutoTokenizer.from_pretrained(prompt_injection_ort_model, subfolder=subfolder)
41
  hf_tokenizer.model_input_names = ["input_ids", "attention_mask"]
42
 
43
  logger.info(f"Initialized classification ONNX model {prompt_injection_ort_model} on CPU")
@@ -58,14 +65,17 @@ def convert_elapsed_time(diff_time) -> float:
58
 
59
 
60
  deepset_classifier = init_prompt_injection_model(
61
- "ProtectAI/deberta-v3-base-injection-onnx"
62
  ) # ONNX version of deepset/deberta-v3-base-injection
63
  protectai_v2_classifier = init_prompt_injection_model(
64
- "ProtectAI/deberta-v3-base-prompt-injection-v2", "onnx"
65
  )
66
  fmops_classifier = init_prompt_injection_model(
67
- "ProtectAI/fmops-distilbert-prompt-injection-onnx"
68
  ) # ONNX version of fmops/distilbert-prompt-injection
 
 
 
69
 
70
 
71
  def detect_hf(
@@ -93,6 +103,10 @@ def detect_hf_protectai_v2(prompt: str) -> (bool, bool):
93
  return detect_hf(prompt, classifier=protectai_v2_classifier)
94
 
95
 
 
 
 
 
96
  def detect_hf_deepset(prompt: str) -> (bool, bool):
97
  return detect_hf(prompt, classifier=deepset_classifier)
98
 
@@ -153,23 +167,6 @@ def detect_aws_comprehend(prompt: str) -> (bool, bool):
153
  EndpointArn="arn:aws:comprehend:us-east-1:aws:document-classifier-endpoint/prompt-safety",
154
  Text=prompt,
155
  )
156
- response = {
157
- "Classes": [
158
- {"Name": "SAFE_PROMPT", "Score": 0.9010000228881836},
159
- {"Name": "UNSAFE_PROMPT", "Score": 0.0989999994635582},
160
- ],
161
- "ResponseMetadata": {
162
- "RequestId": "e8900fe1-3346-45c0-bad3-007b2840865a",
163
- "HTTPStatusCode": 200,
164
- "HTTPHeaders": {
165
- "x-amzn-requestid": "e8900fe1-3346-45c0-bad3-007b2840865a",
166
- "content-type": "application/x-amz-json-1.1",
167
- "content-length": "115",
168
- "date": "Mon, 19 Feb 2024 08:34:43 GMT",
169
- },
170
- "RetryAttempts": 0,
171
- },
172
- }
173
  logger.info(f"Prompt injection result from AWS Comprehend: {response}")
174
  if response["ResponseMetadata"]["HTTPStatusCode"] != 200:
175
  logger.error(f"Failed to call AWS Comprehend API: {response}")
@@ -209,13 +206,14 @@ def detect_sydelabs(prompt: str) -> (bool, bool):
209
 
210
  detection_providers = {
211
  "ProtectAI v2 (HF model)": detect_hf_protectai_v2,
 
212
  "Deepset (HF model)": detect_hf_deepset,
213
  "FMOps (HF model)": detect_hf_fmops,
214
  "Lakera Guard": detect_lakera,
215
  # "Rebuff": detect_rebuff,
216
  "Azure Content Safety": detect_azure,
217
  "SydeLabs": detect_sydelabs,
218
- # "AWS Comprehend": detect_aws_comprehend,
219
  }
220
 
221
 
 
21
  logging.basicConfig(level=logging.INFO)
22
  logger = logging.getLogger(__name__)
23
 
24
+ hf_token = os.getenv("HF_TOKEN")
25
+ hf_api = HfApi(token=hf_token)
26
  num_processes = 2 # mp.cpu_count()
27
 
28
  lakera_api_key = os.getenv("LAKERA_API_KEY")
 
36
  @lru_cache(maxsize=2)
37
  def init_prompt_injection_model(prompt_injection_ort_model: str, subfolder: str = "") -> pipeline:
38
  hf_model = ORTModelForSequenceClassification.from_pretrained(
39
+ prompt_injection_ort_model,
40
+ export=False,
41
+ subfolder=subfolder,
42
+ file_name="model.onnx",
43
+ token=hf_token,
44
+ )
45
+ hf_tokenizer = AutoTokenizer.from_pretrained(
46
+ prompt_injection_ort_model, subfolder=subfolder, token=hf_token
47
  )
 
48
  hf_tokenizer.model_input_names = ["input_ids", "attention_mask"]
49
 
50
  logger.info(f"Initialized classification ONNX model {prompt_injection_ort_model} on CPU")
 
65
 
66
 
67
  deepset_classifier = init_prompt_injection_model(
68
+ "protectai/deberta-v3-base-injection-onnx"
69
  ) # ONNX version of deepset/deberta-v3-base-injection
70
  protectai_v2_classifier = init_prompt_injection_model(
71
+ "protectai/deberta-v3-base-prompt-injection-v2", "onnx"
72
  )
73
  fmops_classifier = init_prompt_injection_model(
74
+ "protectai/fmops-distilbert-prompt-injection-onnx"
75
  ) # ONNX version of fmops/distilbert-prompt-injection
76
+ protectai_v2_small_classifier = init_prompt_injection_model(
77
+ "protectai/deberta-v3-small-prompt-injection-v2", "onnx"
78
+ ) # ONNX version of protectai/deberta-v3-small-prompt-injection-v2
79
 
80
 
81
  def detect_hf(
 
103
  return detect_hf(prompt, classifier=protectai_v2_classifier)
104
 
105
 
106
+ def detect_hf_protectai_v2_small(prompt: str) -> (bool, bool):
107
+ return detect_hf(prompt, classifier=protectai_v2_small_classifier)
108
+
109
+
110
  def detect_hf_deepset(prompt: str) -> (bool, bool):
111
  return detect_hf(prompt, classifier=deepset_classifier)
112
 
 
167
  EndpointArn="arn:aws:comprehend:us-east-1:aws:document-classifier-endpoint/prompt-safety",
168
  Text=prompt,
169
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
  logger.info(f"Prompt injection result from AWS Comprehend: {response}")
171
  if response["ResponseMetadata"]["HTTPStatusCode"] != 200:
172
  logger.error(f"Failed to call AWS Comprehend API: {response}")
 
206
 
207
  detection_providers = {
208
  "ProtectAI v2 (HF model)": detect_hf_protectai_v2,
209
+ "ProtectAI v2 Small (HF model)": detect_hf_protectai_v2_small,
210
  "Deepset (HF model)": detect_hf_deepset,
211
  "FMOps (HF model)": detect_hf_fmops,
212
  "Lakera Guard": detect_lakera,
213
  # "Rebuff": detect_rebuff,
214
  "Azure Content Safety": detect_azure,
215
  "SydeLabs": detect_sydelabs,
216
+ "AWS Comprehend": detect_aws_comprehend,
217
  }
218
 
219
 
requirements.txt CHANGED
@@ -1,8 +1,8 @@
1
- boto3==1.34.104
2
- gradio==4.31.2
3
- huggingface_hub==0.23.0
4
- onnxruntime==1.17.3
5
- optimum[onnxruntime]==1.19.2
6
  rebuff==0.1.1
7
- requests==2.31.0
8
- transformers==4.39.3
 
1
+ boto3==1.34.146
2
+ gradio==4.39.0
3
+ huggingface_hub==0.24.0
4
+ onnxruntime==1.18.1
5
+ optimum[onnxruntime]==1.21.2
6
  rebuff==0.1.1
7
+ requests==2.32.3
8
+ transformers==4.42.4