| from fastapi import FastAPI, HTTPException |
| from pydantic import BaseModel |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification |
| import torch |
| import uvicorn |
| import logging |
| from typing import Union |
|
|
| |
| logging.basicConfig(level=logging.DEBUG) |
| logger = logging.getLogger(__name__) |
|
|
| app = FastAPI() |
|
|
| |
| try: |
| tokenizer = AutoTokenizer.from_pretrained("Yesser4/Last_version_phishing") |
| model = AutoModelForSequenceClassification.from_pretrained("Yesser4/Last_version_phishing") |
| except Exception as e: |
| logger.error(f"Failed to load model or tokenizer: {str(e)}") |
| raise Exception(f"Model initialization failed: {str(e)}") |
|
|
| |
| class EmailText(BaseModel): |
| text: str |
|
|
| class EmailInput(BaseModel): |
| email_text: Union[str, EmailText] |
|
|
| @app.post("/predict") |
| async def predict_email(input: EmailInput): |
| try: |
| |
| logger.debug(f"Received input: {input}") |
| logger.debug(f"Type of email_text: {type(input.email_text)}") |
|
|
| |
| if isinstance(input.email_text, str): |
| text = input.email_text.strip() |
| else: |
| text = input.email_text.text.strip() |
|
|
| if not text: |
| logger.warning("Empty email text received") |
| raise HTTPException(status_code=400, detail="Email text is empty") |
|
|
| |
| inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512, padding=True) |
| with torch.no_grad(): |
| outputs = model(**inputs) |
| probs = torch.nn.functional.softmax(outputs.logits, dim=-1)[0] |
|
|
| response = { |
| "prediction": "phishing" if probs[1] > probs[0] else "legitimate", |
| "confidence": float(probs.max()), |
| "probabilities": { |
| "legitimate": float(probs[0]), |
| "phishing": float(probs[1]) |
| } |
| } |
| logger.debug(f"Prediction response: {response}") |
| return response |
|
|
| except Exception as e: |
| logger.error(f"Error processing request: {str(e)}") |
| raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") |
|
|
| if __name__ == "__main__": |
| uvicorn.run(app, host="0.0.0.0", port=7860) |