chainyo commited on
Commit
967933e
1 Parent(s): df8cba4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +96 -7
app.py CHANGED
@@ -6,24 +6,113 @@ Spaces for showing the model usage.
6
  Author:
7
  - Thomas Chaigneau @ChainYo
8
  """
 
 
9
  import gradio as gr
 
 
 
 
10
 
11
  from huggingface_hub import from_pretrained_keras
12
 
13
 
14
- def inference():
15
- """
16
- Inference function.
17
- """
18
 
19
  model = from_pretrained_keras("ChainYo/video-classification-cnn-rnn")
20
- samples = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
 
22
  app = gr.Interface(
23
- inference,
24
- inputs=[],
25
  outputs=[],
26
  title="Keras Video Classification CNN-RNN model",
27
  description="Keras Working Group",
 
28
  examples=samples
29
  ).launch(enable_queue=True, cache_examples=True)
 
6
  Author:
7
  - Thomas Chaigneau @ChainYo
8
  """
9
+ import os
10
+ import cv2
11
  import gradio as gr
12
+ import numpy as np
13
+
14
+ from tensorflow import keras
15
+ from tensorflow_docs.vis import embed
16
 
17
  from huggingface_hub import from_pretrained_keras
18
 
19
 
20
+ IMG_SIZE = 224
21
+ NUM_FEATURES = 2048
 
 
22
 
23
  model = from_pretrained_keras("ChainYo/video-classification-cnn-rnn")
24
+ samples = []
25
+ for file in os.listdir("samples"):
26
+ print(file)
27
+ tag = file.split("_")[1]
28
+ samples.append([f"samples/{file}", 25])
29
+
30
+
31
+ def crop_center_square(frame):
32
+ y, x = frame.shape[0:2]
33
+ min_dim = min(y, x)
34
+ start_x = (x // 2) - (min_dim // 2)
35
+ start_y = (y // 2) - (min_dim // 2)
36
+ return frame[start_y : start_y + min_dim, start_x : start_x + min_dim]
37
+
38
+
39
+ def load_video(path, max_frames=0, resize=(IMG_SIZE, IMG_SIZE)):
40
+ cap = cv2.VideoCapture(path)
41
+ frames = []
42
+ try:
43
+ while True:
44
+ ret, frame = cap.read()
45
+ if not ret:
46
+ break
47
+ frame = crop_center_square(frame)
48
+ frame = cv2.resize(frame, resize)
49
+ frame = frame[:, :, [2, 1, 0]]
50
+ frames.append(frame)
51
+
52
+ if len(frames) == max_frames:
53
+ break
54
+ finally:
55
+ cap.release()
56
+ return np.array(frames)
57
+
58
+
59
+ def build_feature_extractor():
60
+ feature_extractor = keras.applications.InceptionV3(
61
+ weights="imagenet",
62
+ include_top=False,
63
+ pooling="avg",
64
+ input_shape=(IMG_SIZE, IMG_SIZE, 3),
65
+ )
66
+ preprocess_input = keras.applications.inception_v3.preprocess_input
67
+
68
+ inputs = keras.Input((IMG_SIZE, IMG_SIZE, 3))
69
+ preprocessed = preprocess_input(inputs)
70
+
71
+ outputs = feature_extractor(preprocessed)
72
+ return keras.Model(inputs, outputs, name="feature_extractor")
73
+
74
+
75
+ feature_extractor = build_feature_extractor()
76
+
77
+ def prepare_video(frames, max_seq_length: int = 20):
78
+ frames = frames[None, ...]
79
+ frame_mask = np.zeros(shape=(1, max_seq_length,), dtype="bool")
80
+ frame_features = np.zeros(shape=(1, max_seq_length, NUM_FEATURES), dtype="float32")
81
+
82
+ for i, batch in enumerate(frames):
83
+ video_length = batch.shape[0]
84
+ length = min(max_seq_length, video_length)
85
+ for j in range(length):
86
+ frame_features[i, j, :] = feature_extractor.predict(batch[None, j, :])
87
+ frame_mask[i, :length] = 1 # 1 = not masked, 0 = masked
88
+
89
+ return frame_features, frame_mask
90
+
91
+
92
+ def sequence_prediction(path):
93
+ class_vocab = ["CricketShot", "PlayingCello", "Punch", "ShavingBeard", "TennisSwing"]
94
+
95
+ frames = load_video(path)
96
+ frame_features, frame_mask = prepare_video(frames)
97
+ probabilities = model.predict([frame_features, frame_mask])[0]
98
+
99
+ for i in np.argsort(probabilities)[::-1]:
100
+ print(f" {class_vocab[i]}: {probabilities[i] * 100:5.2f}%")
101
+ return frames
102
+
103
+
104
+ def to_gif(images):
105
+ converted_images = images.astype(np.uint8)
106
+ return embed.embed_file(converted_images, format="gif")
107
+
108
 
109
+ article = article = "<div style='text-align: center;'><a href='https://github.com/ChainYo' target='_blank'>Space by Thomas Chaigneau</a><br><a href='https://keras.io/examples/vision/video_classification/' target='_blank'>Keras example by Sayak Paul</a></div>"
110
  app = gr.Interface(
111
+ sequence_prediction,
112
+ inputs=[gr.inputs.Video(label="Video", type="mp4")],
113
  outputs=[],
114
  title="Keras Video Classification CNN-RNN model",
115
  description="Keras Working Group",
116
+ article=article,
117
  examples=samples
118
  ).launch(enable_queue=True, cache_examples=True)