thinh-huynh-re commited on
Commit
e94e369
·
1 Parent(s): 1bcf2a0
Files changed (2) hide show
  1. stream.py +181 -0
  2. utils/frame_rate.py +5 -3
stream.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from streamlit_webrtc import webrtc_streamer
2
+ import numpy as np
3
+ import streamlit as st
4
+
5
+ import numpy as np
6
+ import av
7
+ import threading
8
+
9
+ import multiprocessing
10
+ from typing import List, Optional, Tuple
11
+
12
+ from pandas import DataFrame
13
+
14
+ import numpy as np
15
+ import pandas as pd
16
+ import streamlit as st
17
+ import torch
18
+ from torch import Tensor
19
+ from transformers import AutoFeatureExtractor, TimesformerForVideoClassification
20
+
21
+
22
+ from utils.frame_rate import FrameRate
23
+
24
+ np.random.seed(0)
25
+
26
+ st.set_page_config(
27
+ page_title="TimeSFormer",
28
+ page_icon="🧊",
29
+ layout="wide",
30
+ initial_sidebar_state="expanded",
31
+ menu_items={
32
+ "Get Help": "https://www.extremelycoolapp.com/help",
33
+ "Report a bug": "https://www.extremelycoolapp.com/bug",
34
+ "About": "# This is a header. This is an *extremely* cool app!",
35
+ },
36
+ )
37
+
38
+
39
+ @st.cache_resource
40
+ # @st.experimental_singleton
41
+ def load_model(model_name: str):
42
+ if "base-finetuned-k400" in model_name or "base-finetuned-k600" in model_name:
43
+ feature_extractor = AutoFeatureExtractor.from_pretrained(
44
+ "MCG-NJU/videomae-base-finetuned-kinetics"
45
+ )
46
+ else:
47
+ feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
48
+ model = TimesformerForVideoClassification.from_pretrained(model_name)
49
+ return feature_extractor, model
50
+
51
+
52
+ lock = threading.Lock()
53
+
54
+ rtc_configuration = {
55
+ "iceServers": [
56
+ {
57
+ "urls": "turn:relay1.expressturn.com:3478",
58
+ "username": "efBRTY571ATWBRMP36",
59
+ "credential": "pGcX1BPH5fMmZJc5",
60
+ },
61
+ # {
62
+ # "urls": [
63
+ # "stun:stun1.l.google.com:19302",
64
+ # "stun:stun2.l.google.com:19302",
65
+ # "stun:stun3.l.google.com:19302",
66
+ # "stun:stun4.l.google.com:19302",
67
+ # ]
68
+ # },
69
+ ],
70
+ }
71
+
72
+
73
+ def inference():
74
+ if not img_container.ready:
75
+ return
76
+
77
+ inputs = feature_extractor(list(img_container.imgs), return_tensors="pt")
78
+
79
+ with torch.no_grad():
80
+ outputs = model(**inputs)
81
+ logits: Tensor = outputs.logits
82
+
83
+ # model predicts one of the 400 Kinetics-400 classes
84
+ max_index = logits.argmax(-1).item()
85
+ predicted_label = model.config.id2label[max_index]
86
+
87
+ img_container.frame_rate.label = f"{predicted_label}_{logits[0][max_index]:.2f}%"
88
+
89
+ TOP_K = 12
90
+ # logits = np.squeeze(logits)
91
+ logits = logits.squeeze().numpy()
92
+ indices = np.argsort(logits)[::-1][:TOP_K]
93
+ values = logits[indices]
94
+
95
+ results: List[Tuple[str, float]] = []
96
+ for index, value in zip(indices, values):
97
+ predicted_label = model.config.id2label[index]
98
+ # print(f"Label: {predicted_label} - {value:.2f}%")
99
+ results.append((predicted_label, value))
100
+
101
+ img_container.rs = pd.DataFrame(results, columns=("Label", "Confidence"))
102
+
103
+
104
+ class ImgContainer:
105
+ def __init__(self, frames_per_video: int = 8) -> None:
106
+ self.img: Optional[np.ndarray] = None # raw image
107
+ self.frame_rate: FrameRate = FrameRate()
108
+ self.imgs: List[np.ndarray] = []
109
+ self.frame_rate.reset()
110
+ self.frames_per_video = frames_per_video
111
+ self.rs: Optional[DataFrame] = None
112
+
113
+ def add_frame(self, frame: np.ndarray):
114
+ if len(img_container.imgs) >= frames_per_video:
115
+ self.imgs.pop(0)
116
+ self.imgs.append(frame)
117
+
118
+ @property
119
+ def ready(self):
120
+ return len(img_container.imgs) == self.frames_per_video
121
+
122
+
123
+ def video_frame_callback(frame: av.VideoFrame) -> av.VideoFrame:
124
+ img = frame.to_ndarray(format="bgr24")
125
+ with lock:
126
+ img_container.img = img
127
+ img_container.frame_rate.count()
128
+ img_container.add_frame(img)
129
+ inference()
130
+ img = img_container.frame_rate.show_fps(img)
131
+
132
+ return av.VideoFrame.from_ndarray(img, format="bgr24")
133
+
134
+
135
+ def get_frames_per_video(model_name: str) -> int:
136
+ if "base-finetuned" in model_name:
137
+ return 8
138
+ elif "hr-finetuned" in model_name:
139
+ return 16
140
+ else:
141
+ return 96
142
+
143
+
144
+ st.title("TimeSFormer")
145
+
146
+ with st.expander("INTRODUCTION"):
147
+ st.text(
148
+ f"""Streamlit demo for TimeSFormer.
149
+ Number of CPU(s): {multiprocessing.cpu_count()}
150
+ """
151
+ )
152
+
153
+ model_name = st.selectbox(
154
+ "model_name",
155
+ (
156
+ "facebook/timesformer-base-finetuned-k400",
157
+ "facebook/timesformer-base-finetuned-k600",
158
+ "facebook/timesformer-base-finetuned-ssv2",
159
+ "facebook/timesformer-hr-finetuned-k600",
160
+ "facebook/timesformer-hr-finetuned-k400",
161
+ "facebook/timesformer-hr-finetuned-ssv2",
162
+ "fcakyon/timesformer-large-finetuned-k400",
163
+ "fcakyon/timesformer-large-finetuned-k600",
164
+ ),
165
+ )
166
+ feature_extractor, model = load_model(model_name)
167
+
168
+
169
+ frames_per_video = get_frames_per_video(model_name)
170
+ st.info(f"Frames per video: {frames_per_video}")
171
+
172
+ img_container = ImgContainer(frames_per_video)
173
+
174
+ ctx = st.session_state.ctx = webrtc_streamer(
175
+ key="snapshot",
176
+ video_frame_callback=video_frame_callback,
177
+ rtc_configuration=rtc_configuration,
178
+ )
179
+
180
+ if img_container.rs is not None:
181
+ st.dataframe(img_container.rs)
utils/frame_rate.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import numpy as np
2
  import time, cv2
3
 
@@ -5,9 +6,10 @@ import time, cv2
5
  class FrameRate:
6
  def __init__(self) -> None:
7
  self.c: int = 0
8
- self.start_time: float = None
9
- self.NO_FRAMES = 100
10
  self.fps: float = -1
 
11
 
12
  def reset(self) -> None:
13
  self.start_time = time.time()
@@ -26,7 +28,7 @@ class FrameRate:
26
  if self.fps != -1:
27
  return cv2.putText(
28
  image,
29
- f"FPS {self.fps:.0f}",
30
  (50, 50),
31
  cv2.FONT_HERSHEY_SIMPLEX,
32
  fontScale=1,
 
1
+ from typing import Optional
2
  import numpy as np
3
  import time, cv2
4
 
 
6
  class FrameRate:
7
  def __init__(self) -> None:
8
  self.c: int = 0
9
+ self.start_time: Optional[float] = None
10
+ self.NO_FRAMES = 10
11
  self.fps: float = -1
12
+ self.label: str = ""
13
 
14
  def reset(self) -> None:
15
  self.start_time = time.time()
 
28
  if self.fps != -1:
29
  return cv2.putText(
30
  image,
31
+ f"FPS {self.fps:.0f} _ {self.label}",
32
  (50, 50),
33
  cv2.FONT_HERSHEY_SIMPLEX,
34
  fontScale=1,