Spaces:
Sleeping
Sleeping
import traceback | |
from typing import Optional | |
from fastapi import HTTPException | |
import litellm | |
from litellm._logging import verbose_proxy_logger | |
from litellm.caching.caching import DualCache | |
from litellm.integrations.custom_logger import CustomLogger | |
from litellm.proxy._types import UserAPIKeyAuth | |
class _PROXY_AzureContentSafety( | |
CustomLogger | |
): # https://docs.litellm.ai/docs/observability/custom_callback#callback-class | |
# Class variables or attributes | |
def __init__(self, endpoint, api_key, thresholds=None): | |
try: | |
from azure.ai.contentsafety.aio import ContentSafetyClient | |
from azure.ai.contentsafety.models import ( | |
AnalyzeTextOptions, | |
AnalyzeTextOutputType, | |
TextCategory, | |
) | |
from azure.core.credentials import AzureKeyCredential | |
from azure.core.exceptions import HttpResponseError | |
except Exception as e: | |
raise Exception( | |
f"\033[91mAzure Content-Safety not installed, try running 'pip install azure-ai-contentsafety' to fix this error: {e}\n{traceback.format_exc()}\033[0m" | |
) | |
self.endpoint = endpoint | |
self.api_key = api_key | |
self.text_category = TextCategory | |
self.analyze_text_options = AnalyzeTextOptions | |
self.analyze_text_output_type = AnalyzeTextOutputType | |
self.azure_http_error = HttpResponseError | |
self.thresholds = self._configure_thresholds(thresholds) | |
self.client = ContentSafetyClient( | |
self.endpoint, AzureKeyCredential(self.api_key) | |
) | |
def _configure_thresholds(self, thresholds=None): | |
default_thresholds = { | |
self.text_category.HATE: 4, | |
self.text_category.SELF_HARM: 4, | |
self.text_category.SEXUAL: 4, | |
self.text_category.VIOLENCE: 4, | |
} | |
if thresholds is None: | |
return default_thresholds | |
for key, default in default_thresholds.items(): | |
if key not in thresholds: | |
thresholds[key] = default | |
return thresholds | |
def _compute_result(self, response): | |
result = {} | |
category_severity = { | |
item.category: item.severity for item in response.categories_analysis | |
} | |
for category in self.text_category: | |
severity = category_severity.get(category) | |
if severity is not None: | |
result[category] = { | |
"filtered": severity >= self.thresholds[category], | |
"severity": severity, | |
} | |
return result | |
async def test_violation(self, content: str, source: Optional[str] = None): | |
verbose_proxy_logger.debug("Testing Azure Content-Safety for: %s", content) | |
# Construct a request | |
request = self.analyze_text_options( | |
text=content, | |
output_type=self.analyze_text_output_type.EIGHT_SEVERITY_LEVELS, | |
) | |
# Analyze text | |
try: | |
response = await self.client.analyze_text(request) | |
except self.azure_http_error: | |
verbose_proxy_logger.debug( | |
"Error in Azure Content-Safety: %s", traceback.format_exc() | |
) | |
verbose_proxy_logger.debug(traceback.format_exc()) | |
raise | |
result = self._compute_result(response) | |
verbose_proxy_logger.debug("Azure Content-Safety Result: %s", result) | |
for key, value in result.items(): | |
if value["filtered"]: | |
raise HTTPException( | |
status_code=400, | |
detail={ | |
"error": "Violated content safety policy", | |
"source": source, | |
"category": key, | |
"severity": value["severity"], | |
}, | |
) | |
async def async_pre_call_hook( | |
self, | |
user_api_key_dict: UserAPIKeyAuth, | |
cache: DualCache, | |
data: dict, | |
call_type: str, # "completion", "embeddings", "image_generation", "moderation" | |
): | |
verbose_proxy_logger.debug("Inside Azure Content-Safety Pre-Call Hook") | |
try: | |
if call_type == "completion" and "messages" in data: | |
for m in data["messages"]: | |
if "content" in m and isinstance(m["content"], str): | |
await self.test_violation(content=m["content"], source="input") | |
except HTTPException as e: | |
raise e | |
except Exception as e: | |
verbose_proxy_logger.error( | |
"litellm.proxy.hooks.azure_content_safety.py::async_pre_call_hook(): Exception occured - {}".format( | |
str(e) | |
) | |
) | |
verbose_proxy_logger.debug(traceback.format_exc()) | |
async def async_post_call_success_hook( | |
self, | |
data: dict, | |
user_api_key_dict: UserAPIKeyAuth, | |
response, | |
): | |
verbose_proxy_logger.debug("Inside Azure Content-Safety Post-Call Hook") | |
if isinstance(response, litellm.ModelResponse) and isinstance( | |
response.choices[0], litellm.utils.Choices | |
): | |
await self.test_violation( | |
content=response.choices[0].message.content or "", source="output" | |
) | |
# async def async_post_call_streaming_hook( | |
# self, | |
# user_api_key_dict: UserAPIKeyAuth, | |
# response: str, | |
# ): | |
# verbose_proxy_logger.debug("Inside Azure Content-Safety Call-Stream Hook") | |
# await self.test_violation(content=response, source="output") | |