import logging import time from datetime import timedelta from typing import Dict, List import streamlit as st from llm_guard.input_scanners import ( Anonymize, BanSubstrings, BanTopics, Code, Language, PromptInjection, Regex, Secrets, Sentiment, TokenLimit, Toxicity, ) from llm_guard.input_scanners.anonymize import default_entity_types from llm_guard.input_scanners.prompt_injection import ALL_MODELS as PI_ALL_MODELS from llm_guard.vault import Vault from streamlit_tags import st_tags logger = logging.getLogger("llm-guard-playground") def init_settings() -> (List, Dict): all_scanners = [ "Anonymize", "BanSubstrings", "BanTopics", "Code", "Language", "PromptInjection", "Regex", "Secrets", "Sentiment", "TokenLimit", "Toxicity", ] st_enabled_scanners = st.sidebar.multiselect( "Select scanners", options=all_scanners, default=all_scanners, help="The list can be found here: https://laiyer-ai.github.io/llm-guard/input_scanners/anonymize/", ) settings = {} if "Anonymize" in st_enabled_scanners: st_anon_expander = st.sidebar.expander( "Anonymize", expanded=False, ) with st_anon_expander: st_anon_entity_types = st_tags( label="Anonymize entities", text="Type and press enter", value=default_entity_types, suggestions=default_entity_types + ["DATE_TIME", "NRP", "LOCATION", "MEDICAL_LICENSE", "US_PASSPORT"], maxtags=30, key="anon_entity_types", ) st.caption( "Check all supported entities: https://llm-guard.com/input_scanners/anonymize/" ) st_anon_hidden_names = st_tags( label="Hidden names to be anonymized", text="Type and press enter", value=[], suggestions=[], maxtags=30, key="anon_hidden_names", ) st.caption("These names will be hidden e.g. [REDACTED_CUSTOM1].") st_anon_allowed_names = st_tags( label="Allowed names to ignore", text="Type and press enter", value=[], suggestions=[], maxtags=30, key="anon_allowed_names", ) st.caption("These names will be ignored even if flagged by the detector.") st_anon_preamble = st.text_input( "Preamble", value="Text to prepend to sanitized prompt: " ) st_anon_use_faker = st.checkbox( "Use Faker", value=False, help="Use Faker library to generate fake data" ) st_anon_threshold = st.slider( label="Threshold", value=0.0, min_value=0.0, max_value=1.0, step=0.1, key="anon_threshold", ) settings["Anonymize"] = { "entity_types": st_anon_entity_types, "hidden_names": st_anon_hidden_names, "allowed_names": st_anon_allowed_names, "preamble": st_anon_preamble, "use_faker": st_anon_use_faker, "threshold": st_anon_threshold, } if "BanSubstrings" in st_enabled_scanners: st_bs_expander = st.sidebar.expander( "Ban Substrings", expanded=False, ) with st_bs_expander: st_bs_substrings = st.text_area( "Enter substrings to ban (one per line)", value="test\nhello\nworld", height=200, ).split("\n") st_bs_match_type = st.selectbox("Match type", ["str", "word"]) st_bs_case_sensitive = st.checkbox("Case sensitive", value=False) st_bs_redact = st.checkbox("Redact", value=False) st_bs_contains_all = st.checkbox("Contains all", value=False) settings["BanSubstrings"] = { "substrings": st_bs_substrings, "match_type": st_bs_match_type, "case_sensitive": st_bs_case_sensitive, "redact": st_bs_redact, "contains_all": st_bs_contains_all, } if "BanTopics" in st_enabled_scanners: st_bt_expander = st.sidebar.expander( "Ban Topics", expanded=False, ) with st_bt_expander: st_bt_topics = st_tags( label="List of topics", text="Type and press enter", value=["violence"], suggestions=[], maxtags=30, key="bt_topics", ) st_bt_threshold = st.slider( label="Threshold", value=0.6, min_value=0.0, max_value=1.0, step=0.05, key="ban_topics_threshold", ) settings["BanTopics"] = { "topics": st_bt_topics, "threshold": st_bt_threshold, } if "Code" in st_enabled_scanners: st_cd_expander = st.sidebar.expander( "Code", expanded=False, ) with st_cd_expander: st_cd_languages = st.multiselect( "Programming languages", ["python", "java", "javascript", "go", "php", "ruby"], default=["python"], ) st_cd_mode = st.selectbox("Mode", ["allowed", "denied"], index=0) settings["Code"] = { "languages": st_cd_languages, "mode": st_cd_mode, } if "Language" in st_enabled_scanners: st_lan_expander = st.sidebar.expander( "Language", expanded=False, ) with st_lan_expander: st_lan_valid_language = st.multiselect( "Languages", [ "af", "ar", "bg", "bn", "ca", "cs", "cy", "da", "de", "el", "en", "es", "et", "fa", "fi", "fr", "gu", "he", "hi", "hr", "hu", "id", "it", "ja", "kn", "ko", "lt", "lv", "mk", "ml", "mr", "ne", "nl", "no", "pa", "pl", "pt", "ro", "ru", "sk", "sl", "so", "sq", "sv", "sw", "ta", "te", "th", "tl", "tr", "uk", "ur", "vi", "zh-cn", "zh-tw", ], default=["en"], ) settings["Language"] = { "valid_languages": st_lan_valid_language, } if "PromptInjection" in st_enabled_scanners: st_pi_expander = st.sidebar.expander( "Prompt Injection", expanded=False, ) with st_pi_expander: st_pi_threshold = st.slider( label="Threshold", value=0.75, min_value=0.0, max_value=1.0, step=0.05, key="prompt_injection_threshold", ) settings["PromptInjection"] = { "threshold": st_pi_threshold, } if "Regex" in st_enabled_scanners: st_regex_expander = st.sidebar.expander( "Regex", expanded=False, ) with st_regex_expander: st_regex_patterns = st.text_area( "Enter patterns to ban (one per line)", value="Bearer [A-Za-z0-9-._~+/]+", height=200, ).split("\n") st_regex_type = st.selectbox( "Match type", ["good", "bad"], index=1, help="good: allow only good patterns, bad: ban bad patterns", ) st_redact = st.checkbox( "Redact", value=False, help="Replace the matched bad patterns with [REDACTED]" ) settings["Regex"] = { "patterns": st_regex_patterns, "type": st_regex_type, "redact": st_redact, } if "Secrets" in st_enabled_scanners: st_sec_expander = st.sidebar.expander( "Secrets", expanded=False, ) with st_sec_expander: st_sec_redact_mode = st.selectbox("Redact mode", ["all", "partial", "hash"]) settings["Secrets"] = { "redact_mode": st_sec_redact_mode, } if "Sentiment" in st_enabled_scanners: st_sent_expander = st.sidebar.expander( "Sentiment", expanded=False, ) with st_sent_expander: st_sent_threshold = st.slider( label="Threshold", value=-0.1, min_value=-1.0, max_value=1.0, step=0.1, key="sentiment_threshold", help="Negative values are negative sentiment, positive values are positive sentiment", ) settings["Sentiment"] = { "threshold": st_sent_threshold, } if "TokenLimit" in st_enabled_scanners: st_tl_expander = st.sidebar.expander( "Token Limit", expanded=False, ) with st_tl_expander: st_tl_limit = st.number_input( "Limit", value=4096, min_value=0, max_value=10000, step=10 ) st_tl_encoding_name = st.selectbox( "Encoding name", ["cl100k_base", "p50k_base", "r50k_base"], index=0, help="Read more: https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb", ) settings["TokenLimit"] = { "limit": st_tl_limit, "encoding_name": st_tl_encoding_name, } if "Toxicity" in st_enabled_scanners: st_tox_expander = st.sidebar.expander( "Toxicity", expanded=False, ) with st_tox_expander: st_tox_threshold = st.slider( label="Threshold", value=0.75, min_value=0.0, max_value=1.0, step=0.05, key="toxicity_threshold", ) settings["Toxicity"] = { "threshold": st_tox_threshold, } return st_enabled_scanners, settings def get_scanner(scanner_name: str, vault: Vault, settings: Dict): logger.debug(f"Initializing {scanner_name} scanner") if scanner_name == "Anonymize": return Anonymize( vault=vault, allowed_names=settings["allowed_names"], hidden_names=settings["hidden_names"], entity_types=settings["entity_types"], preamble=settings["preamble"], use_faker=settings["use_faker"], threshold=settings["threshold"], use_onnx=True, ) if scanner_name == "BanSubstrings": return BanSubstrings( substrings=settings["substrings"], match_type=settings["match_type"], case_sensitive=settings["case_sensitive"], redact=settings["redact"], contains_all=settings["contains_all"], ) if scanner_name == "BanTopics": return BanTopics(topics=settings["topics"], threshold=settings["threshold"]) if scanner_name == "Code": mode = settings["mode"] allowed_languages = None denied_languages = None if mode == "allowed": allowed_languages = settings["languages"] elif mode == "denied": denied_languages = settings["languages"] return Code(allowed=allowed_languages, denied=denied_languages, use_onnx=True) if scanner_name == "Language": return Language(valid_languages=settings["valid_languages"]) if scanner_name == "PromptInjection": return PromptInjection(threshold=settings["threshold"], models=PI_ALL_MODELS, use_onnx=True) if scanner_name == "Regex": match_type = settings["type"] good_patterns = None bad_patterns = None if match_type == "good": good_patterns = settings["patterns"] elif match_type == "bad": bad_patterns = settings["patterns"] return Regex( good_patterns=good_patterns, bad_patterns=bad_patterns, redact=settings["redact"] ) if scanner_name == "Secrets": return Secrets(redact_mode=settings["redact_mode"]) if scanner_name == "Sentiment": return Sentiment(threshold=settings["threshold"]) if scanner_name == "TokenLimit": return TokenLimit(limit=settings["limit"], encoding_name=settings["encoding_name"]) if scanner_name == "Toxicity": return Toxicity(threshold=settings["threshold"], use_onnx=True) raise ValueError("Unknown scanner name") def scan( vault: Vault, enabled_scanners: List[str], settings: Dict, text: str, fail_fast: bool = False ) -> (str, List[Dict[str, any]]): sanitized_prompt = text results = [] status_text = "Scanning prompt..." if fail_fast: status_text = "Scanning prompt (fail fast mode)..." with st.status(status_text, expanded=True) as status: for scanner_name in enabled_scanners: st.write(f"{scanner_name} scanner...") scanner = get_scanner(scanner_name, vault, settings[scanner_name]) start_time = time.monotonic() sanitized_prompt, is_valid, risk_score = scanner.scan(sanitized_prompt) end_time = time.monotonic() results.append( { "scanner": scanner_name, "is_valid": is_valid, "risk_score": risk_score, "took_sec": round(timedelta(seconds=end_time - start_time).total_seconds(), 2), } ) if fail_fast and not is_valid: break status.update(label="Scanning complete", state="complete", expanded=False) return sanitized_prompt, results