Spaces:
Sleeping
Sleeping
File size: 5,566 Bytes
f228a1c 99351b6 338f4c1 c311b0d f228a1c 338f4c1 f228a1c 14c8502 f228a1c 14c8502 99351b6 f43f094 5c19b8d 4fa87d4 5c19b8d 4fa87d4 5c19b8d 4fa87d4 5c19b8d c311b0d 5c19b8d 99351b6 338f4c1 99351b6 338f4c1 f43f094 99351b6 f228a1c bc21776 338f4c1 f228a1c 338f4c1 14c8502 f228a1c 5c19b8d 4fa87d4 5c19b8d 4fa87d4 5c19b8d c311b0d f228a1c 5c19b8d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
import torch
from detoxify import Detoxify
import asyncio
from fastapi.concurrency import run_in_threadpool
from typing import List, Optional
class Guardrail:
def __init__(self):
tokenizer = AutoTokenizer.from_pretrained("ProtectAI/deberta-v3-base-prompt-injection")
model = AutoModelForSequenceClassification.from_pretrained("ProtectAI/deberta-v3-base-prompt-injection")
self.classifier = pipeline(
"text-classification",
model=model,
tokenizer=tokenizer,
truncation=True,
max_length=512,
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
)
async def guard(self, prompt):
return await run_in_threadpool(self.classifier, prompt)
def determine_level(self, label, score):
if label == "SAFE":
return 0, "safe"
else:
if score > 0.9:
return 4, "high"
elif score > 0.75:
return 3, "medium"
elif score > 0.5:
return 2, "low"
else:
return 1, "very low"
class TextPrompt(BaseModel):
prompt: str
class ClassificationResult(BaseModel):
label: str
score: float
level: int
severity_label: str
class ToxicityResult(BaseModel):
toxicity: float
severe_toxicity: float
obscene: float
threat: float
insult: float
identity_attack: float
@classmethod
def from_dict(cls, data: dict):
return cls(**{k: float(v) for k, v in data.items()})
class TopicBannerClassifier:
def __init__(self):
self.classifier = pipeline(
"zero-shot-classification",
model="MoritzLaurer/deberta-v3-large-zeroshot-v2.0",
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
)
self.hypothesis_template = "This text is about {}"
async def classify(self, text, labels):
return await run_in_threadpool(
self.classifier,
text,
labels,
hypothesis_template=self.hypothesis_template,
multi_label=False
)
class TopicBannerRequest(BaseModel):
prompt: str
labels: List[str]
class TopicBannerResult(BaseModel):
sequence: str
labels: list
scores: list
class GuardrailsRequest(BaseModel):
prompt: str
guardrails: List[str]
labels: Optional[List[str]] = None
class GuardrailsResponse(BaseModel):
prompt_injection: Optional[ClassificationResult] = None
toxicity: Optional[ToxicityResult] = None
topic_banner: Optional[TopicBannerResult] = None
app = FastAPI()
guardrail = Guardrail()
toxicity_classifier = Detoxify('original')
topic_banner_classifier = TopicBannerClassifier()
@app.post("/api/models/toxicity/classify", response_model=ToxicityResult)
async def classify_toxicity(text_prompt: TextPrompt):
try:
result = await run_in_threadpool(toxicity_classifier.predict, text_prompt.prompt)
return ToxicityResult.from_dict(result)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/models/PromptInjection/classify", response_model=ClassificationResult)
async def classify_text(text_prompt: TextPrompt):
try:
result = await guardrail.guard(text_prompt.prompt)
label = result[0]['label']
score = result[0]['score']
level, severity_label = guardrail.determine_level(label, score)
return {"label": label, "score": score, "level": level, "severity_label": severity_label}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/models/TopicBanner/classify", response_model=TopicBannerResult)
async def classify_topic_banner(request: TopicBannerRequest):
try:
result = await topic_banner_classifier.classify(request.prompt, request.labels)
return {
"sequence": result["sequence"],
"labels": result["labels"],
"scores": result["scores"]
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/guardrails", response_model=GuardrailsResponse)
async def evaluate_guardrails(request: GuardrailsRequest):
tasks = []
response = GuardrailsResponse()
if "pi" in request.guardrails:
tasks.append(classify_text(TextPrompt(prompt=request.prompt)))
if "tox" in request.guardrails:
tasks.append(classify_toxicity(TextPrompt(prompt=request.prompt)))
if "top" in request.guardrails:
if not request.labels:
raise HTTPException(status_code=400, detail="Labels are required for topic banner classification")
tasks.append(classify_topic_banner(TopicBannerRequest(prompt=request.prompt, labels=request.labels)))
results = await asyncio.gather(*tasks, return_exceptions=True)
for result, guardrail in zip(results, request.guardrails):
if isinstance(result, Exception):
# Handle the exception as needed
continue
if guardrail == "pi":
response.prompt_injection = result
elif guardrail == "tox":
response.toxicity = result
elif guardrail == "top":
response.topic_banner = result
return response
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000) |