ImageAPI / app.py
LazyBoss's picture
Update app.py
d20336c verified
from fastapi import FastAPI, File, UploadFile
from fastapi.responses import JSONResponse
from sklearn.neural_network import MLPClassifier
import torchvision.datasets as datasets
import numpy as np
from PIL import Image
from io import BytesIO
import seaborn as sns
app = FastAPI()
# Dark mode seaborn
sns.set_style("darkgrid")
# Load MNIST data
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=None)
mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=None)
X_train = mnist_trainset.data.numpy()
X_test = mnist_testset.data.numpy()
y_train = mnist_trainset.targets.numpy()
y_test = mnist_testset.targets.numpy()
# Reshape and normalize data
X_train = X_train.reshape(60000, 784) / 255.0
X_test = X_test.reshape(10000, 784) / 255.0
# Train the model
mlp = MLPClassifier(hidden_layer_sizes=(32, 32))
mlp.fit(X_train, y_train)
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
try:
contents = await file.read()
image = Image.open(BytesIO(contents)).convert("L").resize((28, 28))
img_array = np.array(image)
img_array = img_array.flatten() / 255.0
prediction = mlp.predict(img_array.reshape(1, -1))[0]
return JSONResponse(content={"prediction": int(prediction)})
except Exception as e:
return JSONResponse(content={"error": str(e)})