pablorodriper commited on
Commit
e0cf56d
1 Parent(s): 199b9a9
app.py CHANGED
@@ -1,26 +1,44 @@
 
 
1
  import gradio as gr
2
  import tensorflow as tf
3
- from huggingface_hub import from_pretrained_keras
4
 
5
- description = "Keras implementation for Video Vision Transformer trained with OrganMNIST3D (CT videos)"
6
- article = "Classes: liver, kidney-right, kidney-left, femur-right, femur-left, bladder, heart, lung-right, lung-left, spleen, pancreas.\n\nAuthor:<a href=\"https://huggingface.co/pablorodriper/\"> Pablo Rodríguez</a>; Based on the keras example by <a href=\"https://keras.io/examples/vision/vivit/\">Aritra Roy Gosthipaty and Ayush Thakur</a>"
7
- title = "Video Vision Transformer on OrganMNIST3D"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
- def infer(video):
10
- return model.predict(tf.expand_dims(video, axis=0))[0]
11
 
12
- model = from_pretrained_keras("keras-io/video-vision-transformer")
13
 
14
- labels = ['liver', 'kidney-right', 'kidney-left', 'femur-right', 'femur-left', 'bladder', 'heart', 'lung-right', 'lung-left', 'spleen', 'pancreas']
15
 
16
- iface = gr.Interface(
17
- fn = infer,
18
- inputs = "video",
19
- outputs = "number",
20
- description = description,
21
- title = title,
22
- article = article,
23
- examples=["example_1.mp4", "example_2.mp4"]
24
- )
25
 
26
- iface.launch()
 
 
 
 
 
 
 
 
1
+ import glob
2
+
3
  import gradio as gr
4
  import tensorflow as tf
 
5
 
6
+ from utils.predict import predict_label
7
+
8
+ ##Create list of examples to be loaded
9
+ example_list = glob.glob("examples/*.mp4")
10
+ example_list = list(map(lambda el:[el], example_list))
11
+
12
+ demo = gr.Blocks()
13
+
14
+ with demo:
15
+ gr.Markdown("# **<p align='center'>Video Vision Transformer on medmnist</p>**")
16
+
17
+ with gr.Tab("Upload & Predict"):
18
+ with gr.Box():
19
+ with gr.Row():
20
+ input_video = gr.Video(label="Input Video", show_label=True)
21
+ output_label = gr.Label(label="Model Output", show_label=True)
22
+
23
+ gr.Markdown("**Predict**")
24
+
25
+ with gr.Box():
26
+ with gr.Row():
27
+ submit_button = gr.Button("Submit")
28
 
29
+ gr.Markdown("The model is trained to classify videos belonging to the following classes: liver, kidney-right, kidney-left, femur-right, femur-left, bladder, heart, lung-right, lung-left, spleen and pancreas")
 
30
 
31
+ gr.Examples(example_list, [input_video], [output_label], predict_label, cache_examples=True)
32
 
33
+ submit_button.click(predict_label, inputs=input_video, outputs=output_label)
34
 
35
+ gr.Markdown('\n Demo created by: <a href=\"https://huggingface.co/pablorodriper\"> Pablo Rodríguez</a> <br> Based on the Keras example by <a href=\"https://keras.io/examples/vision/vivit/\">Aritra Roy Gosthipaty and Ayush Thakur</a>')
 
 
 
 
 
 
 
 
36
 
37
+ demo.launch()
38
+
39
+
40
+
41
+
42
+
43
+
44
+
examples/femur-right.mp4 ADDED
Binary file (7.7 kB). View file
 
examples/kidney-left.mp4 ADDED
Binary file (7.19 kB). View file
 
requirements.txt CHANGED
@@ -1 +1,5 @@
1
- tensorflow>2.6
 
 
 
 
 
1
+ transformers==4.23
2
+ huggingface_hub>0.10
3
+ tensorflow>2.6
4
+ gradio
5
+ opencv-python
utils/__init__.py ADDED
File without changes
utils/constants.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+
3
+ # DATA
4
+ DATASET_NAME = "organmnist3d"
5
+ BATCH_SIZE = 32
6
+ AUTO = tf.data.AUTOTUNE
7
+ INPUT_SHAPE = (28, 28, 28, 1)
8
+ NUM_CLASSES = 11
9
+
10
+ # OPTIMIZER
11
+ LEARNING_RATE = 1e-4
12
+ WEIGHT_DECAY = 1e-5
13
+
14
+ # TRAINING
15
+ EPOCHS = 80
16
+
17
+ # TUBELET EMBEDDING
18
+ PATCH_SIZE = (8, 8, 8)
19
+ NUM_PATCHES = (INPUT_SHAPE[0] // PATCH_SIZE[0]) ** 2
20
+
21
+ # ViViT ARCHITECTURE
22
+ LAYER_NORM_EPS = 1e-6
23
+ PROJECTION_DIM = 128
24
+ NUM_HEADS = 8
25
+ NUM_LAYERS = 8
26
+
27
+ labels = ['liver', 'kidney-right', 'kidney-left', 'femur-right', 'femur-left', 'bladder', 'heart', 'lung-right', 'lung-left', 'spleen', 'pancreas']
utils/predict.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import tensorflow as tf
4
+ from huggingface_hub import from_pretrained_keras
5
+ from tensorflow.keras.optimizers import Adam
6
+
7
+ from .constants import LEARNING_RATE
8
+
9
+ def get_model():
10
+ """
11
+ Download the model from the Hugging Face Hub and compile it.
12
+ """
13
+ model = from_pretrained_keras("pablorodriper/video-vision-transformer")
14
+
15
+ model.compile(
16
+ optimizer=Adam(learning_rate=LEARNING_RATE),
17
+ loss="sparse_categorical_crossentropy",
18
+ # metrics=[
19
+ # keras.metrics.SparseCategoricalAccuracy(name="accuracy"),
20
+ # keras.metrics.SparseTopKCategoricalAccuracy(5, name="top-5-accuracy"),
21
+ # ],
22
+ )
23
+
24
+ return model
25
+
26
+
27
+ model = get_model()
28
+ labels = ['liver', 'kidney-right', 'kidney-left', 'femur-right', 'femur-left', 'bladder', 'heart', 'lung-right', 'lung-left', 'spleen', 'pancreas']
29
+
30
+
31
+ def predict_label(path):
32
+ frames = load_video(path)
33
+ dataloader = prepare_dataloader(frames)
34
+ prediction = model.predict(dataloader)[0]
35
+ label = np.argmax(prediction, axis=0)
36
+ label = labels[label]
37
+
38
+ return label
39
+
40
+
41
+ def load_video(path):
42
+ """
43
+ Load video from path and return a list of frames.
44
+ The video is converted to grayscale because it is the format expected by the model.
45
+ """
46
+ cap = cv2.VideoCapture(path)
47
+ frames = []
48
+ try:
49
+ while True:
50
+ ret, frame = cap.read()
51
+ if not ret:
52
+ break
53
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
54
+ frames.append(frame)
55
+ finally:
56
+ cap.release()
57
+ return np.array(frames)
58
+
59
+
60
+ def prepare_dataloader(video):
61
+ video = tf.expand_dims(video, axis=0)
62
+ dataset = tf.data.Dataset.from_tensor_slices((video, np.array([0])))
63
+
64
+ dataloader = (
65
+ dataset.map(preprocess, num_parallel_calls=tf.data.AUTOTUNE)
66
+ .batch(1)
67
+ .prefetch(tf.data.AUTOTUNE)
68
+ )
69
+ return dataloader
70
+
71
+
72
+ @tf.function
73
+ def preprocess(frames: tf.Tensor, label: tf.Tensor):
74
+ """Preprocess the frames tensors and parse the labels."""
75
+ # Preprocess images
76
+ frames = tf.image.convert_image_dtype(
77
+ frames[
78
+ ..., tf.newaxis
79
+ ], # The new axis is to help for further processing with Conv3D layers
80
+ tf.float32,
81
+ )
82
+ # Parse label
83
+ label = tf.cast(label, tf.float32)
84
+ return frames, label