AISimplyExplained commited on
Commit
338f4c1
1 Parent(s): bd1a56f
Files changed (1) hide show
  1. main.py +8 -14
main.py CHANGED
@@ -3,8 +3,8 @@ from pydantic import BaseModel
3
  from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
4
  import torch
5
  from detoxify import Detoxify
6
-
7
-
8
 
9
  class Guardrail:
10
  def __init__(self):
@@ -20,8 +20,8 @@ class Guardrail:
20
  device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
  )
22
 
23
- def guard(self, prompt):
24
- return self.classifier(prompt)
25
 
26
  def determine_level(self, label, score):
27
  if label == "SAFE":
@@ -36,18 +36,15 @@ class Guardrail:
36
  else:
37
  return 1, "very low"
38
 
39
-
40
  class TextPrompt(BaseModel):
41
  prompt: str
42
 
43
-
44
  class ClassificationResult(BaseModel):
45
  label: str
46
  score: float
47
  level: int
48
  severity_label: str
49
 
50
-
51
  app = FastAPI()
52
  guardrail = Guardrail()
53
  toxicity_classifier = Detoxify('original')
@@ -61,9 +58,9 @@ class ToxicityResult(BaseModel):
61
  identity_attack: float
62
 
63
  @app.post("/api/models/toxicity/classify", response_model=ToxicityResult)
64
- def classify_toxicity(text_prompt: TextPrompt):
65
  try:
66
- result = toxicity_classifier.predict(text_prompt.prompt)
67
  return {
68
  "toxicity": result['toxicity'],
69
  "severe_toxicity": result['severe_toxicity'],
@@ -75,11 +72,10 @@ def classify_toxicity(text_prompt: TextPrompt):
75
  except Exception as e:
76
  raise HTTPException(status_code=500, detail=str(e))
77
 
78
-
79
  @app.post("/api/models/PromptInjection/classify", response_model=ClassificationResult)
80
- def classify_text(text_prompt: TextPrompt):
81
  try:
82
- result = guardrail.guard(text_prompt.prompt)
83
  label = result[0]['label']
84
  score = result[0]['score']
85
  level, severity_label = guardrail.determine_level(label, score)
@@ -87,8 +83,6 @@ def classify_text(text_prompt: TextPrompt):
87
  except Exception as e:
88
  raise HTTPException(status_code=500, detail=str(e))
89
 
90
-
91
  if __name__ == "__main__":
92
  import uvicorn
93
-
94
  uvicorn.run(app, host="0.0.0.0", port=8000)
 
3
  from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
4
  import torch
5
  from detoxify import Detoxify
6
+ import asyncio
7
+ from fastapi.concurrency import run_in_threadpool
8
 
9
  class Guardrail:
10
  def __init__(self):
 
20
  device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
  )
22
 
23
+ async def guard(self, prompt):
24
+ return await run_in_threadpool(self.classifier, prompt)
25
 
26
  def determine_level(self, label, score):
27
  if label == "SAFE":
 
36
  else:
37
  return 1, "very low"
38
 
 
39
  class TextPrompt(BaseModel):
40
  prompt: str
41
 
 
42
  class ClassificationResult(BaseModel):
43
  label: str
44
  score: float
45
  level: int
46
  severity_label: str
47
 
 
48
  app = FastAPI()
49
  guardrail = Guardrail()
50
  toxicity_classifier = Detoxify('original')
 
58
  identity_attack: float
59
 
60
  @app.post("/api/models/toxicity/classify", response_model=ToxicityResult)
61
+ async def classify_toxicity(text_prompt: TextPrompt):
62
  try:
63
+ result = await run_in_threadpool(toxicity_classifier.predict, text_prompt.prompt)
64
  return {
65
  "toxicity": result['toxicity'],
66
  "severe_toxicity": result['severe_toxicity'],
 
72
  except Exception as e:
73
  raise HTTPException(status_code=500, detail=str(e))
74
 
 
75
  @app.post("/api/models/PromptInjection/classify", response_model=ClassificationResult)
76
+ async def classify_text(text_prompt: TextPrompt):
77
  try:
78
+ result = await guardrail.guard(text_prompt.prompt)
79
  label = result[0]['label']
80
  score = result[0]['score']
81
  level, severity_label = guardrail.determine_level(label, score)
 
83
  except Exception as e:
84
  raise HTTPException(status_code=500, detail=str(e))
85
 
 
86
  if __name__ == "__main__":
87
  import uvicorn
 
88
  uvicorn.run(app, host="0.0.0.0", port=8000)