| import os |
| import re |
| import json |
| import base64 |
| import threading |
| from pathlib import Path |
| from typing import Any |
|
|
| import pycountry |
|
|
| |
| BASE_DIR = Path(".") |
| HF_TOKEN_PATH = BASE_DIR / "hf_token" |
| HF_TOKEN = HF_TOKEN_PATH.read_text(encoding="utf-8").strip() or None |
| if HF_TOKEN is not None: |
| from huggingface_hub import login |
| login(token=HF_TOKEN, add_to_git_credential=False) |
| HF_MODEL = os.environ.get("HF_MODEL", "google/gemma-4-E2B-it") |
| JAILBREAK_MODEL = os.environ.get("JAILBREAK_MODEL", "DerivedFunction1/xlmr-prompt-injection") |
| JAILBREAK_THRESHOLD = float(os.environ.get("JAILBREAK_THRESHOLD", "0.65")) |
| PROMPT_INJECTION_MODEL = os.environ.get( |
| "PROMPT_INJECTION_MODEL", "protectai/deberta-v3-base-prompt-injection-v2" |
| ) |
| REFUSAL_LANGUAGE_MODEL = os.environ.get( |
| "REFUSAL_LANGUAGE_MODEL", |
| "polyglot-tagger/multilabel-language-identification", |
| ) |
|
|
| SUPPORTED_GEMMA_LANGS = { |
| "EN", "ES", "FR", "DE", "IT", "PT", "NL", |
| "DA", "RU", "PL", |
| "ZH", "JA", "KO", "VI", |
| "HI", "BN", "TH", "ID", "MS", "MR", "TE", "TA", "GU", "PA", |
| "AR", "TR", "HE", "SW", |
| } |
|
|
| SUPPORTED_JAILBREAK_LANGS = { |
| "EN", |
| "AR", |
| "DE", |
| "ES", |
| "FR", |
| "HI", |
| "IT", |
| "JA", |
| "KO", |
| "NL", |
| "TH", |
| "ZH", |
| } |
|
|
| |
| from transformers import AutoProcessor, Gemma4ForConditionalGeneration, BitsAndBytesConfig, pipeline |
|
|
| |
| print(f"Loading model: {HF_MODEL}") |
| _processor = AutoProcessor.from_pretrained(HF_MODEL, padding_side="left") |
| _bnb_config = BitsAndBytesConfig( |
| load_in_8bit=True, |
| |
| ) |
| _model = Gemma4ForConditionalGeneration.from_pretrained( |
| HF_MODEL, |
| |
| device_map="auto", |
| ) |
|
|
| _GENERATION_CONFIG = { |
| "max_new_tokens": 8192, |
| "temperature": 1.2, |
| "do_sample": True, |
| "pad_token_id": _processor.tokenizer.eos_token_id, |
| } |
|
|
| print(f"Loading jailbreak detector: {JAILBREAK_MODEL}") |
| _jailbreak_pipe = pipeline("text-classification", model=JAILBREAK_MODEL) |
|
|
| print(f"Loading prompt injection detector: {PROMPT_INJECTION_MODEL}") |
| _prompt_injection_pipe = pipeline("text-classification", model=PROMPT_INJECTION_MODEL) |
|
|
| print(f"Loading refusal language detector: {REFUSAL_LANGUAGE_MODEL}") |
| _refusal_language_pipe = pipeline("text-classification", model=REFUSAL_LANGUAGE_MODEL) |
|
|
| |
| TOOL_CALL_RE = re.compile( |
| r"(?:<\|?tool_call\|?>|^)\s*" |
| r"(?:call:)?(?P<name>[a-zA-Z_][a-zA-Z0-9_\-\s]*?)\s*" |
| r"(?:\{|\()(?P<args>.*?)(?:\}|\))\s*" |
| r"(?P<close><\|?tool_call\|?>|<eos>|<end_of_turn>|<turn\|?>|</s>|$)", |
| re.DOTALL, |
| ) |
|
|
| TOOL_CALL_MARKUP_RE = re.compile( |
| r"<\|?tool_call\|?>.*?(?:<\|?tool_call\|?>|<eos>|$)", |
| re.DOTALL, |
| ) |
|
|
| TOOL_RESPONSE_RE = re.compile( |
| r"<\|?tool_response\|?>.*$", |
| re.DOTALL, |
| ) |
|
|
| CLEANUP_RE = re.compile( |
| r"(<\|?turn\|?>|<eos>|</s>|\[REDIRECT\])", |
| re.DOTALL, |
| ) |
|
|
| THOUGHT_BLOCK_RE = re.compile( |
| r"<\|?channel\|?>(?:thought\s*)?.*?(?:<channel\|>|$)", |
| re.DOTALL, |
| ) |
|
|
|
|
| QUOTES_RE = re.compile(r"<\|\"\|>") |
| TOOL_RESPONSE_MARKERS_RE = re.compile(r"<\|?tool_response\|?>", re.DOTALL) |
| MALFORMED_TOOL_TAIL_RE = re.compile(r"(<\|?tool_call(?:\|)?$|<\|?$|<\|?\?$)") |
|
|
|
|
| def _strip_tool_call_markup(text: str) -> str: |
| cleaned = (text or "").replace("\r", "").strip() |
| if not cleaned: |
| return "" |
|
|
| cleaned = QUOTES_RE.sub('"', cleaned) |
| cleaned = THOUGHT_BLOCK_RE.sub("", cleaned) |
| cleaned = TOOL_CALL_MARKUP_RE.sub("", cleaned) |
| cleaned = TOOL_RESPONSE_RE.sub("", cleaned) |
| |
| cleaned = CLEANUP_RE.sub("", cleaned) |
| return cleaned.strip() |
|
|
|
|
| def _clean_tool_text(text: str) -> str: |
| cleaned = _strip_tool_call_markup(text) |
| if not cleaned: |
| return "" |
| cleaned = TOOL_RESPONSE_MARKERS_RE.sub("", cleaned) |
| return cleaned.strip() |
|
|
|
|
| def _strip_trailing_malformed_tool_tokens(text: str) -> str: |
| cleaned = (text or "").strip() |
| while cleaned: |
| if MALFORMED_TOOL_TAIL_RE.search(cleaned): |
| cleaned = cleaned[:-1].rstrip() |
| continue |
| break |
| return cleaned |
|
|
|
|
| def _clean_language_detector_text(text: str) -> str: |
| cleaned = [] |
| for ch in str(text or ""): |
| if ch.isalpha() or ch.isspace(): |
| cleaned.append(ch) |
| else: |
| cleaned.append(" ") |
| return " ".join("".join(cleaned).split()) |
|
|
|
|
| def detect_jailbreak(text: str) -> dict: |
| """Return detector metadata for a user message.""" |
| result = _jailbreak_pipe(text, truncation=True, max_length=512)[0] |
| label = str(result.get("label", "")).lower() |
| score = float(result.get("score", 0.0)) |
| unsafe_score = score if label == "unsafe" else (1.0 - score if label == "safe" else score) |
|
|
| return { |
| "score": unsafe_score, |
| "blocked": unsafe_score >= JAILBREAK_THRESHOLD, |
| "predicted_label": label, |
| } |
|
|
|
|
| def detect_prompt_injection(text: str) -> dict: |
| """Return detector metadata for a user message using the prompt injection model.""" |
| result = _prompt_injection_pipe(text, truncation=True, max_length=512)[0] |
| label = str(result.get("label", "")).lower() |
| score = float(result.get("score", 0.0)) |
| |
| unsafe_score = ( |
| score if label.lower() == "injection" else (1.0 - score if label == "safe" else score) |
| ) |
|
|
| return { |
| "score": unsafe_score, |
| "blocked": unsafe_score >= JAILBREAK_THRESHOLD, |
| "predicted_label": label, |
| } |
|
|
| def detect_refusal_language(text: str) -> str: |
| cleaned_text = _clean_language_detector_text(text) |
| result = _refusal_language_pipe(cleaned_text, truncation=True, max_length=512)[0] |
| label = str(result.get("label", "")).upper().strip() |
| normalized = _normalize_language_label(label) |
| if normalized in SUPPORTED_GEMMA_LANGS: |
| return normalized |
| return "EN" |
|
|
|
|
| def detect_preferred_language(text: str) -> str: |
| cleaned_text = _clean_language_detector_text(text) |
| result = _refusal_language_pipe(cleaned_text, truncation=True, max_length=512)[0] |
| label = str(result.get("label", "")).upper().strip() |
| normalized = _normalize_language_label(label) |
| return normalized or "EN" |
|
|
|
|
| def _normalize_language_label(label: str) -> str: |
| cleaned = str(label or "").strip() |
| if not cleaned: |
| return "" |
| upper = cleaned.upper() |
| if upper in SUPPORTED_GEMMA_LANGS: |
| return upper |
|
|
| lowered = cleaned.lower() |
| lang = pycountry.languages.get(alpha_2=lowered) |
| if lang is None and len(lowered) == 3: |
| lang = pycountry.languages.get(alpha_3=lowered) |
| if lang is None: |
| try: |
| lang = pycountry.languages.lookup(cleaned) |
| except LookupError: |
| lang = None |
| if lang is None: |
| return upper |
|
|
| alpha_2 = getattr(lang, "alpha_2", None) |
| if alpha_2: |
| return str(alpha_2).upper() |
| alpha_3 = getattr(lang, "alpha_3", None) |
| if alpha_3: |
| return str(alpha_3).upper() |
| return upper |
|
|
|
|
| def _sanitize_display_text(text: str, system_prompt: str | None = None) -> str: |
| cleaned = _clean_tool_text(text) |
| if not cleaned: |
| return "" |
| |
| try: |
| parsed_json = json.loads(cleaned) |
| if ( |
| isinstance(parsed_json, list) |
| and len(parsed_json) > 0 |
| and isinstance(parsed_json[0], dict) |
| and "text" in parsed_json[0] |
| ): |
| return parsed_json[0]["text"].strip() |
| except json.JSONDecodeError: |
| pass |
|
|
| return cleaned.strip() |
|
|
|
|
| |
| |
| from bob_resources import ( |
| connect, |
| validate, |
| skip, |
| clarify_intent, |
| store_policy, |
| store_information, |
| store_app_website, |
| food_safety_endpoint, |
| legal_endpoint, |
| emergency_crisis, |
| apply_discount, |
| loyalty_program, |
| competitor_mentions, |
| take_order |
| ) |
|
|
| def generate_response( |
| messages: list, |
| system_prompt: str, |
| enable_thinking: bool = False, |
| ) -> str: |
| full = [{"role": "system", "content": system_prompt}] + messages |
| full.append({"role": "assistant", "content": ""}) |
| inputs = _processor.apply_chat_template( |
| full, |
| tools=[connect, validate, skip, clarify_intent, store_policy, |
| store_information, store_app_website, food_safety_endpoint, legal_endpoint, |
| emergency_crisis, apply_discount, loyalty_program, competitor_mentions, take_order], |
| tokenize=True, |
| return_dict=True, |
| return_tensors="pt", |
| add_generation_prompt=True, |
| enable_thinking=enable_thinking, |
| ).to(_model.device) |
| with __import__("torch").no_grad(): |
| out = _model.generate( |
| **inputs, |
| **_GENERATION_CONFIG, |
| ) |
| new_tokens = out[0][inputs["input_ids"].shape[1]:] |
| return _processor.decode(new_tokens, skip_special_tokens=True).strip() |
|
|
|
|
| def generate_response_stream( |
| messages: list, |
| system_prompt: str, |
| enable_thinking: bool = False, |
| ): |
| full = [{"role": "system", "content": system_prompt}] + messages |
| inputs = _processor.apply_chat_template( |
| full, |
| tools=[connect, validate, skip, clarify_intent, store_policy, |
| store_information, store_app_website, food_safety_endpoint, legal_endpoint, |
| emergency_crisis, apply_discount, loyalty_program, competitor_mentions, take_order], |
| tokenize=True, |
| return_dict=True, |
| return_tensors="pt", |
| add_generation_prompt=True, |
| enable_thinking=enable_thinking, |
| ).to(_model.device) |
|
|
| from transformers import TextIteratorStreamer |
|
|
| streamer = TextIteratorStreamer(_processor.tokenizer, skip_prompt=True, skip_special_tokens=False) |
| thread = threading.Thread( |
| target=_model.generate, |
| kwargs={ |
| **inputs, |
| **_GENERATION_CONFIG, |
| "streamer": streamer, |
| }, |
| daemon=True, |
| ) |
| thread.start() |
| generated = "" |
| for chunk in streamer: |
| generated += chunk |
| yield chunk |
| thread.join() |
|
|