SANJAYV10 commited on
Commit
bfbcab4
1 Parent(s): fa91d9c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -14
app.py CHANGED
@@ -2,7 +2,6 @@ import torch
2
  import fastapi
3
  import numpy as np
4
  from PIL import Image
5
- from typing import Any, Type
6
 
7
  class TorchTensor(torch.Tensor):
8
  pass
@@ -10,23 +9,21 @@ class TorchTensor(torch.Tensor):
10
  class Prediction:
11
  prediction: TorchTensor
12
 
13
- app = fastapi.FastAPI()
14
-
15
- model = torch.load("model67.bin", map_location='cpu')
16
 
 
17
  # Define a function to preprocess the input image
18
- def preprocess_input(input: Any):
19
- image = Image.open(BytesIO(input))
20
  image = image.resize((224, 224))
21
  input = np.array(image)
22
  input = torch.from_numpy(input).float()
23
- input = input.permute(2, 0, 1)
24
  input = input.unsqueeze(0)
25
  return input
26
 
27
  # Define an endpoint to make predictions
28
  @app.post("/predict")
29
- async def predict_endpoint(input: Any):
30
  """Make a prediction on an image uploaded by the user."""
31
 
32
  # Preprocess the input image
@@ -35,12 +32,8 @@ async def predict_endpoint(input: Any):
35
  # Make a prediction
36
  prediction = model(input)
37
 
38
- # Get the predicted class
39
  predicted_class = prediction.argmax(1).item()
40
 
41
  # Return the predicted class in JSON format
42
- return {"prediction": predicted_class}
43
-
44
- if __name__ == "__main__":
45
- import uvicorn
46
- uvicorn.run(app, host="0.0.0.0", port=8000)
 
2
  import fastapi
3
  import numpy as np
4
  from PIL import Image
 
5
 
6
  class TorchTensor(torch.Tensor):
7
  pass
 
9
  class Prediction:
10
  prediction: TorchTensor
11
 
12
+ app = fastapi.FastAPI(docs_url="/")
 
 
13
 
14
+ model = torch.load("best_model.pth", map_location='cpu')
15
  # Define a function to preprocess the input image
16
+ def preprocess_input(input: fastapi.UploadFile):
17
+ image = Image.open(input.file)
18
  image = image.resize((224, 224))
19
  input = np.array(image)
20
  input = torch.from_numpy(input).float()
 
21
  input = input.unsqueeze(0)
22
  return input
23
 
24
  # Define an endpoint to make predictions
25
  @app.post("/predict")
26
+ async def predict_endpoint(input:fastapi.UploadFile):
27
  """Make a prediction on an image uploaded by the user."""
28
 
29
  # Preprocess the input image
 
32
  # Make a prediction
33
  prediction = model(input)
34
 
35
+
36
  predicted_class = prediction.argmax(1).item()
37
 
38
  # Return the predicted class in JSON format
39
+ return {"prediction": predicted_class}