File size: 2,400 Bytes
2f806a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import uvicorn
import logging
from typing import Union

# Set up logging
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)

app = FastAPI()

# Load the model
try:
    tokenizer = AutoTokenizer.from_pretrained("Yesser4/Last_version_phishing")
    model = AutoModelForSequenceClassification.from_pretrained("Yesser4/Last_version_phishing")
except Exception as e:
    logger.error(f"Failed to load model or tokenizer: {str(e)}")
    raise Exception(f"Model initialization failed: {str(e)}")

# Input models
class EmailText(BaseModel):
    text: str

class EmailInput(BaseModel):
    email_text: Union[str, EmailText]  # Accept either a string or a dictionary with 'text'

@app.post("/predict")
async def predict_email(input: EmailInput):
    try:
        # Log the incoming input for debugging
        logger.debug(f"Received input: {input}")
        logger.debug(f"Type of email_text: {type(input.email_text)}")

        # Extract the string from email_text
        if isinstance(input.email_text, str):
            text = input.email_text.strip()
        else:  # isinstance(input.email_text, EmailText)
            text = input.email_text.text.strip()

        if not text:
            logger.warning("Empty email text received")
            raise HTTPException(status_code=400, detail="Email text is empty")

        # Tokenize and predict
        inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512, padding=True)
        with torch.no_grad():
            outputs = model(**inputs)
            probs = torch.nn.functional.softmax(outputs.logits, dim=-1)[0]

        response = {
            "prediction": "phishing" if probs[1] > probs[0] else "legitimate",
            "confidence": float(probs.max()),
            "probabilities": {
                "legitimate": float(probs[0]),
                "phishing": float(probs[1])
            }
        }
        logger.debug(f"Prediction response: {response}")
        return response

    except Exception as e:
        logger.error(f"Error processing request: {str(e)}")
        raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=7860)