TimeSFormer / app.py
thinh-huynh-re's picture
Update
0885c5d
raw
history blame
No virus
5.07 kB
import multiprocessing
import os
import time
from typing import List, Tuple
import cv2
import numpy as np
import pandas as pd
import streamlit as st
import torch
from torch import Tensor
from transformers import AutoFeatureExtractor, TimesformerForVideoClassification
np.random.seed(0)
st.set_page_config(
page_title="TimeSFormer",
page_icon="🧊",
layout="wide",
initial_sidebar_state="expanded",
menu_items={
"Get Help": "https://www.extremelycoolapp.com/help",
"Report a bug": "https://www.extremelycoolapp.com/bug",
"About": "# This is a header. This is an *extremely* cool app!",
},
)
def sample_frame_indices(
clip_len: int, frame_sample_rate: float, seg_len: int
) -> np.ndarray:
converted_len = int(clip_len * frame_sample_rate)
end_idx = np.random.randint(converted_len, seg_len)
start_idx = end_idx - converted_len
indices = np.linspace(start_idx, end_idx, num=clip_len)
indices = np.clip(indices, start_idx, end_idx - 1).astype(np.int64)
return indices
# @st.cache_resource
@st.experimental_singleton
def load_model(model_name: str):
if "base-finetuned-k400" in model_name or "base-finetuned-k600" in model_name:
feature_extractor = AutoFeatureExtractor.from_pretrained(
"MCG-NJU/videomae-base-finetuned-kinetics"
)
else:
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
model = TimesformerForVideoClassification.from_pretrained(model_name)
return feature_extractor, model
def read_video(file_path: str) -> np.ndarray:
cap = cv2.VideoCapture(file_path)
length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) # 1000 frames
print("Number of frames", length)
indices = sample_frame_indices(clip_len=8, frame_sample_rate=4, seg_len=length)
frames: List[np.array] = []
for i in indices:
cap.set(1, i)
ret, frame = cap.read()
if not ret:
continue
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frames.append(frame)
return np.array(frames)
def read_video_decord(file_path: str) -> np.ndarray:
from decord import VideoReader, cpu
videoreader = VideoReader(VIDEO_TMP_PATH, num_threads=1, ctx=cpu(0))
# sample 8 frames
videoreader.seek(0)
indices = sample_frame_indices(
clip_len=8, frame_sample_rate=4, seg_len=len(videoreader)
)
video = videoreader.get_batch(indices).asnumpy()
# print(video.shape) # (8, 720, 1280, 3)
return video
def inference(file_path: str):
video = read_video(file_path)
inputs = feature_extractor(list(video), return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
logits: Tensor = outputs.logits
# model predicts one of the 400 Kinetics-400 classes
predicted_label = logits.argmax(-1).item()
print(model.config.id2label[predicted_label])
TOP_K = 12
# logits = np.squeeze(logits)
logits = logits.squeeze().numpy()
indices = np.argsort(logits)[::-1][:TOP_K]
values = logits[indices]
results: List[Tuple[str, float]] = []
for index, value in zip(indices, values):
predicted_label = model.config.id2label[index]
# print(f"Label: {predicted_label} - {value:.2f}%")
results.append((predicted_label, value))
return pd.DataFrame(results, columns=("Label", "Confidence"))
st.title("TimeSFormer")
with st.expander("INTRODUCTION"):
st.text(
f"""Streamlit demo for TimeSFormer.
Number of CPU(s): {multiprocessing.cpu_count()}
"""
)
model_name = st.selectbox(
"model_name",
(
"facebook/timesformer-base-finetuned-k400",
"facebook/timesformer-base-finetuned-k600",
"facebook/timesformer-base-finetuned-ssv2",
"facebook/timesformer-hr-finetuned-k600",
"facebook/timesformer-hr-finetuned-k400",
"facebook/timesformer-hr-finetuned-ssv2",
"fcakyon/timesformer-large-finetuned-k400",
"fcakyon/timesformer-large-finetuned-k600",
),
)
feature_extractor, model = load_model(model_name)
VIDEO_TMP_PATH = os.path.join("tmp", "tmp.mp4")
uploadedfile = st.file_uploader("Upload file", type=["mp4"])
if uploadedfile is not None:
with st.spinner():
with open(VIDEO_TMP_PATH, "wb") as f:
f.write(uploadedfile.getbuffer())
start_time = time.time()
with st.spinner("Processing..."):
df = inference(VIDEO_TMP_PATH)
end_time = time.time()
st.info(f"{end_time - start_time} seconds")
st.dataframe(df)
st.video(VIDEO_TMP_PATH)
img_file_buffer = st.camera_input("Take a picture")
if img_file_buffer is not None:
# To read image file buffer with OpenCV:
bytes_data = img_file_buffer.getvalue()
cv2_img = cv2.imdecode(np.frombuffer(bytes_data, np.uint8), cv2.IMREAD_COLOR)
# Check the type of cv2_img:
# Should output: <class 'numpy.ndarray'>
st.write(type(cv2_img))
# Check the shape of cv2_img:
# Should output shape: (height, width, channels)
st.write(cv2_img.shape)