File size: 3,579 Bytes
df8cba4
 
 
 
 
 
 
 
967933e
 
df8cba4
967933e
 
 
 
df8cba4
 
 
 
967933e
 
df8cba4
 
967933e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
df8cba4
967933e
df8cba4
967933e
 
df8cba4
 
 
967933e
df8cba4
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
# -*- coding: utf-8 -*-
"""🎬 Keras Video Classification CNN-RNN model

Spaces for showing the model usage.

Author:
    - Thomas Chaigneau @ChainYo
"""
import os
import cv2
import gradio as gr
import numpy as np

from tensorflow import keras
from tensorflow_docs.vis import embed

from huggingface_hub import from_pretrained_keras


IMG_SIZE = 224
NUM_FEATURES = 2048

model = from_pretrained_keras("ChainYo/video-classification-cnn-rnn")
samples = []
for file in os.listdir("samples"):
    print(file)
    tag = file.split("_")[1]
    samples.append([f"samples/{file}", 25])


def crop_center_square(frame):
    y, x = frame.shape[0:2]
    min_dim = min(y, x)
    start_x = (x // 2) - (min_dim // 2)
    start_y = (y // 2) - (min_dim // 2)
    return frame[start_y : start_y + min_dim, start_x : start_x + min_dim]


def load_video(path, max_frames=0, resize=(IMG_SIZE, IMG_SIZE)):
    cap = cv2.VideoCapture(path)
    frames = []
    try:
        while True:
            ret, frame = cap.read()
            if not ret:
                break
            frame = crop_center_square(frame)
            frame = cv2.resize(frame, resize)
            frame = frame[:, :, [2, 1, 0]]
            frames.append(frame)

            if len(frames) == max_frames:
                break
    finally:
        cap.release()
    return np.array(frames)


def build_feature_extractor():
    feature_extractor = keras.applications.InceptionV3(
        weights="imagenet",
        include_top=False,
        pooling="avg",
        input_shape=(IMG_SIZE, IMG_SIZE, 3),
    )
    preprocess_input = keras.applications.inception_v3.preprocess_input

    inputs = keras.Input((IMG_SIZE, IMG_SIZE, 3))
    preprocessed = preprocess_input(inputs)

    outputs = feature_extractor(preprocessed)
    return keras.Model(inputs, outputs, name="feature_extractor")


feature_extractor = build_feature_extractor()

def prepare_video(frames, max_seq_length: int = 20):
    frames = frames[None, ...]
    frame_mask = np.zeros(shape=(1, max_seq_length,), dtype="bool")
    frame_features = np.zeros(shape=(1, max_seq_length, NUM_FEATURES), dtype="float32")

    for i, batch in enumerate(frames):
        video_length = batch.shape[0]
        length = min(max_seq_length, video_length)
        for j in range(length):
            frame_features[i, j, :] = feature_extractor.predict(batch[None, j, :])
        frame_mask[i, :length] = 1  # 1 = not masked, 0 = masked

    return frame_features, frame_mask


def sequence_prediction(path):
    class_vocab = ["CricketShot", "PlayingCello", "Punch", "ShavingBeard", "TennisSwing"]

    frames = load_video(path)
    frame_features, frame_mask = prepare_video(frames)
    probabilities = model.predict([frame_features, frame_mask])[0]

    for i in np.argsort(probabilities)[::-1]:
        print(f"  {class_vocab[i]}: {probabilities[i] * 100:5.2f}%")
    return frames


def to_gif(images):
    converted_images = images.astype(np.uint8)
    return embed.embed_file(converted_images, format="gif")


article = article = "<div style='text-align: center;'><a href='https://github.com/ChainYo' target='_blank'>Space by Thomas Chaigneau</a><br><a href='https://keras.io/examples/vision/video_classification/' target='_blank'>Keras example by Sayak Paul</a></div>"
app = gr.Interface(
    sequence_prediction,
    inputs=[gr.inputs.Video(label="Video", type="mp4")],
    outputs=[],
    title="Keras Video Classification CNN-RNN model",
    description="Keras Working Group",
    article=article,
    examples=samples
).launch(enable_queue=True, cache_examples=True)