LazyBoss commited on
Commit
d20336c
·
verified ·
1 Parent(s): 2473120

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -17
app.py CHANGED
@@ -1,7 +1,13 @@
1
- import gradio as gr
2
- from sklearn.neural_network import MLPClassifier
3
- import torchvision.datasets as datasets
4
- import seaborn as sns
 
 
 
 
 
 
5
 
6
  # Dark mode seaborn
7
  sns.set_style("darkgrid")
@@ -10,7 +16,7 @@ sns.set_style("darkgrid")
10
  mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=None)
11
  mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=None)
12
 
13
- X_train = mnist_trainset.data.numpy()
14
  X_test = mnist_testset.data.numpy()
15
  y_train = mnist_trainset.targets.numpy()
16
  y_test = mnist_testset.targets.numpy()
@@ -23,15 +29,14 @@ X_test = X_test.reshape(10000, 784) / 255.0
23
  mlp = MLPClassifier(hidden_layer_sizes=(32, 32))
24
  mlp.fit(X_train, y_train)
25
 
26
- # Print the accuracies
27
- print("Training Accuracy: ", mlp.score(X_train, y_train))
28
- print("Testing Accuracy: ", mlp.score(X_test, y_test))
29
-
30
- # Define prediction function
31
- def predict(img):
32
- img = img.reshape(1, 784) / 255.0
33
- prediction = mlp.predict(img)[0]
34
- return int(prediction)
35
-
36
- # Launch Gradio interface
37
- gr.Interface(fn=predict, inputs="sketchpad", outputs="label").launch()
 
1
+ from fastapi import FastAPI, File, UploadFile
2
+ from fastapi.responses import JSONResponse
3
+ from sklearn.neural_network import MLPClassifier
4
+ import torchvision.datasets as datasets
5
+ import numpy as np
6
+ from PIL import Image
7
+ from io import BytesIO
8
+ import seaborn as sns
9
+
10
+ app = FastAPI()
11
 
12
  # Dark mode seaborn
13
  sns.set_style("darkgrid")
 
16
  mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=None)
17
  mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=None)
18
 
19
+ X_train = mnist_trainset.data.numpy()
20
  X_test = mnist_testset.data.numpy()
21
  y_train = mnist_trainset.targets.numpy()
22
  y_test = mnist_testset.targets.numpy()
 
29
  mlp = MLPClassifier(hidden_layer_sizes=(32, 32))
30
  mlp.fit(X_train, y_train)
31
 
32
+ @app.post("/predict")
33
+ async def predict(file: UploadFile = File(...)):
34
+ try:
35
+ contents = await file.read()
36
+ image = Image.open(BytesIO(contents)).convert("L").resize((28, 28))
37
+ img_array = np.array(image)
38
+ img_array = img_array.flatten() / 255.0
39
+ prediction = mlp.predict(img_array.reshape(1, -1))[0]
40
+ return JSONResponse(content={"prediction": int(prediction)})
41
+ except Exception as e:
42
+ return JSONResponse(content={"error": str(e)})