asofter commited on
Commit
a6b53fb
1 Parent(s): 8950a27

* ONNX runtime

Browse files

* use llm-guard 0.3.1
* google analytics tracking
* linter to fix code

Files changed (6) hide show
  1. .pre-commit-config.yaml +38 -0
  2. Dockerfile +1 -1
  3. app.py +28 -19
  4. output.py +43 -27
  5. prompt.py +25 -45
  6. requirements.txt +4 -5
.pre-commit-config.yaml ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ repos:
2
+ - repo: https://github.com/pre-commit/pre-commit-hooks
3
+ rev: v4.4.0
4
+ hooks:
5
+ - id: check-yaml
6
+ - id: end-of-file-fixer
7
+ - id: trailing-whitespace
8
+ - id: end-of-file-fixer
9
+ types: [ python ]
10
+ - id: requirements-txt-fixer
11
+
12
+ - repo: https://github.com/psf/black
13
+ rev: 23.7.0
14
+ hooks:
15
+ - id: black
16
+ args: [ --line-length=100, --exclude="" ]
17
+
18
+ # this is not technically always safe but usually is
19
+ # use comments `# isort: off` and `# isort: on` to disable/re-enable isort
20
+ - repo: https://github.com/pycqa/isort
21
+ rev: 5.12.0
22
+ hooks:
23
+ - id: isort
24
+ args: [ --line-length=100, --profile=black ]
25
+
26
+ # this is slightly dangerous because python imports have side effects
27
+ # and this tool removes unused imports, which may be providing
28
+ # necessary side effects for the code to run
29
+ - repo: https://github.com/PyCQA/autoflake
30
+ rev: v2.2.0
31
+ hooks:
32
+ - id: autoflake
33
+ args:
34
+ - "--in-place"
35
+ - "--expand-star-imports"
36
+ - "--remove-duplicate-keys"
37
+ - "--remove-unused-variables"
38
+ - "--remove-all-unused-imports"
Dockerfile CHANGED
@@ -1,4 +1,4 @@
1
- FROM python:3.10-slim
2
 
3
  RUN apt-get update && apt-get install -y \
4
  build-essential \
 
1
+ FROM python:3.11-slim
2
 
3
  RUN apt-get update && apt-get install -y \
4
  build-essential \
app.py CHANGED
@@ -1,16 +1,33 @@
1
  import logging
2
- import time
3
  import traceback
4
- from datetime import timedelta
5
 
6
  import pandas as pd
7
  import streamlit as st
 
 
 
8
  from output import init_settings as init_output_settings
9
  from output import scan as scan_output
10
  from prompt import init_settings as init_prompt_settings
11
  from prompt import scan as scan_prompt
12
 
13
- from llm_guard.vault import Vault
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  PROMPT = "prompt"
16
  OUTPUT = "output"
@@ -48,6 +65,8 @@ if scanner_type == PROMPT:
48
  elif scanner_type == OUTPUT:
49
  enabled_scanners, settings = init_output_settings()
50
 
 
 
51
  # Main pannel
52
  with st.expander("About", expanded=False):
53
  st.info(
@@ -93,32 +112,24 @@ elif scanner_type == OUTPUT:
93
  st_result_text = None
94
  st_analysis = None
95
  st_is_valid = None
96
- st_time_delta = None
97
 
98
  try:
99
  with st.form("text_form", clear_on_submit=False):
100
  submitted = st.form_submit_button("Process")
101
  if submitted:
102
- results_valid = {}
103
- results_score = {}
104
 
105
- start_time = time.monotonic()
106
  if scanner_type == PROMPT:
107
- st_result_text, results_valid, results_score = scan_prompt(
108
  vault, enabled_scanners, settings, st_prompt_text, st_fail_fast
109
  )
110
  elif scanner_type == OUTPUT:
111
- st_result_text, results_valid, results_score = scan_output(
112
  vault, enabled_scanners, settings, st_prompt_text, st_output_text, st_fail_fast
113
  )
114
- end_time = time.monotonic()
115
- st_time_delta = timedelta(seconds=end_time - start_time)
116
 
117
- st_is_valid = all(results_valid.values())
118
- st_analysis = [
119
- {"scanner": k, "is valid": results_valid[k], "risk score": results_score[k]}
120
- for k in results_valid
121
- ]
122
 
123
  except Exception as e:
124
  logger.error(e)
@@ -127,9 +138,7 @@ except Exception as e:
127
 
128
  # After:
129
  if st_is_valid is not None:
130
- st.subheader(
131
- f"Results - {'valid' if st_is_valid else 'invalid'} ({round(st_time_delta.total_seconds())} seconds)"
132
- )
133
 
134
  col1, col2 = st.columns(2)
135
 
 
1
  import logging
 
2
  import traceback
 
3
 
4
  import pandas as pd
5
  import streamlit as st
6
+ from llm_guard.vault import Vault
7
+ from streamlit.components.v1 import html
8
+
9
  from output import init_settings as init_output_settings
10
  from output import scan as scan_output
11
  from prompt import init_settings as init_prompt_settings
12
  from prompt import scan as scan_prompt
13
 
14
+
15
+ def add_google_analytics(ga4_id):
16
+ """
17
+ Add Google Analytics 4 to a Streamlit app
18
+ """
19
+ ga_code = f"""
20
+ <script async src="https://www.googletagmanager.com/gtag/js?id={ga4_id}"></script>
21
+ <script>
22
+ window.dataLayer = window.dataLayer || [];
23
+ function gtag(){{dataLayer.push(arguments);}}
24
+ gtag('js', new Date());
25
+ gtag('config', '{ga4_id}');
26
+ </script>
27
+ """
28
+
29
+ html(ga_code)
30
+
31
 
32
  PROMPT = "prompt"
33
  OUTPUT = "output"
 
65
  elif scanner_type == OUTPUT:
66
  enabled_scanners, settings = init_output_settings()
67
 
68
+ add_google_analytics("G-0HBVNHEZBW")
69
+
70
  # Main pannel
71
  with st.expander("About", expanded=False):
72
  st.info(
 
112
  st_result_text = None
113
  st_analysis = None
114
  st_is_valid = None
 
115
 
116
  try:
117
  with st.form("text_form", clear_on_submit=False):
118
  submitted = st.form_submit_button("Process")
119
  if submitted:
120
+ results = {}
 
121
 
 
122
  if scanner_type == PROMPT:
123
+ st_result_text, results = scan_prompt(
124
  vault, enabled_scanners, settings, st_prompt_text, st_fail_fast
125
  )
126
  elif scanner_type == OUTPUT:
127
+ st_result_text, results = scan_output(
128
  vault, enabled_scanners, settings, st_prompt_text, st_output_text, st_fail_fast
129
  )
 
 
130
 
131
+ st_is_valid = all(item["is_valid"] for item in results)
132
+ st_analysis = results
 
 
 
133
 
134
  except Exception as e:
135
  logger.error(e)
 
138
 
139
  # After:
140
  if st_is_valid is not None:
141
+ st.subheader(f"Results - {'valid' if st_is_valid else 'invalid'}")
 
 
142
 
143
  col1, col2 = st.columns(2)
144
 
output.py CHANGED
@@ -1,9 +1,9 @@
1
  import logging
 
 
2
  from typing import Dict, List
3
 
4
  import streamlit as st
5
- from streamlit_tags import st_tags
6
-
7
  from llm_guard.input_scanners.anonymize import default_entity_types
8
  from llm_guard.output_scanners import (
9
  JSON,
@@ -12,11 +12,11 @@ from llm_guard.output_scanners import (
12
  Bias,
13
  Code,
14
  Deanonymize,
 
15
  Language,
16
  LanguageSame,
17
  MaliciousURLs,
18
  NoRefusal,
19
- Refutation,
20
  Regex,
21
  Relevance,
22
  Sensitive,
@@ -25,6 +25,7 @@ from llm_guard.output_scanners.relevance import all_models as relevance_models
25
  from llm_guard.output_scanners.sentiment import Sentiment
26
  from llm_guard.output_scanners.toxicity import Toxicity
27
  from llm_guard.vault import Vault
 
28
 
29
  logger = logging.getLogger("llm-guard-playground")
30
 
@@ -41,7 +42,7 @@ def init_settings() -> (List, Dict):
41
  "LanguageSame",
42
  "MaliciousURLs",
43
  "NoRefusal",
44
- "Refutation",
45
  "Regex",
46
  "Relevance",
47
  "Sensitive",
@@ -163,7 +164,12 @@ def init_settings() -> (List, Dict):
163
  help="The minimum number of JSON elements that should be present",
164
  )
165
 
166
- settings["JSON"] = {"required_elements": st_json_required_elements}
 
 
 
 
 
167
 
168
  if "Language" in st_enabled_scanners:
169
  st_lan_expander = st.sidebar.expander(
@@ -274,23 +280,23 @@ def init_settings() -> (List, Dict):
274
 
275
  settings["NoRefusal"] = {"threshold": st_no_ref_threshold}
276
 
277
- if "Refutation" in st_enabled_scanners:
278
- st_refu_expander = st.sidebar.expander(
279
- "Refutation",
280
  expanded=False,
281
  )
282
 
283
- with st_refu_expander:
284
- st_refu_threshold = st.slider(
285
- label="Threshold",
286
  value=0.5,
287
  min_value=0.0,
288
  max_value=1.0,
289
  step=0.05,
290
- key="refu_threshold",
291
  )
292
 
293
- settings["Refutation"] = {"threshold": st_refu_threshold}
294
 
295
  if "Regex" in st_enabled_scanners:
296
  st_regex_expander = st.sidebar.expander(
@@ -359,7 +365,7 @@ def init_settings() -> (List, Dict):
359
  key="sensitive_entity_types",
360
  )
361
  st.caption(
362
- "Check all supported entities: https://microsoft.github.io/presidio/supported_entities/#list-of-supported-entities"
363
  )
364
  st_sens_redact = st.checkbox("Redact", value=False, key="sens_redact")
365
  st_sens_threshold = st.slider(
@@ -434,13 +440,13 @@ def get_scanner(scanner_name: str, vault: Vault, settings: Dict):
434
  return BanTopics(topics=settings["topics"], threshold=settings["threshold"])
435
 
436
  if scanner_name == "Bias":
437
- return Bias(threshold=settings["threshold"])
438
 
439
  if scanner_name == "Deanonymize":
440
  return Deanonymize(vault=vault)
441
 
442
  if scanner_name == "JSON":
443
- return JSON(required_elements=settings["required_elements"])
444
 
445
  if scanner_name == "Language":
446
  return Language(valid_languages=settings["valid_languages"])
@@ -458,16 +464,16 @@ def get_scanner(scanner_name: str, vault: Vault, settings: Dict):
458
  elif mode == "denied":
459
  denied_languages = settings["languages"]
460
 
461
- return Code(allowed=allowed_languages, denied=denied_languages)
462
 
463
  if scanner_name == "MaliciousURLs":
464
- return MaliciousURLs(threshold=settings["threshold"])
465
 
466
  if scanner_name == "NoRefusal":
467
  return NoRefusal(threshold=settings["threshold"])
468
 
469
- if scanner_name == "Refutation":
470
- return Refutation(threshold=settings["threshold"])
471
 
472
  if scanner_name == "Regex":
473
  match_type = settings["type"]
@@ -491,13 +497,14 @@ def get_scanner(scanner_name: str, vault: Vault, settings: Dict):
491
  entity_types=settings["entity_types"],
492
  redact=settings["redact"],
493
  threshold=settings["threshold"],
 
494
  )
495
 
496
  if scanner_name == "Sentiment":
497
  return Sentiment(threshold=settings["threshold"])
498
 
499
  if scanner_name == "Toxicity":
500
- return Toxicity(threshold=settings["threshold"])
501
 
502
  raise ValueError("Unknown scanner name")
503
 
@@ -509,10 +516,9 @@ def scan(
509
  prompt: str,
510
  text: str,
511
  fail_fast: bool = False,
512
- ) -> (str, Dict[str, bool], Dict[str, float]):
513
  sanitized_output = text
514
- results_valid = {}
515
- results_score = {}
516
 
517
  status_text = "Scanning prompt..."
518
  if fail_fast:
@@ -524,13 +530,23 @@ def scan(
524
  scanner = get_scanner(
525
  scanner_name, vault, settings[scanner_name] if scanner_name in settings else {}
526
  )
 
 
527
  sanitized_output, is_valid, risk_score = scanner.scan(prompt, sanitized_output)
528
- results_valid[scanner_name] = is_valid
529
- results_score[scanner_name] = risk_score
 
 
 
 
 
 
 
 
530
 
531
  if fail_fast and not is_valid:
532
  break
533
 
534
  status.update(label="Scanning complete", state="complete", expanded=False)
535
 
536
- return sanitized_output, results_valid, results_score
 
1
  import logging
2
+ import time
3
+ from datetime import timedelta
4
  from typing import Dict, List
5
 
6
  import streamlit as st
 
 
7
  from llm_guard.input_scanners.anonymize import default_entity_types
8
  from llm_guard.output_scanners import (
9
  JSON,
 
12
  Bias,
13
  Code,
14
  Deanonymize,
15
+ FactualConsistency,
16
  Language,
17
  LanguageSame,
18
  MaliciousURLs,
19
  NoRefusal,
 
20
  Regex,
21
  Relevance,
22
  Sensitive,
 
25
  from llm_guard.output_scanners.sentiment import Sentiment
26
  from llm_guard.output_scanners.toxicity import Toxicity
27
  from llm_guard.vault import Vault
28
+ from streamlit_tags import st_tags
29
 
30
  logger = logging.getLogger("llm-guard-playground")
31
 
 
42
  "LanguageSame",
43
  "MaliciousURLs",
44
  "NoRefusal",
45
+ "FactualConsistency",
46
  "Regex",
47
  "Relevance",
48
  "Sensitive",
 
164
  help="The minimum number of JSON elements that should be present",
165
  )
166
 
167
+ st_json_repair = st.checkbox("Repair", value=False, help="Attempt to repair the JSON")
168
+
169
+ settings["JSON"] = {
170
+ "required_elements": st_json_required_elements,
171
+ "repair": st_json_repair,
172
+ }
173
 
174
  if "Language" in st_enabled_scanners:
175
  st_lan_expander = st.sidebar.expander(
 
280
 
281
  settings["NoRefusal"] = {"threshold": st_no_ref_threshold}
282
 
283
+ if "FactualConsistency" in st_enabled_scanners:
284
+ st_fc_expander = st.sidebar.expander(
285
+ "FactualConsistency",
286
  expanded=False,
287
  )
288
 
289
+ with st_fc_expander:
290
+ st_fc_minimum_score = st.slider(
291
+ label="Minimum score",
292
  value=0.5,
293
  min_value=0.0,
294
  max_value=1.0,
295
  step=0.05,
296
+ key="fc_threshold",
297
  )
298
 
299
+ settings["FactualConsistency"] = {"minimum_score": st_fc_minimum_score}
300
 
301
  if "Regex" in st_enabled_scanners:
302
  st_regex_expander = st.sidebar.expander(
 
365
  key="sensitive_entity_types",
366
  )
367
  st.caption(
368
+ "Check all supported entities: https://llm-guard.com/input_scanners/anonymize/"
369
  )
370
  st_sens_redact = st.checkbox("Redact", value=False, key="sens_redact")
371
  st_sens_threshold = st.slider(
 
440
  return BanTopics(topics=settings["topics"], threshold=settings["threshold"])
441
 
442
  if scanner_name == "Bias":
443
+ return Bias(threshold=settings["threshold"], use_onnx=True)
444
 
445
  if scanner_name == "Deanonymize":
446
  return Deanonymize(vault=vault)
447
 
448
  if scanner_name == "JSON":
449
+ return JSON(required_elements=settings["required_elements"], repair=settings["repair"])
450
 
451
  if scanner_name == "Language":
452
  return Language(valid_languages=settings["valid_languages"])
 
464
  elif mode == "denied":
465
  denied_languages = settings["languages"]
466
 
467
+ return Code(allowed=allowed_languages, denied=denied_languages, use_onnx=True)
468
 
469
  if scanner_name == "MaliciousURLs":
470
+ return MaliciousURLs(threshold=settings["threshold"], use_onnx=True)
471
 
472
  if scanner_name == "NoRefusal":
473
  return NoRefusal(threshold=settings["threshold"])
474
 
475
+ if scanner_name == "FactualConsistency":
476
+ return FactualConsistency(minimum_score=settings["minimum_score"])
477
 
478
  if scanner_name == "Regex":
479
  match_type = settings["type"]
 
497
  entity_types=settings["entity_types"],
498
  redact=settings["redact"],
499
  threshold=settings["threshold"],
500
+ use_onnx=True,
501
  )
502
 
503
  if scanner_name == "Sentiment":
504
  return Sentiment(threshold=settings["threshold"])
505
 
506
  if scanner_name == "Toxicity":
507
+ return Toxicity(threshold=settings["threshold"], use_onnx=True)
508
 
509
  raise ValueError("Unknown scanner name")
510
 
 
516
  prompt: str,
517
  text: str,
518
  fail_fast: bool = False,
519
+ ) -> (str, List[Dict[str, any]]):
520
  sanitized_output = text
521
+ results = []
 
522
 
523
  status_text = "Scanning prompt..."
524
  if fail_fast:
 
530
  scanner = get_scanner(
531
  scanner_name, vault, settings[scanner_name] if scanner_name in settings else {}
532
  )
533
+
534
+ start_time = time.monotonic()
535
  sanitized_output, is_valid, risk_score = scanner.scan(prompt, sanitized_output)
536
+ end_time = time.monotonic()
537
+
538
+ results.append(
539
+ {
540
+ "scanner": scanner_name,
541
+ "is_valid": is_valid,
542
+ "risk_score": risk_score,
543
+ "took_sec": round(timedelta(seconds=end_time - start_time).total_seconds(), 2),
544
+ }
545
+ )
546
 
547
  if fail_fast and not is_valid:
548
  break
549
 
550
  status.update(label="Scanning complete", state="complete", expanded=False)
551
 
552
+ return sanitized_output, results
prompt.py CHANGED
@@ -1,9 +1,9 @@
1
  import logging
 
 
2
  from typing import Dict, List
3
 
4
  import streamlit as st
5
- from streamlit_tags import st_tags
6
-
7
  from llm_guard.input_scanners import (
8
  Anonymize,
9
  BanSubstrings,
@@ -11,7 +11,6 @@ from llm_guard.input_scanners import (
11
  Code,
12
  Language,
13
  PromptInjection,
14
- PromptInjectionV2,
15
  Regex,
16
  Secrets,
17
  Sentiment,
@@ -19,8 +18,9 @@ from llm_guard.input_scanners import (
19
  Toxicity,
20
  )
21
  from llm_guard.input_scanners.anonymize import default_entity_types
22
- from llm_guard.input_scanners.anonymize_helpers.analyzer import RECOGNIZER_SPACY_EN_PII_DISTILBERT, RECOGNIZER_SPACY_EN_PII_FAST
23
  from llm_guard.vault import Vault
 
24
 
25
  logger = logging.getLogger("llm-guard-playground")
26
 
@@ -33,7 +33,6 @@ def init_settings() -> (List, Dict):
33
  "Code",
34
  "Language",
35
  "PromptInjection",
36
- "PromptInjectionV2",
37
  "Regex",
38
  "Secrets",
39
  "Sentiment",
@@ -67,7 +66,7 @@ def init_settings() -> (List, Dict):
67
  key="anon_entity_types",
68
  )
69
  st.caption(
70
- "Check all supported entities: https://microsoft.github.io/presidio/supported_entities/#list-of-supported-entities"
71
  )
72
  st_anon_hidden_names = st_tags(
73
  label="Hidden names to be anonymized",
@@ -101,11 +100,6 @@ def init_settings() -> (List, Dict):
101
  step=0.1,
102
  key="anon_threshold",
103
  )
104
- st_anon_recognizer = st.selectbox(
105
- "Recognizer",
106
- [RECOGNIZER_SPACY_EN_PII_DISTILBERT, RECOGNIZER_SPACY_EN_PII_FAST],
107
- index=1,
108
- )
109
 
110
  settings["Anonymize"] = {
111
  "entity_types": st_anon_entity_types,
@@ -114,7 +108,6 @@ def init_settings() -> (List, Dict):
114
  "preamble": st_anon_preamble,
115
  "use_faker": st_anon_use_faker,
116
  "threshold": st_anon_threshold,
117
- "recognizer": st_anon_recognizer,
118
  }
119
 
120
  if "BanSubstrings" in st_enabled_scanners:
@@ -286,26 +279,6 @@ def init_settings() -> (List, Dict):
286
  "threshold": st_pi_threshold,
287
  }
288
 
289
- if "PromptInjectionV2" in st_enabled_scanners:
290
- st_piv2_expander = st.sidebar.expander(
291
- "Prompt Injection V2",
292
- expanded=False,
293
- )
294
-
295
- with st_piv2_expander:
296
- st_piv2_threshold = st.slider(
297
- label="Threshold",
298
- value=0.5,
299
- min_value=0.0,
300
- max_value=1.0,
301
- step=0.05,
302
- key="prompt_injection_v2_threshold",
303
- )
304
-
305
- settings["PromptInjectionV2"] = {
306
- "threshold": st_piv2_threshold,
307
- }
308
-
309
  if "Regex" in st_enabled_scanners:
310
  st_regex_expander = st.sidebar.expander(
311
  "Regex",
@@ -427,7 +400,7 @@ def get_scanner(scanner_name: str, vault: Vault, settings: Dict):
427
  preamble=settings["preamble"],
428
  use_faker=settings["use_faker"],
429
  threshold=settings["threshold"],
430
- recognizer=settings["recognizer"],
431
  )
432
 
433
  if scanner_name == "BanSubstrings":
@@ -452,16 +425,13 @@ def get_scanner(scanner_name: str, vault: Vault, settings: Dict):
452
  elif mode == "denied":
453
  denied_languages = settings["languages"]
454
 
455
- return Code(allowed=allowed_languages, denied=denied_languages)
456
 
457
  if scanner_name == "Language":
458
  return Language(valid_languages=settings["valid_languages"])
459
 
460
  if scanner_name == "PromptInjection":
461
- return PromptInjection(threshold=settings["threshold"])
462
-
463
- if scanner_name == "PromptInjectionV2":
464
- return PromptInjectionV2(threshold=settings["threshold"])
465
 
466
  if scanner_name == "Regex":
467
  match_type = settings["type"]
@@ -487,17 +457,16 @@ def get_scanner(scanner_name: str, vault: Vault, settings: Dict):
487
  return TokenLimit(limit=settings["limit"], encoding_name=settings["encoding_name"])
488
 
489
  if scanner_name == "Toxicity":
490
- return Toxicity(threshold=settings["threshold"])
491
 
492
  raise ValueError("Unknown scanner name")
493
 
494
 
495
  def scan(
496
  vault: Vault, enabled_scanners: List[str], settings: Dict, text: str, fail_fast: bool = False
497
- ) -> (str, Dict[str, bool], Dict[str, float]):
498
  sanitized_prompt = text
499
- results_valid = {}
500
- results_score = {}
501
 
502
  status_text = "Scanning prompt..."
503
  if fail_fast:
@@ -507,12 +476,23 @@ def scan(
507
  for scanner_name in enabled_scanners:
508
  st.write(f"{scanner_name} scanner...")
509
  scanner = get_scanner(scanner_name, vault, settings[scanner_name])
 
 
510
  sanitized_prompt, is_valid, risk_score = scanner.scan(sanitized_prompt)
511
- results_valid[scanner_name] = is_valid
512
- results_score[scanner_name] = risk_score
 
 
 
 
 
 
 
 
513
 
514
  if fail_fast and not is_valid:
515
  break
 
516
  status.update(label="Scanning complete", state="complete", expanded=False)
517
 
518
- return sanitized_prompt, results_valid, results_score
 
1
  import logging
2
+ import time
3
+ from datetime import timedelta
4
  from typing import Dict, List
5
 
6
  import streamlit as st
 
 
7
  from llm_guard.input_scanners import (
8
  Anonymize,
9
  BanSubstrings,
 
11
  Code,
12
  Language,
13
  PromptInjection,
 
14
  Regex,
15
  Secrets,
16
  Sentiment,
 
18
  Toxicity,
19
  )
20
  from llm_guard.input_scanners.anonymize import default_entity_types
21
+ from llm_guard.input_scanners.prompt_injection import ALL_MODELS as PI_ALL_MODELS
22
  from llm_guard.vault import Vault
23
+ from streamlit_tags import st_tags
24
 
25
  logger = logging.getLogger("llm-guard-playground")
26
 
 
33
  "Code",
34
  "Language",
35
  "PromptInjection",
 
36
  "Regex",
37
  "Secrets",
38
  "Sentiment",
 
66
  key="anon_entity_types",
67
  )
68
  st.caption(
69
+ "Check all supported entities: https://llm-guard.com/input_scanners/anonymize/"
70
  )
71
  st_anon_hidden_names = st_tags(
72
  label="Hidden names to be anonymized",
 
100
  step=0.1,
101
  key="anon_threshold",
102
  )
 
 
 
 
 
103
 
104
  settings["Anonymize"] = {
105
  "entity_types": st_anon_entity_types,
 
108
  "preamble": st_anon_preamble,
109
  "use_faker": st_anon_use_faker,
110
  "threshold": st_anon_threshold,
 
111
  }
112
 
113
  if "BanSubstrings" in st_enabled_scanners:
 
279
  "threshold": st_pi_threshold,
280
  }
281
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
282
  if "Regex" in st_enabled_scanners:
283
  st_regex_expander = st.sidebar.expander(
284
  "Regex",
 
400
  preamble=settings["preamble"],
401
  use_faker=settings["use_faker"],
402
  threshold=settings["threshold"],
403
+ use_onnx=True,
404
  )
405
 
406
  if scanner_name == "BanSubstrings":
 
425
  elif mode == "denied":
426
  denied_languages = settings["languages"]
427
 
428
+ return Code(allowed=allowed_languages, denied=denied_languages, use_onnx=True)
429
 
430
  if scanner_name == "Language":
431
  return Language(valid_languages=settings["valid_languages"])
432
 
433
  if scanner_name == "PromptInjection":
434
+ return PromptInjection(threshold=settings["threshold"], models=PI_ALL_MODELS, use_onnx=True)
 
 
 
435
 
436
  if scanner_name == "Regex":
437
  match_type = settings["type"]
 
457
  return TokenLimit(limit=settings["limit"], encoding_name=settings["encoding_name"])
458
 
459
  if scanner_name == "Toxicity":
460
+ return Toxicity(threshold=settings["threshold"], use_onnx=True)
461
 
462
  raise ValueError("Unknown scanner name")
463
 
464
 
465
  def scan(
466
  vault: Vault, enabled_scanners: List[str], settings: Dict, text: str, fail_fast: bool = False
467
+ ) -> (str, List[Dict[str, any]]):
468
  sanitized_prompt = text
469
+ results = []
 
470
 
471
  status_text = "Scanning prompt..."
472
  if fail_fast:
 
476
  for scanner_name in enabled_scanners:
477
  st.write(f"{scanner_name} scanner...")
478
  scanner = get_scanner(scanner_name, vault, settings[scanner_name])
479
+
480
+ start_time = time.monotonic()
481
  sanitized_prompt, is_valid, risk_score = scanner.scan(sanitized_prompt)
482
+ end_time = time.monotonic()
483
+
484
+ results.append(
485
+ {
486
+ "scanner": scanner_name,
487
+ "is_valid": is_valid,
488
+ "risk_score": risk_score,
489
+ "took_sec": round(timedelta(seconds=end_time - start_time).total_seconds(), 2),
490
+ }
491
+ )
492
 
493
  if fail_fast and not is_valid:
494
  break
495
+
496
  status.update(label="Scanning complete", state="complete", expanded=False)
497
 
498
+ return sanitized_prompt, results
requirements.txt CHANGED
@@ -1,6 +1,5 @@
1
- https://huggingface.co/beki/en_spacy_pii_distilbert/resolve/main/en_spacy_pii_distilbert-any-py3-none-any.whl
2
- llm-guard==0.3.0
3
- pandas==2.1.0
4
- streamlit==1.27.2
5
  streamlit-tags==1.2.8
6
- https://huggingface.co/beki/en_spacy_pii_fast/resolve/main/en_spacy_pii_fast-any-py3-none-any.whl
 
1
+ llm-guard==0.3.1
2
+ llm-guard[onnxruntime]==0.3.1
3
+ pandas==2.1.2
4
+ streamlit==1.28.1
5
  streamlit-tags==1.2.8