conv-lstm / app.py
nouamanetazi's picture
nouamanetazi HF staff
add credits
fa8b0cf
raw
history blame
3.99 kB
import os
import yaml
import gradio as gr
import numpy as np
import imageio, cv2
from moviepy.editor import *
from skimage.transform import resize
from skimage import img_as_ubyte
from skimage.color import rgb2gray
from huggingface_hub.keras_mixin import from_pretrained_keras
# load model
model = from_pretrained_keras("keras-io/conv-lstm")
# Examples
samples = []
example_source = os.listdir('asset/source')
for video in example_source:
samples.append([f'asset/source/{video}', 0.5, True])
def inference(source,
split_pred = 0.4, # predict 0.6% of video
predict_one = False, # Whether to predict a sliding one frame or all frames at once
output_name = 'output.mp4',
output_path = 'asset/output',
cpu = False,
):
# source
reader = imageio.get_reader(source)
fps = reader.get_meta_data()['fps']
source_video = []
try:
for im in reader:
source_video.append(im)
except RuntimeError:
pass
reader.close()
source_video = [rgb2gray(resize(frame, (64, 64)))[..., np.newaxis] for frame in source_video]
example = np.array(source_video)
print(example.shape)
# Pick the first/last ten frames from the example.
start_pred_id = int(split_pred * example.shape[0]) # prediction starts from frame start_pred_id
frames = example[:start_pred_id, ...]
original_frames = example[start_pred_id:, ...]
new_predictions = np.zeros(shape=(example.shape[0] - start_pred_id, *frames[0].shape))
# Predict a new set of 10 frames.
for i in range(example.shape[0] - start_pred_id):
# Extract the model's prediction and post-process it.
if predict_one:
frames = example[: start_pred_id + i + 1, ...]
else:
frames = np.concatenate((example[: start_pred_id+1 , ...], new_predictions[:i, ...]), axis=0)
new_prediction = model.predict(np.expand_dims(frames, axis=0))
new_prediction = np.squeeze(new_prediction, axis=0)
predicted_frame = np.expand_dims(new_prediction[-1, ...], axis=0)
# Extend the set of prediction frames.
new_predictions[i] = predicted_frame
# Create and save MP4s for each of the ground truth/prediction images.
def postprocess(frame_set, save_file):
# Construct a GIF from the selected video frames.
current_frames = np.squeeze(frame_set)
current_frames = current_frames[..., np.newaxis] * np.ones(3)
current_frames = (current_frames * 255).astype(np.uint8)
current_frames = list(current_frames)
print(f'{output_path}/{save_file}')
imageio.mimsave(f'{output_path}/{save_file}', current_frames, fps=fps)
# save video
os.makedirs(output_path, exist_ok=True)
postprocess(original_frames, "original.mp4")
postprocess(new_predictions, output_name)
return f'{output_path}/{output_name}', f'{output_path}/original.mp4'
article = "<div style='text-align: center;'><a href='https://nouamanetazi.me/' target='_blank'>Space by Nouamane Tazi</a><br><a href='https://keras.io/examples/vision/conv_lstm/' target='_blank'>Keras example by Amogh Joshi</a></div>"
iface = gr.Interface(
inference, # main function
inputs = [
gr.inputs.Video(label='Video', type='mp4'),
gr.inputs.Slider(minimum=.1, maximum=.9, default=.5, step=.001, label="prediction start"),
gr.inputs.Checkbox(label="predict one frame only", default=True),
],
outputs = [
gr.outputs.Video(label='result'), # generated video
gr.outputs.Video(label='ground truth') # same part of original video
],
title = 'Next-Frame Video Prediction with Convolutional LSTMs',
# description = "This app is an unofficial demo web app of the Next-Frame Video Prediction with Convolutional LSTMs by Keras.",
article = article,
examples = samples,
).launch(enable_queue=True, cache_examples=True)