Jon Solow commited on
Commit
f68d424
·
1 Parent(s): 0dc3e49

Get app functional locally

Browse files
src/huggingface/__init__.py ADDED
File without changes
src/huggingface/handler.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from skimage.io import imread
2
+ from io import BytesIO
3
+ import numpy as np
4
+ import urllib
5
+ from tensorflow.keras.utils import array_to_img, img_to_array
6
+
7
+ from PIL import Image
8
+
9
+ def preprocess(img: np.ndarray) -> np.ndarray:
10
+ img = array_to_img(img, scale=False)
11
+ img = img.resize((224, 224))
12
+ img = img_to_array(img)
13
+ return img / 255.0
14
+
15
+
16
+ def handle_url(url: str) -> np.ndarray:
17
+ try:
18
+ img_data = imread(url)
19
+ except Exception:
20
+ req = urllib.request.Request(url, headers={"User-Agent": "Magic Browser"})
21
+ con = urllib.request.urlopen(req)
22
+ img_data = imread(con)
23
+ processed_img = preprocess(img_data)
24
+ img_array = np.array([processed_img])
25
+ return img_array
26
+
27
+
28
+ def read_imagefile(file):
29
+ file_bytes = BytesIO(file)
30
+ image = Image.open(file_bytes)
31
+ return image
32
+
33
+
34
+ def handle_file(file) -> np.ndarray:
35
+ img_data = read_imagefile(file)
36
+ processed_img = preprocess(img_data)
37
+ img_array = np.array([processed_img])
38
+ return img_array
src/huggingface/labels.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CLASS_LABELS = [
2
+ "Apple pie",
3
+ "Baby back ribs",
4
+ "Baklava",
5
+ "Beef carpaccio",
6
+ "Beef tartare",
7
+ "Beet salad",
8
+ "Beignets",
9
+ "Bibimbap",
10
+ "Bread pudding",
11
+ "Breakfast burrito",
12
+ "Bruschetta",
13
+ "Caesar salad",
14
+ "Cannoli",
15
+ "Caprese salad",
16
+ "Carrot cake",
17
+ "Ceviche",
18
+ "Cheesecake",
19
+ "Cheese plate",
20
+ "Chicken curry",
21
+ "Chicken quesadilla",
22
+ "Chicken wings",
23
+ "Chocolate cake",
24
+ "Chocolate mousse",
25
+ "Churros",
26
+ "Clam chowder",
27
+ "Club sandwich",
28
+ "Crab cakes",
29
+ "Creme brulee",
30
+ "Croque madame",
31
+ "Cup cakes",
32
+ "Deviled eggs",
33
+ "Donuts",
34
+ "Dumplings",
35
+ "Edamame",
36
+ "Eggs benedict",
37
+ "Escargots",
38
+ "Falafel",
39
+ "Filet mignon",
40
+ "Fish and chips",
41
+ "Foie gras",
42
+ "French fries",
43
+ "French onion soup",
44
+ "French toast",
45
+ "Fried calamari",
46
+ "Fried rice",
47
+ "Frozen yogurt",
48
+ "Garlic bread",
49
+ "Gnocchi",
50
+ "Greek salad",
51
+ "Grilled cheese sandwich",
52
+ "Grilled salmon",
53
+ "Guacamole",
54
+ "Gyoza",
55
+ "Hamburger",
56
+ "Hot and sour soup",
57
+ "Hot dog",
58
+ "Huevos rancheros",
59
+ "Hummus",
60
+ "Ice cream",
61
+ "Lasagna",
62
+ "Lobster bisque",
63
+ "Lobster roll sandwich",
64
+ "Macaroni and cheese",
65
+ "Macarons",
66
+ "Miso soup",
67
+ "Mussels",
68
+ "Nachos",
69
+ "Omelette",
70
+ "Onion rings",
71
+ "Oysters",
72
+ "Pad thai",
73
+ "Paella",
74
+ "Pancakes",
75
+ "Panna cotta",
76
+ "Peking duck",
77
+ "Pho",
78
+ "Pizza",
79
+ "Pork chop",
80
+ "Poutine",
81
+ "Prime rib",
82
+ "Pulled pork sandwich",
83
+ "Ramen",
84
+ "Ravioli",
85
+ "Red velvet cake",
86
+ "Risotto",
87
+ "Samosa",
88
+ "Sashimi",
89
+ "Scallops",
90
+ "Seaweed salad",
91
+ "Shrimp and grits",
92
+ "Spaghetti bolognese",
93
+ "Spaghetti carbonara",
94
+ "Spring rolls",
95
+ "Steak",
96
+ "Strawberry shortcake",
97
+ "Sushi",
98
+ "Tacos",
99
+ "Takoyaki",
100
+ "Tiramisu",
101
+ "Tuna tartare",
102
+ "Waffles",
103
+ ]
src/huggingface/model.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from huggingface_hub import from_pretrained_keras
2
+
3
+
4
+ MODEL = from_pretrained_keras("jsolow/grubguesser")
src/huggingface/predict.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .model import MODEL
2
+ from .handler import handle_file, handle_url
3
+ from .labels import CLASS_LABELS
4
+
5
+
6
+
7
+ def predict_model(img_array, n_top_guesses:int = 10):
8
+ class_prob = MODEL.predict(img_array)
9
+ top_values_index = (-class_prob).argsort()[0][:n_top_guesses]
10
+ top_guesses = [CLASS_LABELS[i].title() for i in top_values_index]
11
+ return top_guesses
12
+
13
+
14
+ def predict_file(file):
15
+ handled_array = handle_file(file)
16
+ return predict_model(handled_array)
17
+
18
+
19
+
20
+ def predict_url(url: str):
21
+ handled_array = handle_url(url)
22
+ return predict_model(handled_array)
23
+
src/labels.txt DELETED
@@ -1,101 +0,0 @@
1
- Apple pie
2
- Baby back ribs
3
- Baklava
4
- Beef carpaccio
5
- Beef tartare
6
- Beet salad
7
- Beignets
8
- Bibimbap
9
- Bread pudding
10
- Breakfast burrito
11
- Bruschetta
12
- Caesar salad
13
- Cannoli
14
- Caprese salad
15
- Carrot cake
16
- Ceviche
17
- Cheesecake
18
- Cheese plate
19
- Chicken curry
20
- Chicken quesadilla
21
- Chicken wings
22
- Chocolate cake
23
- Chocolate mousse
24
- Churros
25
- Clam chowder
26
- Club sandwich
27
- Crab cakes
28
- Creme brulee
29
- Croque madame
30
- Cup cakes
31
- Deviled eggs
32
- Donuts
33
- Dumplings
34
- Edamame
35
- Eggs benedict
36
- Escargots
37
- Falafel
38
- Filet mignon
39
- Fish and chips
40
- Foie gras
41
- French fries
42
- French onion soup
43
- French toast
44
- Fried calamari
45
- Fried rice
46
- Frozen yogurt
47
- Garlic bread
48
- Gnocchi
49
- Greek salad
50
- Grilled cheese sandwich
51
- Grilled salmon
52
- Guacamole
53
- Gyoza
54
- Hamburger
55
- Hot and sour soup
56
- Hot dog
57
- Huevos rancheros
58
- Hummus
59
- Ice cream
60
- Lasagna
61
- Lobster bisque
62
- Lobster roll sandwich
63
- Macaroni and cheese
64
- Macarons
65
- Miso soup
66
- Mussels
67
- Nachos
68
- Omelette
69
- Onion rings
70
- Oysters
71
- Pad thai
72
- Paella
73
- Pancakes
74
- Panna cotta
75
- Peking duck
76
- Pho
77
- Pizza
78
- Pork chop
79
- Poutine
80
- Prime rib
81
- Pulled pork sandwich
82
- Ramen
83
- Ravioli
84
- Red velvet cake
85
- Risotto
86
- Samosa
87
- Sashimi
88
- Scallops
89
- Seaweed salad
90
- Shrimp and grits
91
- Spaghetti bolognese
92
- Spaghetti carbonara
93
- Spring rolls
94
- Steak
95
- Strawberry shortcake
96
- Sushi
97
- Tacos
98
- Takoyaki
99
- Tiramisu
100
- Tuna tartare
101
- Waffles
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/model_client.py CHANGED
@@ -9,45 +9,8 @@ from django.core.files.uploadedfile import InMemoryUploadedFile, TemporaryUpload
9
 
10
  import logging
11
 
 
12
 
13
- MODEL_ENDPOINT_URL = os.getenv("MODEL_ENDPOINT_URL", "https://0.0.0.0:2000")
14
-
15
-
16
- def try_make_request(request_kwargs, request_type: str):
17
- try:
18
- request_kwargs["timeout"] = 3
19
- if request_type.lower() == "get":
20
- response = requests.get(**request_kwargs)
21
- elif request_type.lower() == "post":
22
- response = requests.post(**request_kwargs)
23
- else:
24
- raise Exception("Request Type not Supported. Only get, post supported.")
25
-
26
- return json.loads(response.content)
27
- except requests.exceptions.ConnectionError:
28
- logging.warning("Failed Model prediction", exc_info=True)
29
- return ["Image Failed to Predict", "Try Another Image", "", "", ""]
30
- except Exception:
31
- logging.warning("Failed Model prediction", exc_info=True)
32
- return ["Image Failed to Predict", "Try Another Image", "", "", ""]
33
-
34
-
35
- def predict_url(url: str) -> List[str]:
36
- params = {"url": url}
37
- headers = {"content-type": "application/json", "Accept-Charset": "UTF-8"}
38
- request_url = urljoin(MODEL_ENDPOINT_URL, "predict_url")
39
- request_kwargs = dict(url=request_url, params=params, headers=headers)
40
- return try_make_request(request_kwargs, "get")
41
-
42
-
43
- def predict_file(image_file) -> List[str]:
44
- image_file.seek(0)
45
- file_ob = {
46
- "upload_file": (image_file.name, image_file.read(), image_file.content_type)
47
- }
48
- request_url = urljoin(MODEL_ENDPOINT_URL, "predict_file")
49
- request_kwargs = dict(url=request_url, files=file_ob)
50
- return try_make_request(request_kwargs, "post")
51
 
52
 
53
  def get_color_labels(guesses: List[str], actual_label: Optional[str]) -> List[str]:
@@ -66,7 +29,8 @@ def url_image_vars(
66
  elif isinstance(input_img, str):
67
  top_guesses = predict_url(input_img)
68
  elif isinstance(input_img, (InMemoryUploadedFile, TemporaryUploadedFile)):
69
- top_guesses = predict_file(input_img)
 
70
  else:
71
  logging.error(f"Unknown input type: {type(input_img)=}")
72
  top_guesses = ["Unknown Input Type", "", "", "", ""]
@@ -75,18 +39,4 @@ def url_image_vars(
75
 
76
 
77
  def is_healthy() -> bool:
78
- request_url = urljoin(MODEL_ENDPOINT_URL, "healthcheck")
79
- try:
80
- response = requests.get(url=request_url, timeout=1)
81
- except Exception:
82
- logging.error("Failed to make healthcheck request")
83
- return False
84
- if response.status_code == 200:
85
- try:
86
- response_content = json.loads(response.content)
87
- except Exception:
88
- logging.error("Failed to load healthcheck content")
89
- return False
90
- if response_content == {"status": "alive"}:
91
- return True
92
- return False
 
9
 
10
  import logging
11
 
12
+ from huggingface.predict import predict_file, predict_url
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
 
16
  def get_color_labels(guesses: List[str], actual_label: Optional[str]) -> List[str]:
 
29
  elif isinstance(input_img, str):
30
  top_guesses = predict_url(input_img)
31
  elif isinstance(input_img, (InMemoryUploadedFile, TemporaryUploadedFile)):
32
+ input_img.seek(0)
33
+ top_guesses = predict_file(input_img.read())
34
  else:
35
  logging.error(f"Unknown input type: {type(input_img)=}")
36
  top_guesses = ["Unknown Input Type", "", "", "", ""]
 
39
 
40
 
41
  def is_healthy() -> bool:
42
+ return True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/requirements.in CHANGED
@@ -6,3 +6,5 @@ joblib
6
  whitenoise
7
  pillow
8
  transformers
 
 
 
6
  whitenoise
7
  pillow
8
  transformers
9
+ tensorflow
10
+ scikit-image
src/requirements.txt CHANGED
@@ -4,10 +4,18 @@
4
  #
5
  # pip-compile --resolver=backtracking requirements.in
6
  #
 
 
 
 
7
  asgiref==3.7.1
8
  # via django
 
 
9
  backports-zoneinfo==0.2.1
10
  # via django
 
 
11
  certifi==2023.5.7
12
  # via requests
13
  charset-normalizer==3.1.0
@@ -18,26 +26,96 @@ filelock==3.12.0
18
  # via
19
  # huggingface-hub
20
  # transformers
 
 
21
  fsspec==2023.5.0
22
  # via huggingface-hub
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  gunicorn==20.1.0
24
  # via -r requirements.in
 
 
25
  huggingface-hub==0.14.1
26
  # via transformers
27
  idna==3.4
28
  # via requests
 
 
 
 
 
 
29
  joblib==1.2.0
30
  # via -r requirements.in
 
 
 
 
 
 
 
 
 
 
 
 
31
  numpy==1.23.5
32
  # via
33
  # -r requirements.in
 
 
 
 
 
 
 
 
 
 
 
34
  # transformers
 
 
 
 
 
 
35
  packaging==23.1
36
  # via
37
  # huggingface-hub
 
 
38
  # transformers
39
  pillow==9.5.0
40
- # via -r requirements.in
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  pyyaml==6.0
42
  # via
43
  # huggingface-hub
@@ -48,9 +126,41 @@ requests==2.30.0
48
  # via
49
  # -r requirements.in
50
  # huggingface-hub
 
 
51
  # transformers
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  sqlparse==0.4.4
53
  # via django
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  tokenizers==0.13.3
55
  # via transformers
56
  tqdm==4.65.0
@@ -63,10 +173,23 @@ typing-extensions==4.5.0
63
  # via
64
  # asgiref
65
  # huggingface-hub
 
66
  urllib3==1.26.15
67
- # via requests
 
 
 
 
 
 
 
 
68
  whitenoise==6.4.0
69
  # via -r requirements.in
 
 
 
 
70
 
71
  # The following packages are considered to be unsafe in a requirements file:
72
  # setuptools
 
4
  #
5
  # pip-compile --resolver=backtracking requirements.in
6
  #
7
+ absl-py==1.4.0
8
+ # via
9
+ # tensorboard
10
+ # tensorflow
11
  asgiref==3.7.1
12
  # via django
13
+ astunparse==1.6.3
14
+ # via tensorflow
15
  backports-zoneinfo==0.2.1
16
  # via django
17
+ cachetools==5.3.0
18
+ # via google-auth
19
  certifi==2023.5.7
20
  # via requests
21
  charset-normalizer==3.1.0
 
26
  # via
27
  # huggingface-hub
28
  # transformers
29
+ flatbuffers==23.5.26
30
+ # via tensorflow
31
  fsspec==2023.5.0
32
  # via huggingface-hub
33
+ gast==0.4.0
34
+ # via tensorflow
35
+ google-auth==2.19.0
36
+ # via
37
+ # google-auth-oauthlib
38
+ # tensorboard
39
+ google-auth-oauthlib==1.0.0
40
+ # via tensorboard
41
+ google-pasta==0.2.0
42
+ # via tensorflow
43
+ grpcio==1.54.2
44
+ # via
45
+ # tensorboard
46
+ # tensorflow
47
  gunicorn==20.1.0
48
  # via -r requirements.in
49
+ h5py==3.8.0
50
+ # via tensorflow
51
  huggingface-hub==0.14.1
52
  # via transformers
53
  idna==3.4
54
  # via requests
55
+ imageio==2.29.0
56
+ # via scikit-image
57
+ importlib-metadata==6.6.0
58
+ # via markdown
59
+ jax==0.4.10
60
+ # via tensorflow
61
  joblib==1.2.0
62
  # via -r requirements.in
63
+ keras==2.12.0
64
+ # via tensorflow
65
+ libclang==16.0.0
66
+ # via tensorflow
67
+ markdown==3.4.3
68
+ # via tensorboard
69
+ markupsafe==2.1.2
70
+ # via werkzeug
71
+ ml-dtypes==0.1.0
72
+ # via jax
73
+ networkx==3.1
74
+ # via scikit-image
75
  numpy==1.23.5
76
  # via
77
  # -r requirements.in
78
+ # h5py
79
+ # imageio
80
+ # jax
81
+ # ml-dtypes
82
+ # opt-einsum
83
+ # pywavelets
84
+ # scikit-image
85
+ # scipy
86
+ # tensorboard
87
+ # tensorflow
88
+ # tifffile
89
  # transformers
90
+ oauthlib==3.2.2
91
+ # via requests-oauthlib
92
+ opt-einsum==3.3.0
93
+ # via
94
+ # jax
95
+ # tensorflow
96
  packaging==23.1
97
  # via
98
  # huggingface-hub
99
+ # scikit-image
100
+ # tensorflow
101
  # transformers
102
  pillow==9.5.0
103
+ # via
104
+ # -r requirements.in
105
+ # imageio
106
+ # scikit-image
107
+ protobuf==4.23.2
108
+ # via
109
+ # tensorboard
110
+ # tensorflow
111
+ pyasn1==0.5.0
112
+ # via
113
+ # pyasn1-modules
114
+ # rsa
115
+ pyasn1-modules==0.3.0
116
+ # via google-auth
117
+ pywavelets==1.4.1
118
+ # via scikit-image
119
  pyyaml==6.0
120
  # via
121
  # huggingface-hub
 
126
  # via
127
  # -r requirements.in
128
  # huggingface-hub
129
+ # requests-oauthlib
130
+ # tensorboard
131
  # transformers
132
+ requests-oauthlib==1.3.1
133
+ # via google-auth-oauthlib
134
+ rsa==4.9
135
+ # via google-auth
136
+ scikit-image==0.19.3
137
+ # via -r requirements.in
138
+ scipy==1.10.1
139
+ # via
140
+ # jax
141
+ # scikit-image
142
+ six==1.16.0
143
+ # via
144
+ # astunparse
145
+ # google-auth
146
+ # google-pasta
147
+ # tensorflow
148
  sqlparse==0.4.4
149
  # via django
150
+ tensorboard==2.12.3
151
+ # via tensorflow
152
+ tensorboard-data-server==0.7.0
153
+ # via tensorboard
154
+ tensorflow==2.12.0
155
+ # via -r requirements.in
156
+ tensorflow-estimator==2.12.0
157
+ # via tensorflow
158
+ tensorflow-io-gcs-filesystem==0.32.0
159
+ # via tensorflow
160
+ termcolor==2.3.0
161
+ # via tensorflow
162
+ tifffile==2023.4.12
163
+ # via scikit-image
164
  tokenizers==0.13.3
165
  # via transformers
166
  tqdm==4.65.0
 
173
  # via
174
  # asgiref
175
  # huggingface-hub
176
+ # tensorflow
177
  urllib3==1.26.15
178
+ # via
179
+ # google-auth
180
+ # requests
181
+ werkzeug==2.3.4
182
+ # via tensorboard
183
+ wheel==0.40.0
184
+ # via
185
+ # astunparse
186
+ # tensorboard
187
  whitenoise==6.4.0
188
  # via -r requirements.in
189
+ wrapt==1.14.1
190
+ # via tensorflow
191
+ zipp==3.15.0
192
+ # via importlib-metadata
193
 
194
  # The following packages are considered to be unsafe in a requirements file:
195
  # setuptools
src/start.sh DELETED
@@ -1 +0,0 @@
1
- python3 -m uvicorn main:app --workers 1 --host 0.0.0.0 --port 7860