z4hid's picture
api code changed
2275a4b verified
from fastapi import FastAPI, HTTPException, File, UploadFile
from fastapi.responses import JSONResponse
from io import BytesIO
from PIL import Image
import torch
from torchvision import transforms
import os
from .model import MalwareNet, malware_classes
app = FastAPI()
def preprocess_image(image_data):
image = Image.open(BytesIO(image_data)).convert("RGB")
preprocess = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
return preprocess(image).unsqueeze(0)
def load_model():
model = MalwareNet()
base_dir = os.path.dirname(os.path.abspath(__file__))
model_location = os.path.join(base_dir, '../model/malwareNet.pt')
state_dict = torch.load(model_location, map_location=torch.device('cpu'), weights_only=True)
model.load_state_dict(state_dict)
model.eval()
return model
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
try:
# Read file bytes
image_data = await file.read()
# Preprocess the image
img_tensor = preprocess_image(image_data)
# Load the model and make the prediction
model = load_model()
with torch.no_grad():
prediction = model(img_tensor)
# Get the predicted class
predicted_class = malware_classes[torch.argmax(prediction).item()]
return JSONResponse(content={"prediction": predicted_class})
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error processing the image: {e}")
if __name__ == "__main__":
import uvicorn
uvicorn.run(
"src.serve:app",
host=os.environ.get("HOST", "localhost"),
port=int(os.environ.get("PORT", 5000)),
reload=True,
)