safeguard-api / app.py
dez2work9876's picture
Update app.py
a3a4463 verified
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import BertTokenizer, BertForSequenceClassification
import torch
import uvicorn # <--- THIS IS THE FIX
# --- Configuration ---
MODEL_PATH = "./BERT_Bullying_Detector_Model"
device = torch.device("cpu")
# --- Data Model for Input ---
class TextInput(BaseModel):
text: str
# --- Load Model and Tokenizer ---
model = None
tokenizer = None
try:
model = BertForSequenceClassification.from_pretrained(MODEL_PATH)
tokenizer = BertTokenizer.from_pretrained(MODEL_PATH)
model.to(device)
model.eval()
print("Model and tokenizer loaded successfully.")
except Exception as e:
print(f"Error loading model: {e}")
# --- Create FastAPI App ---
app = FastAPI()
@app.get("/")
def read_root():
return {"status": "API is running"}
@app.post("/predict")
def predict_toxicity(input_data: TextInput):
if not model or not tokenizer:
return {"error": "Model not loaded."}
text = input_data.text
try:
encoding = tokenizer.encode_plus(
text,
add_special_tokens=True,
max_length=128,
return_token_type_ids=False,
padding='max_length',
return_tensors='pt',
truncation=True
)
input_ids = encoding['input_ids'].to(device)
attention_mask = encoding['attention_mask'].to(device)
with torch.no_grad():
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
_, prediction = torch.max(outputs.logits, dim=1)
label_map = {0: "Not Bullying", 1: "Bullying"}
result_label = label_map[prediction.item()]
return {"label": result_label, "score": prediction.item()}
except Exception as e:
return {"error": str(e)}
# --- Start the Server ---
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)