Bitirme commited on
Commit
0cfb5b9
1 Parent(s): e6735d2

Update api.py

Browse files
Files changed (1) hide show
  1. api.py +39 -28
api.py CHANGED
@@ -1,24 +1,13 @@
1
- from fastapi import FastAPI, File, UploadFile
2
- import numpy as np
3
- import tensorflow as tf
4
- from fastapi.middleware.cors import CORSMiddleware
5
  import cv2
 
 
 
6
  from pydantic import BaseModel
7
- from huggingface_hub import from_pretrained_keras
8
 
9
  app = FastAPI()
10
 
11
- origins = ["*"]
12
-
13
- app.add_middleware(
14
- CORSMiddleware,
15
- allow_origins=origins,
16
- allow_credentials=True,
17
- allow_methods=["*"],
18
- allow_headers=["*"],
19
- )
20
 
21
- # Filtre kısmı
22
  def crop_image_from_gray(img, tol=7):
23
  if img.ndim == 2:
24
  mask = img > tol
@@ -36,6 +25,7 @@ def crop_image_from_gray(img, tol=7):
36
  img = np.stack([img1, img2, img3], axis=-1)
37
  return img
38
 
 
39
  def load_ben_color(image, sigmaX=10):
40
  image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
41
  image = crop_image_from_gray(image)
@@ -43,6 +33,7 @@ def load_ben_color(image, sigmaX=10):
43
  image = cv2.addWeighted(image, 4, cv2.GaussianBlur(image, (0, 0), sigmaX), -4, 128)
44
  return image
45
 
 
46
  def clahe(image):
47
  clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
48
  r, g, b = cv2.split(image)
@@ -52,52 +43,72 @@ def clahe(image):
52
  result = cv2.merge((r, g, b))
53
  return result
54
 
 
55
  def filter1(image):
56
  image = load_ben_color(image)
57
  return image
58
 
 
59
  def filter2(image):
60
  image = clahe(image)
61
  image = cv2.resize(image, (224, 224))
62
  return image
63
 
64
- def predict(image, model, filter_func):
65
- model_image = filter_func(image)
 
66
  model_image = np.array([model_image], dtype=np.float32) / 255.0
67
- predictions = model(tf.constant(model_image))
68
- return predictions.numpy()
 
 
69
 
70
  def result(predictions):
71
  class_labels = ["Age related Macular Degeneration", "Cataract", "Diabetes", "Glaucoma", "Hypertension", "Normal", "Others", "Pathological Myopia"]
72
  predictions = np.array(predictions)
73
  predictions = predictions.tolist()[0]
74
  predictions_index = np.argmax(predictions)
75
- return class_labels[predictions_index]
 
 
 
 
 
 
76
 
77
  # Model tanımlamaları
78
- models_names = ["ODIR-B-2K-5Class-LastTrain-Xception", "ODIR-B-2K-6Class-LastTrain-Xception"]
79
- model_paths = ["Bitirme/odirmodel/ODIR-B-2K-5Class-LastTrain-Xception", "Bitirme/odirmodel/ODIR-B-2K-6Class-LastTrain-Xception"]
 
 
 
80
 
81
- filters = [filter1, filter2] # tanımlandı
82
 
83
  models = []
84
- for model_path in model_paths:
85
- model = from_pretrained_keras(model_path)
 
86
  models.append(model)
87
 
88
  class PredictionResponse(BaseModel):
89
- predictions: dict
90
 
91
  @app.post("/predict", response_model=PredictionResponse)
92
  async def predict_endpoint(file: UploadFile = File(...)):
 
93
  contents = await file.read()
94
- nparr = np.frombuffer(contents, np.uint8)
95
  image = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
96
 
97
  result_json = {}
 
98
  for i in range(len(models)):
99
  model = models[i]
100
  prediction = predict(image, model, filters[i])
101
  result_json[models_names[i]] = result(prediction)
102
 
103
- return {"predictions": result_json}
 
 
 
 
 
 
 
 
1
  import cv2
2
+ import tensorflow as tf
3
+ import numpy as np
4
+ from fastapi import FastAPI, UploadFile, File
5
  from pydantic import BaseModel
6
+ import uvicorn
7
 
8
  app = FastAPI()
9
 
 
 
 
 
 
 
 
 
 
10
 
 
11
  def crop_image_from_gray(img, tol=7):
12
  if img.ndim == 2:
13
  mask = img > tol
 
25
  img = np.stack([img1, img2, img3], axis=-1)
26
  return img
27
 
28
+
29
  def load_ben_color(image, sigmaX=10):
30
  image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
31
  image = crop_image_from_gray(image)
 
33
  image = cv2.addWeighted(image, 4, cv2.GaussianBlur(image, (0, 0), sigmaX), -4, 128)
34
  return image
35
 
36
+
37
  def clahe(image):
38
  clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
39
  r, g, b = cv2.split(image)
 
43
  result = cv2.merge((r, g, b))
44
  return result
45
 
46
+
47
  def filter1(image):
48
  image = load_ben_color(image)
49
  return image
50
 
51
+
52
  def filter2(image):
53
  image = clahe(image)
54
  image = cv2.resize(image, (224, 224))
55
  return image
56
 
57
+
58
+ def predict(image, model, filter):
59
+ model_image = filter(image)
60
  model_image = np.array([model_image], dtype=np.float32) / 255.0
61
+ infer = model.signatures["serving_default"]
62
+ predictions = infer(tf.constant(model_image))[next(iter(infer.structured_outputs.keys()))].numpy()
63
+ return predictions
64
+
65
 
66
  def result(predictions):
67
  class_labels = ["Age related Macular Degeneration", "Cataract", "Diabetes", "Glaucoma", "Hypertension", "Normal", "Others", "Pathological Myopia"]
68
  predictions = np.array(predictions)
69
  predictions = predictions.tolist()[0]
70
  predictions_index = np.argmax(predictions)
71
+
72
+ result_json = {
73
+ "class": class_labels[predictions_index],
74
+ "probability": predictions[predictions_index]
75
+ }
76
+
77
+ return result_json
78
 
79
  # Model tanımlamaları
80
+ models_names = ["ODIR-B-2K-5Class-LastTrain-Xception"]
81
+
82
+ models_paths = [
83
+ "ODIR-B-2K-5Class-LastTrain-Xception"
84
+ ]
85
 
86
+ filters = [filter1, filter1, filter1, filter1]
87
 
88
  models = []
89
+
90
+ for model_path in models_paths:
91
+ model = tf.saved_model.load(model_path)
92
  models.append(model)
93
 
94
  class PredictionResponse(BaseModel):
95
+ ODIR-B-2K-5Class-LastTrain-Xception: dict
96
 
97
  @app.post("/predict", response_model=PredictionResponse)
98
  async def predict_endpoint(file: UploadFile = File(...)):
99
+
100
  contents = await file.read()
101
+ nparr = np.fromstring(contents, np.uint8)
102
  image = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
103
 
104
  result_json = {}
105
+
106
  for i in range(len(models)):
107
  model = models[i]
108
  prediction = predict(image, model, filters[i])
109
  result_json[models_names[i]] = result(prediction)
110
 
111
+ return result_json
112
+
113
+ if __name__ == "__main__":
114
+ uvicorn.run(app, host="localhost", port=8000)