Update app.py
Browse files
app.py
CHANGED
|
@@ -1,7 +1,13 @@
|
|
| 1 |
-
import
|
| 2 |
-
from
|
| 3 |
-
|
| 4 |
-
import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
# Dark mode seaborn
|
| 7 |
sns.set_style("darkgrid")
|
|
@@ -10,7 +16,7 @@ sns.set_style("darkgrid")
|
|
| 10 |
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=None)
|
| 11 |
mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=None)
|
| 12 |
|
| 13 |
-
X_train = mnist_trainset.data.numpy()
|
| 14 |
X_test = mnist_testset.data.numpy()
|
| 15 |
y_train = mnist_trainset.targets.numpy()
|
| 16 |
y_test = mnist_testset.targets.numpy()
|
|
@@ -23,15 +29,14 @@ X_test = X_test.reshape(10000, 784) / 255.0
|
|
| 23 |
mlp = MLPClassifier(hidden_layer_sizes=(32, 32))
|
| 24 |
mlp.fit(X_train, y_train)
|
| 25 |
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
gr.Interface(fn=predict, inputs="sketchpad", outputs="label").launch()
|
|
|
|
| 1 |
+
from fastapi import FastAPI, File, UploadFile
|
| 2 |
+
from fastapi.responses import JSONResponse
|
| 3 |
+
from sklearn.neural_network import MLPClassifier
|
| 4 |
+
import torchvision.datasets as datasets
|
| 5 |
+
import numpy as np
|
| 6 |
+
from PIL import Image
|
| 7 |
+
from io import BytesIO
|
| 8 |
+
import seaborn as sns
|
| 9 |
+
|
| 10 |
+
app = FastAPI()
|
| 11 |
|
| 12 |
# Dark mode seaborn
|
| 13 |
sns.set_style("darkgrid")
|
|
|
|
| 16 |
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=None)
|
| 17 |
mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=None)
|
| 18 |
|
| 19 |
+
X_train = mnist_trainset.data.numpy()
|
| 20 |
X_test = mnist_testset.data.numpy()
|
| 21 |
y_train = mnist_trainset.targets.numpy()
|
| 22 |
y_test = mnist_testset.targets.numpy()
|
|
|
|
| 29 |
mlp = MLPClassifier(hidden_layer_sizes=(32, 32))
|
| 30 |
mlp.fit(X_train, y_train)
|
| 31 |
|
| 32 |
+
@app.post("/predict")
|
| 33 |
+
async def predict(file: UploadFile = File(...)):
|
| 34 |
+
try:
|
| 35 |
+
contents = await file.read()
|
| 36 |
+
image = Image.open(BytesIO(contents)).convert("L").resize((28, 28))
|
| 37 |
+
img_array = np.array(image)
|
| 38 |
+
img_array = img_array.flatten() / 255.0
|
| 39 |
+
prediction = mlp.predict(img_array.reshape(1, -1))[0]
|
| 40 |
+
return JSONResponse(content={"prediction": int(prediction)})
|
| 41 |
+
except Exception as e:
|
| 42 |
+
return JSONResponse(content={"error": str(e)})
|
|
|