asofter commited on
Commit
727d1ca
1 Parent(s): 19ee1e4

* upgrade version of llm-guard

Browse files
Files changed (3) hide show
  1. output.py +35 -134
  2. prompt.py +25 -117
  3. requirements.txt +4 -4
output.py CHANGED
@@ -5,25 +5,8 @@ from typing import Dict, List
5
 
6
  import streamlit as st
7
  from llm_guard.input_scanners.anonymize import default_entity_types
8
- from llm_guard.output_scanners import (
9
- JSON,
10
- BanSubstrings,
11
- BanTopics,
12
- Bias,
13
- Code,
14
- Deanonymize,
15
- FactualConsistency,
16
- Language,
17
- LanguageSame,
18
- MaliciousURLs,
19
- NoRefusal,
20
- Regex,
21
- Relevance,
22
- Sensitive,
23
- )
24
  from llm_guard.output_scanners.relevance import all_models as relevance_models
25
- from llm_guard.output_scanners.sentiment import Sentiment
26
- from llm_guard.output_scanners.toxicity import Toxicity
27
  from llm_guard.vault import Vault
28
  from streamlit_tags import st_tags
29
 
@@ -145,7 +128,14 @@ def init_settings() -> (List, Dict):
145
 
146
  st_cd_mode = st.selectbox("Mode", ["allowed", "denied"], index=0)
147
 
148
- settings["Code"] = {"languages": st_cd_languages, "mode": st_cd_mode}
 
 
 
 
 
 
 
149
 
150
  if "JSON" in st_enabled_scanners:
151
  st_json_expander = st.sidebar.expander(
@@ -181,61 +171,26 @@ def init_settings() -> (List, Dict):
181
  st_lan_valid_language = st.multiselect(
182
  "Languages",
183
  [
184
- "af",
185
  "ar",
186
  "bg",
187
- "bn",
188
- "ca",
189
- "cs",
190
- "cy",
191
- "da",
192
  "de",
193
  "el",
194
  "en",
195
  "es",
196
- "et",
197
- "fa",
198
- "fi",
199
  "fr",
200
- "gu",
201
- "he",
202
  "hi",
203
- "hr",
204
- "hu",
205
- "id",
206
  "it",
207
  "ja",
208
- "kn",
209
- "ko",
210
- "lt",
211
- "lv",
212
- "mk",
213
- "ml",
214
- "mr",
215
- "ne",
216
  "nl",
217
- "no",
218
- "pa",
219
  "pl",
220
  "pt",
221
- "ro",
222
  "ru",
223
- "sk",
224
- "sl",
225
- "so",
226
- "sq",
227
- "sv",
228
  "sw",
229
- "ta",
230
- "te",
231
  "th",
232
- "tl",
233
  "tr",
234
- "uk",
235
  "ur",
236
  "vi",
237
- "zh-cn",
238
- "zh-tw",
239
  ],
240
  default=["en"],
241
  )
@@ -322,9 +277,16 @@ def init_settings() -> (List, Dict):
322
  "Redact", value=False, help="Replace the matched bad patterns with [REDACTED]"
323
  )
324
 
 
 
 
 
 
 
 
325
  settings["Regex"] = {
326
- "patterns": st_regex_patterns,
327
- "type": st_regex_type,
328
  "redact": st_redact,
329
  }
330
 
@@ -427,86 +389,25 @@ def init_settings() -> (List, Dict):
427
  def get_scanner(scanner_name: str, vault: Vault, settings: Dict):
428
  logger.debug(f"Initializing {scanner_name} scanner")
429
 
430
- if scanner_name == "BanSubstrings":
431
- return BanSubstrings(
432
- substrings=settings["substrings"],
433
- match_type=settings["match_type"],
434
- case_sensitive=settings["case_sensitive"],
435
- redact=settings["redact"],
436
- contains_all=settings["contains_all"],
437
- )
438
-
439
- if scanner_name == "BanTopics":
440
- return BanTopics(topics=settings["topics"], threshold=settings["threshold"])
441
-
442
- if scanner_name == "Bias":
443
- return Bias(threshold=settings["threshold"], use_onnx=True)
444
-
445
  if scanner_name == "Deanonymize":
446
- return Deanonymize(vault=vault)
447
-
448
- if scanner_name == "JSON":
449
- return JSON(required_elements=settings["required_elements"], repair=settings["repair"])
450
-
451
- if scanner_name == "Language":
452
- return Language(valid_languages=settings["valid_languages"])
453
-
454
- if scanner_name == "LanguageSame":
455
- return LanguageSame()
456
-
457
- if scanner_name == "Code":
458
- mode = settings["mode"]
459
-
460
- allowed_languages = None
461
- denied_languages = None
462
- if mode == "allowed":
463
- allowed_languages = settings["languages"]
464
- elif mode == "denied":
465
- denied_languages = settings["languages"]
466
-
467
- return Code(allowed=allowed_languages, denied=denied_languages, use_onnx=True)
468
-
469
- if scanner_name == "MaliciousURLs":
470
- return MaliciousURLs(threshold=settings["threshold"], use_onnx=True)
471
-
472
- if scanner_name == "NoRefusal":
473
- return NoRefusal(threshold=settings["threshold"])
474
-
475
- if scanner_name == "FactualConsistency":
476
- return FactualConsistency(minimum_score=settings["minimum_score"])
477
 
478
- if scanner_name == "Regex":
479
- match_type = settings["type"]
480
-
481
- good_patterns = None
482
- bad_patterns = None
483
- if match_type == "good":
484
- good_patterns = settings["patterns"]
485
- elif match_type == "bad":
486
- bad_patterns = settings["patterns"]
487
-
488
- return Regex(
489
- good_patterns=good_patterns, bad_patterns=bad_patterns, redact=settings["redact"]
490
- )
491
-
492
- if scanner_name == "Relevance":
493
- return Relevance(threshold=settings["threshold"], model=settings["model"])
494
-
495
- if scanner_name == "Sensitive":
496
- return Sensitive(
497
- entity_types=settings["entity_types"],
498
- redact=settings["redact"],
499
- threshold=settings["threshold"],
500
- use_onnx=True,
501
- )
502
-
503
- if scanner_name == "Sentiment":
504
- return Sentiment(threshold=settings["threshold"])
505
-
506
- if scanner_name == "Toxicity":
507
- return Toxicity(threshold=settings["threshold"], use_onnx=True)
508
 
509
- raise ValueError("Unknown scanner name")
510
 
511
 
512
  def scan(
 
5
 
6
  import streamlit as st
7
  from llm_guard.input_scanners.anonymize import default_entity_types
8
+ from llm_guard.output_scanners import get_scanner_by_name
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  from llm_guard.output_scanners.relevance import all_models as relevance_models
 
 
10
  from llm_guard.vault import Vault
11
  from streamlit_tags import st_tags
12
 
 
128
 
129
  st_cd_mode = st.selectbox("Mode", ["allowed", "denied"], index=0)
130
 
131
+ allowed_languages = None
132
+ denied_languages = None
133
+ if st_cd_mode == "allowed":
134
+ allowed_languages = st_cd_languages
135
+ elif st_cd_mode == "denied":
136
+ denied_languages = st_cd_languages
137
+
138
+ settings["Code"] = {"allowed": allowed_languages, "denied": denied_languages}
139
 
140
  if "JSON" in st_enabled_scanners:
141
  st_json_expander = st.sidebar.expander(
 
171
  st_lan_valid_language = st.multiselect(
172
  "Languages",
173
  [
 
174
  "ar",
175
  "bg",
 
 
 
 
 
176
  "de",
177
  "el",
178
  "en",
179
  "es",
 
 
 
180
  "fr",
 
 
181
  "hi",
 
 
 
182
  "it",
183
  "ja",
 
 
 
 
 
 
 
 
184
  "nl",
 
 
185
  "pl",
186
  "pt",
 
187
  "ru",
 
 
 
 
 
188
  "sw",
 
 
189
  "th",
 
190
  "tr",
 
191
  "ur",
192
  "vi",
193
+ "zh",
 
194
  ],
195
  default=["en"],
196
  )
 
277
  "Redact", value=False, help="Replace the matched bad patterns with [REDACTED]"
278
  )
279
 
280
+ good_patterns = None
281
+ bad_patterns = None
282
+ if st_regex_type == "good":
283
+ good_patterns = st_regex_patterns
284
+ elif st_regex_type == "bad":
285
+ bad_patterns = st_regex_patterns
286
+
287
  settings["Regex"] = {
288
+ "good_patterns": good_patterns,
289
+ "bad_patterns": bad_patterns,
290
  "redact": st_redact,
291
  }
292
 
 
389
  def get_scanner(scanner_name: str, vault: Vault, settings: Dict):
390
  logger.debug(f"Initializing {scanner_name} scanner")
391
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
392
  if scanner_name == "Deanonymize":
393
+ settings["vault"] = vault
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
394
 
395
+ if scanner_name in [
396
+ "BanTopics",
397
+ "Bias",
398
+ "Code",
399
+ "Language",
400
+ "LanguageSame",
401
+ "MaliciousURLs",
402
+ "NoRefusal",
403
+ "FactualConsistency",
404
+ "Relevance",
405
+ "Sensitive",
406
+ "Toxicity",
407
+ ]:
408
+ settings["use_onnx"] = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
409
 
410
+ return get_scanner_by_name(scanner_name, settings)
411
 
412
 
413
  def scan(
prompt.py CHANGED
@@ -4,19 +4,7 @@ from datetime import timedelta
4
  from typing import Dict, List
5
 
6
  import streamlit as st
7
- from llm_guard.input_scanners import (
8
- Anonymize,
9
- BanSubstrings,
10
- BanTopics,
11
- Code,
12
- Language,
13
- PromptInjection,
14
- Regex,
15
- Secrets,
16
- Sentiment,
17
- TokenLimit,
18
- Toxicity,
19
- )
20
  from llm_guard.input_scanners.anonymize import default_entity_types
21
  from llm_guard.input_scanners.prompt_injection import ALL_MODELS as PI_ALL_MODELS
22
  from llm_guard.vault import Vault
@@ -181,9 +169,16 @@ def init_settings() -> (List, Dict):
181
 
182
  st_cd_mode = st.selectbox("Mode", ["allowed", "denied"], index=0)
183
 
 
 
 
 
 
 
 
184
  settings["Code"] = {
185
- "languages": st_cd_languages,
186
- "mode": st_cd_mode,
187
  }
188
 
189
  if "Language" in st_enabled_scanners:
@@ -196,61 +191,26 @@ def init_settings() -> (List, Dict):
196
  st_lan_valid_language = st.multiselect(
197
  "Languages",
198
  [
199
- "af",
200
  "ar",
201
  "bg",
202
- "bn",
203
- "ca",
204
- "cs",
205
- "cy",
206
- "da",
207
  "de",
208
  "el",
209
  "en",
210
  "es",
211
- "et",
212
- "fa",
213
- "fi",
214
  "fr",
215
- "gu",
216
- "he",
217
  "hi",
218
- "hr",
219
- "hu",
220
- "id",
221
  "it",
222
  "ja",
223
- "kn",
224
- "ko",
225
- "lt",
226
- "lv",
227
- "mk",
228
- "ml",
229
- "mr",
230
- "ne",
231
  "nl",
232
- "no",
233
- "pa",
234
  "pl",
235
  "pt",
236
- "ro",
237
  "ru",
238
- "sk",
239
- "sl",
240
- "so",
241
- "sq",
242
- "sv",
243
  "sw",
244
- "ta",
245
- "te",
246
  "th",
247
- "tl",
248
  "tr",
249
- "uk",
250
  "ur",
251
  "vi",
252
- "zh-cn",
253
- "zh-tw",
254
  ],
255
  default=["en"],
256
  )
@@ -303,9 +263,16 @@ def init_settings() -> (List, Dict):
303
  "Redact", value=False, help="Replace the matched bad patterns with [REDACTED]"
304
  )
305
 
 
 
 
 
 
 
 
306
  settings["Regex"] = {
307
- "patterns": st_regex_patterns,
308
- "type": st_regex_type,
309
  "redact": st_redact,
310
  }
311
 
@@ -392,74 +359,15 @@ def get_scanner(scanner_name: str, vault: Vault, settings: Dict):
392
  logger.debug(f"Initializing {scanner_name} scanner")
393
 
394
  if scanner_name == "Anonymize":
395
- return Anonymize(
396
- vault=vault,
397
- allowed_names=settings["allowed_names"],
398
- hidden_names=settings["hidden_names"],
399
- entity_types=settings["entity_types"],
400
- preamble=settings["preamble"],
401
- use_faker=settings["use_faker"],
402
- threshold=settings["threshold"],
403
- use_onnx=True,
404
- )
405
-
406
- if scanner_name == "BanSubstrings":
407
- return BanSubstrings(
408
- substrings=settings["substrings"],
409
- match_type=settings["match_type"],
410
- case_sensitive=settings["case_sensitive"],
411
- redact=settings["redact"],
412
- contains_all=settings["contains_all"],
413
- )
414
-
415
- if scanner_name == "BanTopics":
416
- return BanTopics(topics=settings["topics"], threshold=settings["threshold"])
417
-
418
- if scanner_name == "Code":
419
- mode = settings["mode"]
420
-
421
- allowed_languages = None
422
- denied_languages = None
423
- if mode == "allowed":
424
- allowed_languages = settings["languages"]
425
- elif mode == "denied":
426
- denied_languages = settings["languages"]
427
-
428
- return Code(allowed=allowed_languages, denied=denied_languages, use_onnx=True)
429
-
430
- if scanner_name == "Language":
431
- return Language(valid_languages=settings["valid_languages"])
432
 
433
  if scanner_name == "PromptInjection":
434
- return PromptInjection(threshold=settings["threshold"], models=PI_ALL_MODELS, use_onnx=True)
435
-
436
- if scanner_name == "Regex":
437
- match_type = settings["type"]
438
-
439
- good_patterns = None
440
- bad_patterns = None
441
- if match_type == "good":
442
- good_patterns = settings["patterns"]
443
- elif match_type == "bad":
444
- bad_patterns = settings["patterns"]
445
-
446
- return Regex(
447
- good_patterns=good_patterns, bad_patterns=bad_patterns, redact=settings["redact"]
448
- )
449
-
450
- if scanner_name == "Secrets":
451
- return Secrets(redact_mode=settings["redact_mode"])
452
-
453
- if scanner_name == "Sentiment":
454
- return Sentiment(threshold=settings["threshold"])
455
-
456
- if scanner_name == "TokenLimit":
457
- return TokenLimit(limit=settings["limit"], encoding_name=settings["encoding_name"])
458
 
459
- if scanner_name == "Toxicity":
460
- return Toxicity(threshold=settings["threshold"], use_onnx=True)
461
 
462
- raise ValueError("Unknown scanner name")
463
 
464
 
465
  def scan(
 
4
  from typing import Dict, List
5
 
6
  import streamlit as st
7
+ from llm_guard.input_scanners import get_scanner_by_name
 
 
 
 
 
 
 
 
 
 
 
 
8
  from llm_guard.input_scanners.anonymize import default_entity_types
9
  from llm_guard.input_scanners.prompt_injection import ALL_MODELS as PI_ALL_MODELS
10
  from llm_guard.vault import Vault
 
169
 
170
  st_cd_mode = st.selectbox("Mode", ["allowed", "denied"], index=0)
171
 
172
+ allowed_languages = None
173
+ denied_languages = None
174
+ if st_cd_mode == "allowed":
175
+ allowed_languages = st_cd_languages
176
+ elif st_cd_mode == "denied":
177
+ denied_languages = st_cd_languages
178
+
179
  settings["Code"] = {
180
+ "allowed": allowed_languages,
181
+ "denied": denied_languages,
182
  }
183
 
184
  if "Language" in st_enabled_scanners:
 
191
  st_lan_valid_language = st.multiselect(
192
  "Languages",
193
  [
 
194
  "ar",
195
  "bg",
 
 
 
 
 
196
  "de",
197
  "el",
198
  "en",
199
  "es",
 
 
 
200
  "fr",
 
 
201
  "hi",
 
 
 
202
  "it",
203
  "ja",
 
 
 
 
 
 
 
 
204
  "nl",
 
 
205
  "pl",
206
  "pt",
 
207
  "ru",
 
 
 
 
 
208
  "sw",
 
 
209
  "th",
 
210
  "tr",
 
211
  "ur",
212
  "vi",
213
+ "zh",
 
214
  ],
215
  default=["en"],
216
  )
 
263
  "Redact", value=False, help="Replace the matched bad patterns with [REDACTED]"
264
  )
265
 
266
+ good_patterns = None
267
+ bad_patterns = None
268
+ if st_regex_type == "good":
269
+ good_patterns = st_regex_patterns
270
+ elif st_regex_type == "bad":
271
+ bad_patterns = st_regex_patterns
272
+
273
  settings["Regex"] = {
274
+ "good_patterns": good_patterns,
275
+ "bad_patterns": bad_patterns,
276
  "redact": st_redact,
277
  }
278
 
 
359
  logger.debug(f"Initializing {scanner_name} scanner")
360
 
361
  if scanner_name == "Anonymize":
362
+ settings["vault"] = vault
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
363
 
364
  if scanner_name == "PromptInjection":
365
+ settings["models"] = PI_ALL_MODELS
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
366
 
367
+ if scanner_name in ["Anonymize", "BanTopics", "Code", "PromptInjection", "Toxicity"]:
368
+ settings["use_onnx"] = True
369
 
370
+ return get_scanner_by_name(scanner_name, settings)
371
 
372
 
373
  def scan(
requirements.txt CHANGED
@@ -1,5 +1,5 @@
1
- llm-guard==0.3.1
2
- llm-guard[onnxruntime]==0.3.1
3
- pandas==2.1.2
4
- streamlit==1.28.1
5
  streamlit-tags==1.2.8
 
1
+ llm-guard==0.3.2
2
+ llm-guard[onnxruntime]==0.3.2
3
+ pandas==2.1.3
4
+ streamlit==1.28.2
5
  streamlit-tags==1.2.8