asofter's picture
* upgrade of the version
05bf37a
raw
history blame contribute delete
No virus
18 kB
import logging
import time
from datetime import timedelta
from typing import Dict, List
import streamlit as st
from llm_guard.input_scanners.anonymize import default_entity_types
from llm_guard.input_scanners.code import SUPPORTED_LANGUAGES as SUPPORTED_CODE_LANGUAGES
from llm_guard.output_scanners import get_scanner_by_name
from llm_guard.output_scanners.bias import MatchType as BiasMatchType
from llm_guard.output_scanners.deanonymize import MatchingStrategy as DeanonymizeMatchingStrategy
from llm_guard.output_scanners.gibberish import MatchType as GibberishMatchType
from llm_guard.output_scanners.language import MatchType as LanguageMatchType
from llm_guard.output_scanners.toxicity import MatchType as ToxicityMatchType
from llm_guard.vault import Vault
from streamlit_tags import st_tags
logger = logging.getLogger("llm-guard-playground")
def init_settings() -> (List, Dict):
all_scanners = [
"BanCode",
"BanCompetitors",
"BanSubstrings",
"BanTopics",
"Bias",
"Code",
"Deanonymize",
"JSON",
"Language",
"LanguageSame",
"MaliciousURLs",
"NoRefusal",
"NoRefusalLight" "ReadingTime",
"FactualConsistency",
"Gibberish",
"Regex",
"Relevance",
"Sensitive",
"Sentiment",
"Toxicity",
"URLReachability",
]
st_enabled_scanners = st.sidebar.multiselect(
"Select scanners",
options=all_scanners,
default=all_scanners,
help="The list can be found here: https://llm-guard.com/output_scanners/bias/",
)
settings = {}
if "BanCode" in st_enabled_scanners:
st_bc_expander = st.sidebar.expander(
"Ban Code",
expanded=False,
)
with st_bc_expander:
st_bc_threshold = st.slider(
label="Threshold",
value=0.95,
min_value=0.0,
max_value=1.0,
step=0.05,
key="ban_code_threshold",
)
settings["BanCode"] = {"threshold": st_bc_threshold}
if "BanCompetitors" in st_enabled_scanners:
st_bc_expander = st.sidebar.expander(
"Ban Competitors",
expanded=False,
)
with st_bc_expander:
st_bc_competitors = st_tags(
label="List of competitors",
text="Type and press enter",
value=["openai", "anthropic", "deepmind", "google"],
suggestions=[],
maxtags=30,
key="bc_competitors",
)
st_bc_threshold = st.slider(
label="Threshold",
value=0.5,
min_value=0.0,
max_value=1.0,
step=0.05,
key="ban_competitors_threshold",
)
settings["BanCompetitors"] = {
"competitors": st_bc_competitors,
"threshold": st_bc_threshold,
}
if "BanSubstrings" in st_enabled_scanners:
st_bs_expander = st.sidebar.expander(
"Ban Substrings",
expanded=False,
)
with st_bs_expander:
st_bs_substrings = st.text_area(
"Enter substrings to ban (one per line)",
value="test\nhello\nworld\n",
height=200,
).split("\n")
st_bs_match_type = st.selectbox(
"Match type", ["str", "word"], index=0, key="bs_match_type"
)
st_bs_case_sensitive = st.checkbox(
"Case sensitive", value=False, key="bs_case_sensitive"
)
st_bs_redact = st.checkbox("Redact", value=False, key="bs_redact")
st_bs_contains_all = st.checkbox("Contains all", value=False, key="bs_contains_all")
settings["BanSubstrings"] = {
"substrings": st_bs_substrings,
"match_type": st_bs_match_type,
"case_sensitive": st_bs_case_sensitive,
"redact": st_bs_redact,
"contains_all": st_bs_contains_all,
}
if "BanTopics" in st_enabled_scanners:
st_bt_expander = st.sidebar.expander(
"Ban Topics",
expanded=False,
)
with st_bt_expander:
st_bt_topics = st_tags(
label="List of topics",
text="Type and press enter",
value=["violence"],
suggestions=[],
maxtags=30,
key="bt_topics",
)
st_bt_threshold = st.slider(
label="Threshold",
value=0.6,
min_value=0.0,
max_value=1.0,
step=0.05,
key="ban_topics_threshold",
)
settings["BanTopics"] = {"topics": st_bt_topics, "threshold": st_bt_threshold}
if "Bias" in st_enabled_scanners:
st_bias_expander = st.sidebar.expander(
"Bias",
expanded=False,
)
with st_bias_expander:
st_bias_threshold = st.slider(
label="Threshold",
value=0.75,
min_value=0.0,
max_value=1.0,
step=0.05,
key="bias_threshold",
)
st_bias_match_type = st.selectbox(
"Match type", [e.value for e in BiasMatchType], index=1, key="bias_match_type"
)
settings["Bias"] = {
"threshold": st_bias_threshold,
"match_type": BiasMatchType(st_bias_match_type),
}
if "Code" in st_enabled_scanners:
st_cd_expander = st.sidebar.expander(
"Code",
expanded=False,
)
with st_cd_expander:
st_cd_languages = st.multiselect(
"Programming languages",
options=SUPPORTED_CODE_LANGUAGES,
default=["Python"],
)
st_cd_is_blocked = st.checkbox("Is blocked", value=False, key="cd_is_blocked")
settings["Code"] = {
"languages": st_cd_languages,
"is_blocked": st_cd_is_blocked,
}
if "Deanonymize" in st_enabled_scanners:
st_de_expander = st.sidebar.expander(
"Deanonymize",
expanded=False,
)
with st_de_expander:
st_de_matching_strategy = st.selectbox(
"Matching strategy", [e.value for e in DeanonymizeMatchingStrategy], index=0
)
settings["Deanonymize"] = {
"matching_strategy": DeanonymizeMatchingStrategy(st_de_matching_strategy)
}
if "JSON" in st_enabled_scanners:
st_json_expander = st.sidebar.expander(
"JSON",
expanded=False,
)
with st_json_expander:
st_json_required_elements = st.slider(
label="Required elements",
value=0,
min_value=0,
max_value=10,
step=1,
key="json_required_elements",
help="The minimum number of JSON elements that should be present",
)
st_json_repair = st.checkbox(
"Repair", value=False, help="Attempt to repair the JSON", key="json_repair"
)
settings["JSON"] = {
"required_elements": st_json_required_elements,
"repair": st_json_repair,
}
if "Language" in st_enabled_scanners:
st_lan_expander = st.sidebar.expander(
"Language",
expanded=False,
)
with st_lan_expander:
st_lan_valid_language = st.multiselect(
"Languages",
[
"ar",
"bg",
"de",
"el",
"en",
"es",
"fr",
"hi",
"it",
"ja",
"nl",
"pl",
"pt",
"ru",
"sw",
"th",
"tr",
"ur",
"vi",
"zh",
],
default=["en"],
)
st_lan_match_type = st.selectbox(
"Match type", [e.value for e in LanguageMatchType], index=1, key="lan_match_type"
)
settings["Language"] = {
"valid_languages": st_lan_valid_language,
"match_type": LanguageMatchType(st_lan_match_type),
}
if "MaliciousURLs" in st_enabled_scanners:
st_murls_expander = st.sidebar.expander(
"Malicious URLs",
expanded=False,
)
with st_murls_expander:
st_murls_threshold = st.slider(
label="Threshold",
value=0.75,
min_value=0.0,
max_value=1.0,
step=0.05,
key="murls_threshold",
)
settings["MaliciousURLs"] = {"threshold": st_murls_threshold}
if "NoRefusal" in st_enabled_scanners:
st_no_ref_expander = st.sidebar.expander(
"No refusal",
expanded=False,
)
with st_no_ref_expander:
st_no_ref_threshold = st.slider(
label="Threshold",
value=0.5,
min_value=0.0,
max_value=1.0,
step=0.05,
key="no_ref_threshold",
)
settings["NoRefusal"] = {"threshold": st_no_ref_threshold}
if "NoRefusalLight" in st_enabled_scanners:
settings["NoRefusalLight"] = {}
if "ReadingTime" in st_enabled_scanners:
st_rt_expander = st.sidebar.expander(
"Reading Time",
expanded=False,
)
with st_rt_expander:
st_rt_max_reading_time = st.slider(
label="Max reading time (in minutes)",
value=5,
min_value=0,
max_value=3600,
step=5,
key="rt_max_reading_time",
)
st_rt_truncate = st.checkbox(
"Truncate",
value=False,
help="Truncate the text to the max reading time",
key="rt_truncate",
)
settings["ReadingTime"] = {"max_time": st_rt_max_reading_time, "truncate": st_rt_truncate}
if "FactualConsistency" in st_enabled_scanners:
st_fc_expander = st.sidebar.expander(
"FactualConsistency",
expanded=False,
)
with st_fc_expander:
st_fc_minimum_score = st.slider(
label="Minimum score",
value=0.5,
min_value=0.0,
max_value=1.0,
step=0.05,
key="fc_threshold",
)
settings["FactualConsistency"] = {"minimum_score": st_fc_minimum_score}
if "Regex" in st_enabled_scanners:
st_regex_expander = st.sidebar.expander(
"Regex",
expanded=False,
)
with st_regex_expander:
st_regex_patterns = st.text_area(
"Enter patterns to ban (one per line)",
value="Bearer [A-Za-z0-9-._~+/]+",
height=200,
).split("\n")
st_regex_is_blocked = st.checkbox("Is blocked", value=True, key="regex_is_blocked")
st_regex_redact = st.checkbox(
"Redact",
value=False,
help="Replace the matched bad patterns with [REDACTED]",
key="regex_redact",
)
settings["Regex"] = {
"patterns": st_regex_patterns,
"is_blocked": st_regex_is_blocked,
"redact": st_regex_redact,
}
if "Relevance" in st_enabled_scanners:
st_rele_expander = st.sidebar.expander(
"Relevance",
expanded=False,
)
with st_rele_expander:
st_rele_threshold = st.slider(
label="Threshold",
value=0.5,
min_value=0.0,
max_value=1.0,
step=0.05,
key="rele_threshold",
)
settings["Relevance"] = {"threshold": st_rele_threshold}
if "Sensitive" in st_enabled_scanners:
st_sens_expander = st.sidebar.expander(
"Sensitive",
expanded=False,
)
with st_sens_expander:
st_sens_entity_types = st_tags(
label="Sensitive entities",
text="Type and press enter",
value=default_entity_types,
suggestions=default_entity_types
+ ["DATE_TIME", "NRP", "LOCATION", "MEDICAL_LICENSE", "US_PASSPORT"],
maxtags=30,
key="sensitive_entity_types",
)
st.caption(
"Check all supported entities: https://llm-guard.com/input_scanners/anonymize/"
)
st_sens_redact = st.checkbox("Redact", value=False, key="sens_redact")
st_sens_threshold = st.slider(
label="Threshold",
value=0.0,
min_value=0.0,
max_value=1.0,
step=0.1,
key="sens_threshold",
)
settings["Sensitive"] = {
"entity_types": st_sens_entity_types,
"redact": st_sens_redact,
"threshold": st_sens_threshold,
}
if "Sentiment" in st_enabled_scanners:
st_sent_expander = st.sidebar.expander(
"Sentiment",
expanded=False,
)
with st_sent_expander:
st_sent_threshold = st.slider(
label="Threshold",
value=-0.5,
min_value=-1.0,
max_value=1.0,
step=0.1,
key="sentiment_threshold",
help="Negative values are negative sentiment, positive values are positive sentiment",
)
settings["Sentiment"] = {"threshold": st_sent_threshold}
if "Toxicity" in st_enabled_scanners:
st_tox_expander = st.sidebar.expander(
"Toxicity",
expanded=False,
)
with st_tox_expander:
st_tox_threshold = st.slider(
label="Threshold",
value=0.5,
min_value=0.0,
max_value=1.0,
step=0.05,
key="toxicity_threshold",
)
st_tox_match_type = st.selectbox(
"Match type",
[e.value for e in ToxicityMatchType],
index=1,
key="toxicity_match_type",
)
settings["Toxicity"] = {
"threshold": st_tox_threshold,
"match_type": ToxicityMatchType(st_tox_match_type),
}
if "URLReachability" in st_enabled_scanners:
st_url_expander = st.sidebar.expander(
"URL Reachability",
expanded=False,
)
if st_url_expander:
settings["URLReachability"] = {}
if "Gibberish" in st_enabled_scanners:
st_gib_expander = st.sidebar.expander(
"Gibberish",
expanded=False,
)
with st_gib_expander:
st_gib_threshold = st.slider(
label="Threshold",
value=0.7,
min_value=0.0,
max_value=1.0,
step=0.1,
key="gib_threshold",
)
st_gib_match_type = st.selectbox(
"Match type", [e.value for e in GibberishMatchType], index=1, key="gib_match_type"
)
settings["Gibberish"] = {"match_type": st_gib_match_type, "threshold": st_gib_threshold}
return st_enabled_scanners, settings
def get_scanner(scanner_name: str, vault: Vault, settings: Dict):
logger.debug(f"Initializing {scanner_name} scanner")
if scanner_name == "Deanonymize":
settings["vault"] = vault
if scanner_name in [
"BanCode",
"BanTopics",
"Bias",
"Code",
"Gibberish",
"Language",
"LanguageSame",
"MaliciousURLs",
"NoRefusal",
"FactualConsistency",
"Relevance",
"Sensitive",
"Toxicity",
]:
settings["use_onnx"] = True
return get_scanner_by_name(scanner_name, settings)
def scan(
vault: Vault,
enabled_scanners: List[str],
settings: Dict,
prompt: str,
text: str,
fail_fast: bool = False,
) -> (str, List[Dict[str, any]]):
sanitized_output = text
results = []
status_text = "Scanning prompt..."
if fail_fast:
status_text = "Scanning prompt (fail fast mode)..."
with st.status(status_text, expanded=True) as status:
for scanner_name in enabled_scanners:
st.write(f"{scanner_name} scanner...")
scanner = get_scanner(
scanner_name, vault, settings[scanner_name] if scanner_name in settings else {}
)
start_time = time.monotonic()
sanitized_output, is_valid, risk_score = scanner.scan(prompt, sanitized_output)
end_time = time.monotonic()
results.append(
{
"scanner": scanner_name,
"is_valid": is_valid,
"risk_score": risk_score,
"took_sec": round(timedelta(seconds=end_time - start_time).total_seconds(), 2),
}
)
if fail_fast and not is_valid:
break
status.update(label="Scanning complete", state="complete", expanded=False)
return sanitized_output, results