FastAPI / app.py
Yesser4's picture
Upload 2 files
2f806a7 verified
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import uvicorn
import logging
from typing import Union
# Set up logging
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
app = FastAPI()
# Load the model
try:
tokenizer = AutoTokenizer.from_pretrained("Yesser4/Last_version_phishing")
model = AutoModelForSequenceClassification.from_pretrained("Yesser4/Last_version_phishing")
except Exception as e:
logger.error(f"Failed to load model or tokenizer: {str(e)}")
raise Exception(f"Model initialization failed: {str(e)}")
# Input models
class EmailText(BaseModel):
text: str
class EmailInput(BaseModel):
email_text: Union[str, EmailText] # Accept either a string or a dictionary with 'text'
@app.post("/predict")
async def predict_email(input: EmailInput):
try:
# Log the incoming input for debugging
logger.debug(f"Received input: {input}")
logger.debug(f"Type of email_text: {type(input.email_text)}")
# Extract the string from email_text
if isinstance(input.email_text, str):
text = input.email_text.strip()
else: # isinstance(input.email_text, EmailText)
text = input.email_text.text.strip()
if not text:
logger.warning("Empty email text received")
raise HTTPException(status_code=400, detail="Email text is empty")
# Tokenize and predict
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512, padding=True)
with torch.no_grad():
outputs = model(**inputs)
probs = torch.nn.functional.softmax(outputs.logits, dim=-1)[0]
response = {
"prediction": "phishing" if probs[1] > probs[0] else "legitimate",
"confidence": float(probs.max()),
"probabilities": {
"legitimate": float(probs[0]),
"phishing": float(probs[1])
}
}
logger.debug(f"Prediction response: {response}")
return response
except Exception as e:
logger.error(f"Error processing request: {str(e)}")
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)