File size: 1,836 Bytes
2275a4b
 
 
98d74fb
 
2275a4b
98d74fb
2275a4b
98d74fb
 
 
2275a4b
 
98d74fb
2275a4b
 
 
98d74fb
2275a4b
98d74fb
 
 
 
2275a4b
98d74fb
 
2275a4b
98d74fb
 
 
2275a4b
98d74fb
2275a4b
 
 
 
 
98d74fb
 
 
2275a4b
98d74fb
 
 
 
 
2275a4b
98d74fb
 
 
2275a4b
98d74fb
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from fastapi import FastAPI, HTTPException, File, UploadFile
from fastapi.responses import JSONResponse
from io import BytesIO
from PIL import Image
import torch
from torchvision import transforms
import os
from .model import MalwareNet, malware_classes

app = FastAPI()

def preprocess_image(image_data):
    image = Image.open(BytesIO(image_data)).convert("RGB")
    preprocess = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    return preprocess(image).unsqueeze(0)

def load_model():
    model = MalwareNet()
    base_dir = os.path.dirname(os.path.abspath(__file__))
    model_location = os.path.join(base_dir, '../model/malwareNet.pt')
    state_dict = torch.load(model_location, map_location=torch.device('cpu'), weights_only=True)
    model.load_state_dict(state_dict)
    model.eval()
    return model

@app.post("/predict")
async def predict(file: UploadFile = File(...)):
    try:
        # Read file bytes
        image_data = await file.read()

        # Preprocess the image
        img_tensor = preprocess_image(image_data)

        # Load the model and make the prediction
        model = load_model()
        with torch.no_grad():
            prediction = model(img_tensor)

        # Get the predicted class
        predicted_class = malware_classes[torch.argmax(prediction).item()]

        return JSONResponse(content={"prediction": predicted_class})
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Error processing the image: {e}")


if __name__ == "__main__":
    import uvicorn
    uvicorn.run(
        "src.serve:app",
        host=os.environ.get("HOST", "localhost"),
        port=int(os.environ.get("PORT", 5000)),
        reload=True,
    )