Spaces:
Sleeping
Sleeping
dylanglenister
commited on
Commit
·
6d1027d
1
Parent(s):
47e3582
CHORE: Removing typing import
Browse files- src/services/guard.py +9 -10
src/services/guard.py
CHANGED
|
@@ -2,7 +2,6 @@
|
|
| 2 |
|
| 3 |
import os
|
| 4 |
import re
|
| 5 |
-
from typing import Dict, List, Tuple
|
| 6 |
|
| 7 |
import requests
|
| 8 |
|
|
@@ -33,7 +32,7 @@ class SafetyGuard:
|
|
| 33 |
self.fail_open = settings.SAFETY_GUARD_FAIL_OPEN
|
| 34 |
|
| 35 |
@staticmethod
|
| 36 |
-
def _chunk_text(text: str, chunk_size: int = 2800, overlap: int = 200) ->
|
| 37 |
"""Chunk long text to keep request payloads small enough for the guard.
|
| 38 |
Uses character-based approximation with small overlap.
|
| 39 |
"""
|
|
@@ -42,7 +41,7 @@ class SafetyGuard:
|
|
| 42 |
n = len(text)
|
| 43 |
if n <= chunk_size:
|
| 44 |
return [text]
|
| 45 |
-
chunks:
|
| 46 |
start = 0
|
| 47 |
while start < n:
|
| 48 |
end = min(start + chunk_size, n)
|
|
@@ -52,7 +51,7 @@ class SafetyGuard:
|
|
| 52 |
start = max(0, end - overlap)
|
| 53 |
return chunks
|
| 54 |
|
| 55 |
-
def _call_guard(self, messages:
|
| 56 |
# Enhance messages with medical context if detected
|
| 57 |
enhanced_messages = self._enhance_messages_with_context(messages)
|
| 58 |
|
|
@@ -129,7 +128,7 @@ class SafetyGuard:
|
|
| 129 |
return ""
|
| 130 |
|
| 131 |
@staticmethod
|
| 132 |
-
def _parse_guard_reply(text: str) ->
|
| 133 |
"""Parse guard reply; expect 'SAFE' or 'UNSAFE: <reason>' (case-insensitive)."""
|
| 134 |
if not text:
|
| 135 |
# Fail-open: treat as SAFE if guard unavailable to avoid false blocks
|
|
@@ -231,7 +230,7 @@ class SafetyGuard:
|
|
| 231 |
|
| 232 |
return False
|
| 233 |
|
| 234 |
-
def check_user_query(self, user_query: str) ->
|
| 235 |
"""Validate the user query is safe to process with medical context awareness."""
|
| 236 |
if not self.enabled:
|
| 237 |
logger().info("[SafetyGuard] Safety guard disabled, allowing query through")
|
|
@@ -253,7 +252,7 @@ class SafetyGuard:
|
|
| 253 |
return False, reason
|
| 254 |
return True, ""
|
| 255 |
|
| 256 |
-
def _detect_harmful_content(self, text: str) ->
|
| 257 |
"""Detect harmful content using sophisticated pattern matching."""
|
| 258 |
if not text:
|
| 259 |
return True, ""
|
|
@@ -339,7 +338,7 @@ class SafetyGuard:
|
|
| 339 |
|
| 340 |
return True, ""
|
| 341 |
|
| 342 |
-
def _enhance_messages_with_context(self, messages:
|
| 343 |
"""Enhance messages with medical context awareness for better guard performance."""
|
| 344 |
if not messages:
|
| 345 |
return messages
|
|
@@ -376,7 +375,7 @@ Please evaluate this message for safety, keeping in mind that:
|
|
| 376 |
|
| 377 |
return messages
|
| 378 |
|
| 379 |
-
def _assess_risk_level(self, text: str) ->
|
| 380 |
"""Assess the risk level of content using multiple indicators."""
|
| 381 |
if not text:
|
| 382 |
return "low", 0.0
|
|
@@ -453,7 +452,7 @@ Please evaluate this message for safety, keeping in mind that:
|
|
| 453 |
else:
|
| 454 |
return "low", risk_score
|
| 455 |
|
| 456 |
-
def check_model_answer(self, user_query: str, model_answer: str) ->
|
| 457 |
"""Validate the model's answer is safe with medical context awareness."""
|
| 458 |
if not self.enabled:
|
| 459 |
logger().info("[SafetyGuard] Safety guard disabled, allowing response through")
|
|
|
|
| 2 |
|
| 3 |
import os
|
| 4 |
import re
|
|
|
|
| 5 |
|
| 6 |
import requests
|
| 7 |
|
|
|
|
| 32 |
self.fail_open = settings.SAFETY_GUARD_FAIL_OPEN
|
| 33 |
|
| 34 |
@staticmethod
|
| 35 |
+
def _chunk_text(text: str, chunk_size: int = 2800, overlap: int = 200) -> list[str]:
|
| 36 |
"""Chunk long text to keep request payloads small enough for the guard.
|
| 37 |
Uses character-based approximation with small overlap.
|
| 38 |
"""
|
|
|
|
| 41 |
n = len(text)
|
| 42 |
if n <= chunk_size:
|
| 43 |
return [text]
|
| 44 |
+
chunks: list[str] = []
|
| 45 |
start = 0
|
| 46 |
while start < n:
|
| 47 |
end = min(start + chunk_size, n)
|
|
|
|
| 51 |
start = max(0, end - overlap)
|
| 52 |
return chunks
|
| 53 |
|
| 54 |
+
def _call_guard(self, messages: list[dict], max_tokens: int = 512) -> str:
|
| 55 |
# Enhance messages with medical context if detected
|
| 56 |
enhanced_messages = self._enhance_messages_with_context(messages)
|
| 57 |
|
|
|
|
| 128 |
return ""
|
| 129 |
|
| 130 |
@staticmethod
|
| 131 |
+
def _parse_guard_reply(text: str) -> tuple[bool, str]:
|
| 132 |
"""Parse guard reply; expect 'SAFE' or 'UNSAFE: <reason>' (case-insensitive)."""
|
| 133 |
if not text:
|
| 134 |
# Fail-open: treat as SAFE if guard unavailable to avoid false blocks
|
|
|
|
| 230 |
|
| 231 |
return False
|
| 232 |
|
| 233 |
+
def check_user_query(self, user_query: str) -> tuple[bool, str]:
|
| 234 |
"""Validate the user query is safe to process with medical context awareness."""
|
| 235 |
if not self.enabled:
|
| 236 |
logger().info("[SafetyGuard] Safety guard disabled, allowing query through")
|
|
|
|
| 252 |
return False, reason
|
| 253 |
return True, ""
|
| 254 |
|
| 255 |
+
def _detect_harmful_content(self, text: str) -> tuple[bool, str]:
|
| 256 |
"""Detect harmful content using sophisticated pattern matching."""
|
| 257 |
if not text:
|
| 258 |
return True, ""
|
|
|
|
| 338 |
|
| 339 |
return True, ""
|
| 340 |
|
| 341 |
+
def _enhance_messages_with_context(self, messages: list[dict]) -> list[dict]:
|
| 342 |
"""Enhance messages with medical context awareness for better guard performance."""
|
| 343 |
if not messages:
|
| 344 |
return messages
|
|
|
|
| 375 |
|
| 376 |
return messages
|
| 377 |
|
| 378 |
+
def _assess_risk_level(self, text: str) -> tuple[str, float]:
|
| 379 |
"""Assess the risk level of content using multiple indicators."""
|
| 380 |
if not text:
|
| 381 |
return "low", 0.0
|
|
|
|
| 452 |
else:
|
| 453 |
return "low", risk_score
|
| 454 |
|
| 455 |
+
def check_model_answer(self, user_query: str, model_answer: str) -> tuple[bool, str]:
|
| 456 |
"""Validate the model's answer is safe with medical context awareness."""
|
| 457 |
if not self.enabled:
|
| 458 |
logger().info("[SafetyGuard] Safety guard disabled, allowing response through")
|