from fastapi import FastAPI, HTTPException from pydantic import BaseModel from transformers import AutoTokenizer, AutoModelForSequenceClassification import torch import uvicorn import logging from typing import Union # Set up logging logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger(__name__) app = FastAPI() # Load the model try: tokenizer = AutoTokenizer.from_pretrained("Yesser4/Last_version_phishing") model = AutoModelForSequenceClassification.from_pretrained("Yesser4/Last_version_phishing") except Exception as e: logger.error(f"Failed to load model or tokenizer: {str(e)}") raise Exception(f"Model initialization failed: {str(e)}") # Input models class EmailText(BaseModel): text: str class EmailInput(BaseModel): email_text: Union[str, EmailText] # Accept either a string or a dictionary with 'text' @app.post("/predict") async def predict_email(input: EmailInput): try: # Log the incoming input for debugging logger.debug(f"Received input: {input}") logger.debug(f"Type of email_text: {type(input.email_text)}") # Extract the string from email_text if isinstance(input.email_text, str): text = input.email_text.strip() else: # isinstance(input.email_text, EmailText) text = input.email_text.text.strip() if not text: logger.warning("Empty email text received") raise HTTPException(status_code=400, detail="Email text is empty") # Tokenize and predict inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512, padding=True) with torch.no_grad(): outputs = model(**inputs) probs = torch.nn.functional.softmax(outputs.logits, dim=-1)[0] response = { "prediction": "phishing" if probs[1] > probs[0] else "legitimate", "confidence": float(probs.max()), "probabilities": { "legitimate": float(probs[0]), "phishing": float(probs[1]) } } logger.debug(f"Prediction response: {response}") return response except Exception as e: logger.error(f"Error processing request: {str(e)}") raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=7860)