TimeSFormer / app.py
thinh-huynh-re's picture
Update
1bcf2a0
import multiprocessing
import os
import time
from typing import List, Tuple
import cv2
import numpy as np
import pandas as pd
import streamlit as st
import torch
from torch import Tensor
from transformers import AutoFeatureExtractor, TimesformerForVideoClassification
np.random.seed(0)
st.set_page_config(
page_title="TimeSFormer",
page_icon="🧊",
layout="wide",
initial_sidebar_state="expanded",
menu_items={
"Get Help": "https://www.extremelycoolapp.com/help",
"Report a bug": "https://www.extremelycoolapp.com/bug",
"About": "# This is a header. This is an *extremely* cool app!",
},
)
def sample_frame_indices(
clip_len: int, frame_sample_rate: float, seg_len: int
) -> np.ndarray:
converted_len = int(clip_len * frame_sample_rate)
end_idx = np.random.randint(converted_len, seg_len)
start_idx = end_idx - converted_len
indices = np.linspace(start_idx, end_idx, num=clip_len)
indices = np.clip(indices, start_idx, end_idx - 1).astype(np.int64)
return indices
# @st.cache_resource
@st.experimental_singleton
def load_model(model_name: str):
if "base-finetuned-k400" in model_name or "base-finetuned-k600" in model_name:
feature_extractor = AutoFeatureExtractor.from_pretrained(
"MCG-NJU/videomae-base-finetuned-kinetics"
)
else:
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
model = TimesformerForVideoClassification.from_pretrained(model_name)
return feature_extractor, model
def read_video(file_path: str, frames_per_video: int = 8) -> np.ndarray:
cap = cv2.VideoCapture(file_path)
length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) # 1000 frames
print("Number of frames", length)
indices = sample_frame_indices(
clip_len=frames_per_video, frame_sample_rate=4, seg_len=length
)
frames: List[np.array] = []
for i in indices:
cap.set(1, i)
ret, frame = cap.read()
if not ret:
continue
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frames.append(frame)
return np.array(frames)
def read_video_decord(file_path: str) -> np.ndarray:
from decord import VideoReader, cpu
videoreader = VideoReader(VIDEO_TMP_PATH, num_threads=1, ctx=cpu(0))
# sample 8 frames
videoreader.seek(0)
indices = sample_frame_indices(
clip_len=8, frame_sample_rate=4, seg_len=len(videoreader)
)
video = videoreader.get_batch(indices).asnumpy()
# print(video.shape) # (8, 720, 1280, 3)
return video
def inference(file_path: str, frames_per_video: int = 8):
video = read_video(file_path, frames_per_video)
inputs = feature_extractor(list(video), return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
logits: Tensor = outputs.logits
# model predicts one of the 400 Kinetics-400 classes
predicted_label = logits.argmax(-1).item()
print(model.config.id2label[predicted_label])
TOP_K = 12
# logits = np.squeeze(logits)
logits = logits.squeeze().numpy()
indices = np.argsort(logits)[::-1][:TOP_K]
values = logits[indices]
results: List[Tuple[str, float]] = []
for index, value in zip(indices, values):
predicted_label = model.config.id2label[index]
# print(f"Label: {predicted_label} - {value:.2f}%")
results.append((predicted_label, value))
return pd.DataFrame(results, columns=("Label", "Confidence"))
def get_frames_per_video(model_name: str) -> int:
if "base-finetuned" in model_name:
return 8
elif "hr-finetuned" in model_name:
return 16
else:
return 96
st.title("TimeSFormer")
with st.expander("INTRODUCTION"):
st.text(
f"""Streamlit demo for TimeSFormer.
Number of CPU(s): {multiprocessing.cpu_count()}
"""
)
model_name = st.selectbox(
"model_name",
(
"facebook/timesformer-base-finetuned-k400",
"facebook/timesformer-base-finetuned-k600",
"facebook/timesformer-base-finetuned-ssv2",
"facebook/timesformer-hr-finetuned-k600",
"facebook/timesformer-hr-finetuned-k400",
"facebook/timesformer-hr-finetuned-ssv2",
"fcakyon/timesformer-large-finetuned-k400",
"fcakyon/timesformer-large-finetuned-k600",
),
)
feature_extractor, model = load_model(model_name)
frames_per_video = get_frames_per_video(model_name)
st.info(f"Frames per video: {frames_per_video}")
VIDEO_TMP_PATH = os.path.join("tmp", "tmp.mp4")
uploadedfile = st.file_uploader("Upload file", type=["mp4"])
if uploadedfile is not None:
with st.spinner():
with open(VIDEO_TMP_PATH, "wb") as f:
f.write(uploadedfile.getbuffer())
start_time = time.time()
with st.spinner("Processing..."):
df = inference(VIDEO_TMP_PATH, frames_per_video)
end_time = time.time()
st.info(f"{end_time - start_time} seconds")
st.dataframe(df)
st.video(VIDEO_TMP_PATH)