nathan ayers
Update app.py
3fa79b6 verified
raw
history blame
936 Bytes
from fastapi import FastAPI, File, UploadFile
from fastapi.responses import JSONResponse
import pickle
import numpy as np
from PIL import Image
app = FastAPI()
model = pickle.load(open("mnist_model.pkl", "rb"))
def preprocess_image(file_bytes) -> np.ndarray:
# 1) Load into PIL, convert to grayscale 'L'
img = Image.open(file_bytes).convert("L")
# 2) Resize to 28×28 (use ANTIALIAS for quality)
img = img.resize((28,28), Image.ANTIALIAS)
# 3) Convert to numpy array (uint8), flatten to length-784
arr = np.array(img).astype("uint8").reshape(1, -1)
# 4) Optionally invert colors if your MNIST is white-on-black:
# arr = 255 - arr
return arr
@app.post("/predict-image/")
async def predict_image(file: UploadFile = File(...)):
# read the incoming UploadFile into BytesIO
arr = preprocess_image(file.file)
pred = model.predict(arr)[0]
return JSONResponse({"prediction": int(pred)})