from fastapi import FastAPI, HTTPException, File, UploadFile from fastapi.responses import JSONResponse from io import BytesIO from PIL import Image import torch from torchvision import transforms import os from .model import MalwareNet, malware_classes app = FastAPI() def preprocess_image(image_data): image = Image.open(BytesIO(image_data)).convert("RGB") preprocess = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) return preprocess(image).unsqueeze(0) def load_model(): model = MalwareNet() base_dir = os.path.dirname(os.path.abspath(__file__)) model_location = os.path.join(base_dir, '../model/malwareNet.pt') state_dict = torch.load(model_location, map_location=torch.device('cpu'), weights_only=True) model.load_state_dict(state_dict) model.eval() return model @app.post("/predict") async def predict(file: UploadFile = File(...)): try: # Read file bytes image_data = await file.read() # Preprocess the image img_tensor = preprocess_image(image_data) # Load the model and make the prediction model = load_model() with torch.no_grad(): prediction = model(img_tensor) # Get the predicted class predicted_class = malware_classes[torch.argmax(prediction).item()] return JSONResponse(content={"prediction": predicted_class}) except Exception as e: raise HTTPException(status_code=500, detail=f"Error processing the image: {e}") if __name__ == "__main__": import uvicorn uvicorn.run( "src.serve:app", host=os.environ.get("HOST", "localhost"), port=int(os.environ.get("PORT", 5000)), reload=True, )