Spaces:
Build error
Build error
import cv2 | |
import numpy as np | |
import tensorflow as tf | |
from huggingface_hub import from_pretrained_keras | |
from tensorflow.keras.optimizers import Adam | |
from constants import LEARNING_RATE | |
model = get_model() | |
def predict_label(path): | |
frames = load_video(path) | |
prediction = model.predict(tf.expand_dims(frames, axis=0))[0] | |
label = np.argmax(prediction, axis=0) | |
return label | |
def load_video(path): | |
""" | |
Load video from path and return a list of frames. | |
The video is converted to grayscale because it is the format expected by the model. | |
""" | |
cap = cv2.VideoCapture(path) | |
frames = [] | |
try: | |
while True: | |
ret, frame = cap.read() | |
if not ret: | |
break | |
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) | |
frames.append(frame) | |
finally: | |
cap.release() | |
return np.array(frames) | |
def get_model(): | |
""" | |
Download the model from the Hugging Face Hub and compile it. | |
""" | |
model = from_pretrained_keras("pablorodriper/video-vision-transformer") | |
model.compile( | |
optimizer=Adam(learning_rate=LEARNING_RATE), | |
loss="sparse_categorical_crossentropy", | |
# metrics=[ | |
# keras.metrics.SparseCategoricalAccuracy(name="accuracy"), | |
# keras.metrics.SparseTopKCategoricalAccuracy(5, name="top-5-accuracy"), | |
# ], | |
) | |
return model | |