JackVines commited on
Commit
279d47e
·
verified ·
1 Parent(s): c4f0172

Upload 13 files

Browse files
Files changed (13) hide show
  1. Dockerfile +11 -0
  2. README.md +54 -3
  3. __init__.py +0 -0
  4. app/__init__.py +0 -0
  5. app/main.py +34 -0
  6. app/model.py +41 -0
  7. app/saved_model.pb +3 -0
  8. convert_model.py +33 -0
  9. main.py +34 -0
  10. model.py +41 -0
  11. requirements.txt +7 -0
  12. saved_model.pb +3 -0
  13. streamlit_viz.py +71 -0
Dockerfile ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10
2
+
3
+ WORKDIR /code
4
+
5
+ COPY ./requirements.txt /code/requirements.txt
6
+
7
+ RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
8
+
9
+ COPY ./app /code/app
10
+
11
+ CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8080"]
README.md CHANGED
@@ -1,3 +1,54 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Saliency Inference API Template
2
+
3
+ This is an API and Streamlit app to interact with a saliency model. The API is built using FastAPI and the Streamlit app is built using Streamlit. The API is built to be run in a Docker container.
4
+
5
+ ## Setup
6
+
7
+ ### Install dependencies
8
+
9
+ ```bash
10
+ pip install -r requirements.txt
11
+ ```
12
+
13
+ ### Run the API
14
+
15
+ ```bash
16
+ uvicorn main:app --reload --workers 1 --host 0.0.0.0 --port 8080
17
+ ```
18
+
19
+ This will run the FastAPI server on port 8080.
20
+
21
+ ### (Alternative) Run the API in a Docker container
22
+
23
+
24
+ ```bash
25
+ docker build -t ds-api-template .
26
+ docker run -p 8080:8080 ds-api-template
27
+ ```
28
+
29
+ You can test this is running by executing the same `curl` command as above, which should return the same response.
30
+
31
+ NOTE: You will need to have Docker installed on your machine. To install Docker, follow the instructions [here](https://docs.docker.com/get-docker/).
32
+
33
+ ## Run the Streamlit App
34
+
35
+ Once you've set up the API, you can run the Streamlit app to interact with the API.
36
+
37
+ To run the Streamlit app, run the following command:
38
+
39
+ ```bash
40
+ streamlit run app.py
41
+ ```
42
+
43
+ You will need to have Streamlit installed on your machine. To install Streamlit, run the following command:
44
+
45
+ ```bash
46
+ pip install streamlit
47
+ ```
48
+
49
+ You will also need to update a `secrets.toml` file in a `.streamlit` directory at the root of the repo. This file should contain the following:
50
+
51
+ ```toml
52
+ api_host = "http://localhost:8501"
53
+ password = "<INSERT DESIRED PASSWORD HERE>"
54
+ ```
__init__.py ADDED
File without changes
app/__init__.py ADDED
File without changes
app/main.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, File
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+
4
+ from .model import predict
5
+ import json
6
+
7
+ app = FastAPI()
8
+
9
+ # CORS
10
+ origins = [
11
+ "http://localhost:8080",
12
+ "http://localhost"
13
+ ]
14
+
15
+ app.add_middleware(
16
+ CORSMiddleware,
17
+ allow_origins=origins,
18
+ allow_credentials=True,
19
+ allow_methods=["POST"],
20
+ allow_headers=["*"],
21
+ )
22
+
23
+ @app.post("/predict")
24
+ def img_object_detection_to_img(file: bytes = File(...)):
25
+ """
26
+ Object Detection from an image plot bbox on image
27
+
28
+ Args:
29
+ file (bytes): The image file in bytes format.
30
+ Returns:
31
+ The json representation of the prediction
32
+ """
33
+ prediction = predict(file)
34
+ return json.dumps(prediction.tolist())
app/model.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ from PIL import Image
3
+ import io
4
+
5
+ imported = tf.saved_model.load("./app")
6
+ imported = imported.signatures["serving_default"]
7
+
8
+ def get_image_from_bytes(binary_image: bytes) -> Image:
9
+ """Convert image from bytes to PIL RGB format
10
+
11
+ Args:
12
+ binary_image (bytes): The binary representation of the image
13
+
14
+ Returns:
15
+ PIL.Image: The image in PIL RGB format
16
+ """
17
+ input_image = Image.open(io.BytesIO(binary_image)).convert("RGB")
18
+ return input_image
19
+
20
+ def predict(input_image):
21
+ """Reads file and returns prediction
22
+
23
+ Args:
24
+ x (_type_): _description_
25
+
26
+ Returns:
27
+ _type_: _description_
28
+ """
29
+ tensor = tf.io.decode_image(input_image, channels=3)
30
+
31
+ inference_shape = (240, 320)
32
+ original_shape = tensor.shape[:2]
33
+
34
+ input_tensor = tf.expand_dims(tensor, axis=0)
35
+
36
+ input_tensor = tf.image.resize(input_tensor, inference_shape,
37
+ preserve_aspect_ratio=True)
38
+ saliency = imported(input_tensor)["output"]
39
+
40
+ saliency = tf.image.resize(saliency, original_shape)
41
+ return saliency.numpy()[0]
app/saved_model.pb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:646e0f343c4357e828f2569bef2f2bf288449fe68f7e4fb43e076f2e3b094e3d
3
+ size 99858975
convert_model.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # use this script to convert any of the models saved to be
2
+ # compatible with tf2: https://drive.google.com/drive/folders/1GI7i6GpfI-FoklP3vCc6vxe3T9nk3V2n
3
+
4
+ import tensorflow as tf
5
+ from tensorflow.python.saved_model import signature_constants, tag_constants
6
+
7
+ export_dir = "./app/"
8
+ # update the below line to point at the desired model downloaded
9
+ # from the above google drive link
10
+ graph_pb = "./app/model_salicon_cpu.pb"
11
+
12
+ with tf.io.gfile.GFile(graph_pb, "rb") as f:
13
+ graph_def = tf.compat.v1.GraphDef()
14
+ graph_def.ParseFromString(f.read())
15
+
16
+ sig = {}
17
+
18
+ builder = tf.compat.v1.saved_model.Builder(export_dir)
19
+
20
+ with tf.compat.v1.Session(graph=tf.Graph()) as sess:
21
+ tf.import_graph_def(graph_def, name="")
22
+ g = tf.compat.v1.get_default_graph()
23
+
24
+ input = g.get_tensor_by_name("input:0")
25
+ output = g.get_tensor_by_name("output:0")
26
+
27
+ sig_key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
28
+ sig[sig_key] = tf.compat.v1.saved_model.predict_signature_def({"input": input},
29
+ {"output": output})
30
+ builder.add_meta_graph_and_variables(sess,
31
+ [tag_constants.SERVING],
32
+ signature_def_map=sig)
33
+ builder.save()
main.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, File
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+
4
+ from .model import predict
5
+ import json
6
+
7
+ app = FastAPI()
8
+
9
+ # CORS
10
+ origins = [
11
+ "http://localhost:8080",
12
+ "http://localhost"
13
+ ]
14
+
15
+ app.add_middleware(
16
+ CORSMiddleware,
17
+ allow_origins=origins,
18
+ allow_credentials=True,
19
+ allow_methods=["POST"],
20
+ allow_headers=["*"],
21
+ )
22
+
23
+ @app.post("/predict")
24
+ def img_object_detection_to_img(file: bytes = File(...)):
25
+ """
26
+ Object Detection from an image plot bbox on image
27
+
28
+ Args:
29
+ file (bytes): The image file in bytes format.
30
+ Returns:
31
+ The json representation of the prediction
32
+ """
33
+ prediction = predict(file)
34
+ return json.dumps(prediction.tolist())
model.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ from PIL import Image
3
+ import io
4
+
5
+ imported = tf.saved_model.load("./app")
6
+ imported = imported.signatures["serving_default"]
7
+
8
+ def get_image_from_bytes(binary_image: bytes) -> Image:
9
+ """Convert image from bytes to PIL RGB format
10
+
11
+ Args:
12
+ binary_image (bytes): The binary representation of the image
13
+
14
+ Returns:
15
+ PIL.Image: The image in PIL RGB format
16
+ """
17
+ input_image = Image.open(io.BytesIO(binary_image)).convert("RGB")
18
+ return input_image
19
+
20
+ def predict(input_image):
21
+ """Reads file and returns prediction
22
+
23
+ Args:
24
+ x (_type_): _description_
25
+
26
+ Returns:
27
+ _type_: _description_
28
+ """
29
+ tensor = tf.io.decode_image(input_image, channels=3)
30
+
31
+ inference_shape = (240, 320)
32
+ original_shape = tensor.shape[:2]
33
+
34
+ input_tensor = tf.expand_dims(tensor, axis=0)
35
+
36
+ input_tensor = tf.image.resize(input_tensor, inference_shape,
37
+ preserve_aspect_ratio=True)
38
+ saliency = imported(input_tensor)["output"]
39
+
40
+ saliency = tf.image.resize(saliency, original_shape)
41
+ return saliency.numpy()[0]
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ fastapi==0.103.2
2
+ uvicorn==0.23.2
3
+ tensorflow
4
+ python-multipart
5
+ Pillow
6
+ streamlit
7
+ matplotlib
saved_model.pb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:646e0f343c4357e828f2569bef2f2bf288449fe68f7e4fb43e076f2e3b094e3d
3
+ size 99858975
streamlit_viz.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """App to visualize saliency maps for images.
2
+ To run, use:
3
+ streamlit run streamlit_viz.py
4
+ """
5
+
6
+ import streamlit as st
7
+ import pandas as pd
8
+ import numpy as np
9
+ import requests
10
+ import hmac
11
+ import json
12
+ import matplotlib.pyplot as plt
13
+ import matplotlib.image as mpimg
14
+
15
+ from PIL import Image
16
+
17
+ st.set_option('deprecation.showPyplotGlobalUse', False)
18
+
19
+ def check_password():
20
+ """Returns `True` if the user had the correct password."""
21
+
22
+ def password_entered():
23
+ """Checks whether a password entered by the user is correct."""
24
+ if hmac.compare_digest(st.session_state["password"], st.secrets["password"]):
25
+ st.session_state["password_correct"] = True
26
+ del st.session_state["password"] # Don't store the password.
27
+ else:
28
+ st.session_state["password_correct"] = False
29
+
30
+ # Return True if the passward is validated.
31
+ if st.session_state.get("password_correct", False):
32
+ return True
33
+
34
+ # Show input for password.
35
+ st.text_input(
36
+ "Password", type="password", on_change=password_entered, key="password"
37
+ )
38
+ if "password_correct" in st.session_state:
39
+ st.error("😕 Password incorrect")
40
+ return False
41
+
42
+
43
+ if not check_password():
44
+ st.stop() # Do not continue if check_password is not True.
45
+
46
+ st.title("Saliency Map Visualizer")
47
+
48
+ st.markdown(
49
+ """
50
+ This is a demo of the Saliency Map Visualizer. To use it, upload an image
51
+ and click the button below. Please note, it may take up to 20 seconds to visualise.
52
+ """
53
+ )
54
+
55
+ # get host from secrets
56
+ api_host = st.secrets["api_host"]
57
+
58
+ uploaded_file = st.file_uploader("Choose an image...", type=(["jpg", "jpeg", "png"]))
59
+
60
+ if uploaded_file is not None:
61
+ file = {'file': uploaded_file.read()}
62
+ st.write("")
63
+ st.write("Classifying...")
64
+ response = requests.post(api_host, files=file)
65
+ arr = np.asarray(json.loads(response.json()))
66
+ st.write("Done!")
67
+ # Show plt plots
68
+ plt.imshow(Image.open(uploaded_file))
69
+ plt.imshow(arr, alpha=0.6)
70
+ plt.axis('off')
71
+ st.pyplot()