Spaces:
Sleeping
Sleeping
File size: 1,836 Bytes
2275a4b 98d74fb 2275a4b 98d74fb 2275a4b 98d74fb 2275a4b 98d74fb 2275a4b 98d74fb 2275a4b 98d74fb 2275a4b 98d74fb 2275a4b 98d74fb 2275a4b 98d74fb 2275a4b 98d74fb 2275a4b 98d74fb 2275a4b 98d74fb 2275a4b 98d74fb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 |
from fastapi import FastAPI, HTTPException, File, UploadFile
from fastapi.responses import JSONResponse
from io import BytesIO
from PIL import Image
import torch
from torchvision import transforms
import os
from .model import MalwareNet, malware_classes
app = FastAPI()
def preprocess_image(image_data):
image = Image.open(BytesIO(image_data)).convert("RGB")
preprocess = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
return preprocess(image).unsqueeze(0)
def load_model():
model = MalwareNet()
base_dir = os.path.dirname(os.path.abspath(__file__))
model_location = os.path.join(base_dir, '../model/malwareNet.pt')
state_dict = torch.load(model_location, map_location=torch.device('cpu'), weights_only=True)
model.load_state_dict(state_dict)
model.eval()
return model
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
try:
# Read file bytes
image_data = await file.read()
# Preprocess the image
img_tensor = preprocess_image(image_data)
# Load the model and make the prediction
model = load_model()
with torch.no_grad():
prediction = model(img_tensor)
# Get the predicted class
predicted_class = malware_classes[torch.argmax(prediction).item()]
return JSONResponse(content={"prediction": predicted_class})
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error processing the image: {e}")
if __name__ == "__main__":
import uvicorn
uvicorn.run(
"src.serve:app",
host=os.environ.get("HOST", "localhost"),
port=int(os.environ.get("PORT", 5000)),
reload=True,
) |