pablorodriper commited on
Commit
9b2dc59
·
1 Parent(s): 33f824c

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +39 -16
  2. predict.py +54 -0
app.py CHANGED
@@ -1,25 +1,48 @@
 
 
1
  import gradio as gr
2
  import tensorflow as tf
3
  from huggingface_hub import from_pretrained_keras
4
 
5
- description = "Keras implementation for Video Vision Transformer trained with OrganMNIST3D (CT videos)"
6
- article = "Classes: liver, kidney-right, kidney-left, femur-right, femur-left, bladder, heart, lung-right, lung-left, spleen, pancreas.\n\nAuthor:<a href=\"https://huggingface.co/pablorodriper/\"> Pablo Rodríguez</a>; Based on the keras example by <a href=\"https://keras.io/examples/vision/vivit/\">Aritra Roy Gosthipaty and Ayush Thakur</a>"
7
- title = "Video Vision Transformer on OrganMNIST3D"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
- def infer(x):
10
- return model.predict(tf.expand_dims(x, axis=0))[0]
11
 
12
- model = from_pretrained_keras("keras-io/video-vision-transformer")
 
13
 
14
- labels = ['liver', 'kidney-right', 'kidney-left', 'femur-right', 'femur-left', 'bladder', 'heart', 'lung-right', 'lung-left', 'spleen', 'pancreas']
15
 
16
- iface = gr.Interface(
17
- fn = infer,
18
- inputs = "video",
19
- outputs = "number",
20
- description = description,
21
- title = title,
22
- article = article
23
- )
24
 
25
- iface.launch()
 
 
 
 
 
 
 
 
1
+ import glob
2
+
3
  import gradio as gr
4
  import tensorflow as tf
5
  from huggingface_hub import from_pretrained_keras
6
 
7
+ from predict import predict_label
8
+
9
+ ##Create list of examples to be loaded
10
+ example_list = glob.glob("examples/*")
11
+ example_list = list(map(lambda el:[el], example_list))
12
+
13
+ demo = gr.Blocks()
14
+
15
+ with demo:
16
+ gr.Markdown("# **<p align='center'>Video Vision Transformer on medmnist</p>**")
17
+
18
+ with gr.Tabs():
19
+ with gr.TabItem("Upload & Predict"):
20
+ with gr.Box():
21
+ with gr.Row():
22
+ input_video = gr.Video(label="Input Video", show_label=True)
23
+ output_label = gr.Label(label="Model Output", show_label=True)
24
+
25
+ gr.Markdown("**Predict**")
26
+
27
+ with gr.Box():
28
+ with gr.Row():
29
+ submit_button = gr.Button("Submit")
30
 
31
+ gr.Markdown("Examples")
32
+ gr.Markdown("The model is trained to classify videos belonging to the following classes: liver, kidney-right, kidney-left, femur-right, femur-left, bladder, heart, lung-right, lung-left, spleen, pancreas")
33
 
34
+ with gr.Column():
35
+ gr.Examples(example_list, [input_video], [output_label], predict_label, cache_examples=True)
36
 
37
+ submit_button.click(predict_label, inputs=input_video, outputs=output_label)
38
 
39
+ gr.Markdown('\n Demo created by: <a href=\"https://huggingface.co/pablorodriper\"> Pablo Rodríguez</a> Based on the Keras example by <a href=\"https://keras.io/examples/vision/vivit/\">Aritra Roy Gosthipaty and Ayush Thakur</a>')
 
 
 
 
 
 
 
40
 
41
+ demo.launch
42
+
43
+
44
+
45
+
46
+
47
+
48
+
predict.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ # import imageio
3
+ import numpy as np
4
+ import tensorflow as tf
5
+ from huggingface_hub import from_pretrained_keras
6
+ from tensorflow.keras.optimizers import Adam
7
+
8
+ from .constants import LEARNING_RATE
9
+
10
+
11
+ def predict_label(path):
12
+ frames = load_video(path)
13
+ model = get_model()
14
+ prediction = model.predict(tf.expand_dims(example, axis=0))[0]
15
+ label = np.argmax(pred, axis=0)
16
+
17
+ return label
18
+
19
+
20
+ def load_video(path):
21
+ """
22
+ Load video from path and return a list of frames.
23
+ The video is converted to grayscale because it is the format expected by the model.
24
+ """
25
+ cap = cv2.VideoCapture(path)
26
+ frames = []
27
+ try:
28
+ while True:
29
+ ret, frame = cap.read()
30
+ if not ret:
31
+ break
32
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
33
+ frames.append(frame)
34
+ finally:
35
+ cap.release()
36
+ return np.array(frames)
37
+
38
+
39
+ def get_model():
40
+ """
41
+ Download the model from the Hugging Face Hub and compile it.
42
+ """
43
+ model = from_pretrained_keras("pablorodriper/video-vision-transformer")
44
+
45
+ model.compile(
46
+ optimizer=Adam(learning_rate=LEARNING_RATE),
47
+ loss="sparse_categorical_crossentropy",
48
+ # metrics=[
49
+ # keras.metrics.SparseCategoricalAccuracy(name="accuracy"),
50
+ # keras.metrics.SparseTopKCategoricalAccuracy(5, name="top-5-accuracy"),
51
+ # ],
52
+ )
53
+
54
+ return model