JackWong0911 commited on
Commit
ba5befc
1 Parent(s): d999fef

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +491 -10
app.py CHANGED
@@ -18,13 +18,492 @@ from torchvision.transforms import (
18
  RandomHorizontalFlip,
19
  Resize,
20
  )
21
- from transformers import VideoMAEFeatureExtractor, VideoMAEForVideoClassification
 
 
 
 
 
 
22
 
23
- MODEL_CKPT = "sayakpaul/videomae-base-finetuned-kinetics-finetuned-ucf101-subset"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
 
26
- MODEL = VideoMAEForVideoClassification.from_pretrained(MODEL_CKPT).to(DEVICE)
27
- PROCESSOR = VideoMAEFeatureExtractor.from_pretrained(MODEL_CKPT)
28
 
29
  RESIZE_TO = PROCESSOR.size["shortest_edge"]
30
  NUM_FRAMES_TO_SAMPLE = MODEL.config.num_frames
@@ -122,17 +601,19 @@ gr.Interface(
122
  inputs=gr.Video(type="file"),
123
  outputs=gr.Label(num_top_classes=3),
124
  examples=[
125
- ["examples/babycrawling.mp4"],
126
- ["examples/baseball.mp4"],
127
- ["examples/balancebeam.mp4"],
 
 
128
  ],
129
- title="VideoMAE fine-tuned on a subset of UCF-101",
130
  description=(
131
- "Gradio demo for VideoMAE for video classification. To use it, simply upload your video or click one of the"
132
  " examples to load them. Read more at the links below."
133
  ),
134
  article=(
135
- "<div style='text-align: center;'><a href='https://huggingface.co/docs/transformers/model_doc/videomae' target='_blank'>VideoMAE</a>"
136
  " <center><a href='https://huggingface.co/sayakpaul/videomae-base-finetuned-kinetics-finetuned-ucf101-subset' target='_blank'>Fine-tuned Model</a></center></div>"
137
  ),
138
  allow_flagging=False,
 
18
  RandomHorizontalFlip,
19
  Resize,
20
  )
21
+ # my code below
22
+ # import transformers.models.timesformer.modeling_timesformer
23
+ from transformers.models.timesformer.modeling_timesformer import TimeSformerDropPath, TimeSformerAttention, TimesformerIntermediate, TimesformerOutput, TimesformerLayer, TimesformerEncoder, TimesformerModel, TIMESFORMER_INPUTS_DOCSTRING, _CONFIG_FOR_DOC, TimesformerEmbeddings, TimesformerForVideoClassification
24
+ from transformers import TimesformerConfig
25
+ configuration = TimesformerConfig()
26
+ import collections
27
+ from typing import Optional, Tuple, Union
28
 
29
+ import torch
30
+ import torch.nn.functional
31
+ import torch.utils.checkpoint
32
+ from torch import nn
33
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
34
+
35
+ from transformers.activations import ACT2FN
36
+ from transformers.modeling_outputs import BaseModelOutput, ImageClassifierOutput
37
+ from transformers.modeling_utils import PreTrainedModel
38
+ from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
39
+ from transformers.models.timesformer.configuration_timesformer import TimesformerConfig
40
+ class MyTimesformerLayer(TimesformerLayer):
41
+ def __init__(self, config: configuration, layer_index: int) -> None:
42
+ super().__init__()
43
+
44
+ attention_type = config.attention_type
45
+
46
+ drop_path_rates = [
47
+ x.item() for x in torch.linspace(0, config.drop_path_rate, config.num_hidden_layers)
48
+ ] # stochastic depth decay rule
49
+ drop_path_rate = drop_path_rates[layer_index]
50
+
51
+ self.drop_path = TimeSformerDropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
52
+ self.attention = TimeSformerAttention(config)
53
+ self.intermediate = TimesformerIntermediate(config)
54
+ self.output = TimesformerOutput(config)
55
+ self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
56
+ self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
57
+
58
+ self.config = config
59
+ self.attention_type = attention_type
60
+ if attention_type not in ["divided_space_time", "space_only", "joint_space_time"]:
61
+ raise ValueError("Unknown attention type: {}".format(attention_type))
62
+
63
+ # Temporal Attention Parameters
64
+ if self.attention_type == "divided_space_time":
65
+ self.temporal_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
66
+ self.temporal_attention = TimeSformerAttention(config)
67
+ self.temporal_dense = nn.Linear(config.hidden_size, config.hidden_size)
68
+
69
+ def forward(self, hidden_states: torch.Tensor, output_attentions: bool = False):
70
+ num_frames = self.config.num_frames
71
+ num_patch_width = self.config.image_size // self.config.patch_size
72
+ batch_size = hidden_states.shape[0]
73
+ num_spatial_tokens = (hidden_states.size(1) - 1) // num_frames
74
+ num_patch_height = num_spatial_tokens // num_patch_width
75
+
76
+ if self.attention_type in ["space_only", "joint_space_time"]:
77
+ self_attention_outputs = self.attention(
78
+ self.layernorm_before(hidden_states), output_attentions=output_attentions
79
+ )
80
+ attention_output = self_attention_outputs[0]
81
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
82
+
83
+ hidden_states = hidden_states + self.drop_path(attention_output)
84
+
85
+ layer_output = self.layernorm_after(hidden_states)
86
+ layer_output = self.intermediate(layer_output)
87
+ layer_output = self.output(layer_output)
88
+ layer_output = hidden_states + self.drop_path(layer_output)
89
+
90
+ outputs = (layer_output,) + outputs
91
+
92
+ return outputs
93
+
94
+ elif self.attention_type == "divided_space_time":
95
+ # Spatial
96
+ init_cls_token = hidden_states[:, 0, :].unsqueeze(1)
97
+ cls_token = init_cls_token.repeat(1, num_frames, 1)
98
+ cls_token = cls_token.reshape(batch_size * num_frames, 1, cls_token.shape[2])
99
+ spatial_embedding = hidden_states[:, 1:, :]
100
+ spatial_embedding = (
101
+ spatial_embedding.reshape(
102
+ batch_size, num_patch_height, num_patch_width, num_frames, spatial_embedding.shape[2]
103
+ )
104
+ .permute(0, 3, 1, 2, 4)
105
+ .reshape(batch_size * num_frames, num_patch_height * num_patch_width, spatial_embedding.shape[2])
106
+ )
107
+ spatial_embedding = torch.cat((cls_token, spatial_embedding), 1)
108
+
109
+ spatial_attention_outputs = self.attention(
110
+ self.layernorm_before(spatial_embedding), output_attentions=output_attentions
111
+ )
112
+ attention_output = spatial_attention_outputs[0]
113
+ outputs = spatial_attention_outputs[1:] # add self attentions if we output attention weights
114
+
115
+ residual_spatial = self.drop_path(attention_output)
116
+
117
+ # Taking care of CLS token
118
+ cls_token = residual_spatial[:, 0, :]
119
+ cls_token = cls_token.reshape(batch_size, num_frames, cls_token.shape[1])
120
+ cls_token = torch.mean(cls_token, 1, True) # averaging for every frame
121
+ residual_spatial = residual_spatial[:, 1:, :]
122
+ residual_spatial = (
123
+ residual_spatial.reshape(
124
+ batch_size, num_frames, num_patch_height, num_patch_width, residual_spatial.shape[2]
125
+ )
126
+ .permute(0, 2, 3, 1, 4)
127
+ .reshape(batch_size, num_patch_height * num_patch_width * num_frames, residual_spatial.shape[2])
128
+ )
129
+ residual = residual_spatial
130
+ hidden_states = hidden_states[:, 1:, :] + residual_spatial
131
+
132
+ # Temporal
133
+ temporal_embedding = hidden_states
134
+ temporal_embedding = temporal_embedding.reshape(
135
+ batch_size, num_patch_height, num_patch_width, num_frames, temporal_embedding.shape[2]
136
+ ).reshape(batch_size * num_patch_height * num_patch_width, num_frames, temporal_embedding.shape[2])
137
+
138
+ temporal_attention_outputs = self.temporal_attention(
139
+ self.temporal_layernorm(temporal_embedding),
140
+ )
141
+ attention_output = temporal_attention_outputs[0]
142
+
143
+ residual_temporal = self.drop_path(attention_output)
144
+
145
+ residual_temporal = residual_temporal.reshape(
146
+ batch_size, num_patch_height, num_patch_width, num_frames, residual_temporal.shape[2]
147
+ ).reshape(batch_size, num_patch_height * num_patch_width * num_frames, residual_temporal.shape[2])
148
+ residual_temporal = self.temporal_dense(residual_temporal)
149
+ hidden_states = hidden_states + residual_temporal
150
+
151
+ # Mlp
152
+ hidden_states = torch.cat((init_cls_token, hidden_states), 1) + torch.cat((cls_token, residual_temporal), 1)
153
+ layer_output = self.layernorm_after(hidden_states)
154
+ layer_output = self.intermediate(layer_output)
155
+ layer_output = self.output(layer_output)
156
+ layer_output = hidden_states + self.drop_path(layer_output)
157
+
158
+ outputs = (layer_output,) + outputs
159
+
160
+ return outputs
161
+ import transformers.models.timesformer.modeling_timesformer
162
+ class MyTimesformerEncoder(TimesformerEncoder):
163
+ def __init__(self, config: configuration) -> None:
164
+ super().__init__()
165
+ self.config = config
166
+ self.layer = nn.ModuleList([MyTimesformerLayer(config, ind) for ind in range(config.num_hidden_layers)])
167
+ self.gradient_checkpointing = False
168
+
169
+ def forward(
170
+ self,
171
+ hidden_states: torch.Tensor,
172
+ output_attentions: bool = False,
173
+ output_hidden_states: bool = False,
174
+ return_dict: bool = True,
175
+ ) -> Union[tuple, BaseModelOutput]:
176
+ all_hidden_states = () if output_hidden_states else None
177
+ all_self_attentions = () if output_attentions else None
178
+
179
+ for i, layer_module in enumerate(self.layer):
180
+ if output_hidden_states:
181
+ all_hidden_states = all_hidden_states + (hidden_states,)
182
+
183
+ if self.gradient_checkpointing and self.training:
184
+ layer_outputs = self._gradient_checkpointing_func(
185
+ layer_module.__call__,
186
+ hidden_states,
187
+ output_attentions,
188
+ )
189
+ else:
190
+ layer_outputs = layer_module(hidden_states, output_attentions)
191
+
192
+ hidden_states = layer_outputs[0]
193
+
194
+ if output_attentions:
195
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
196
+
197
+ if output_hidden_states:
198
+ all_hidden_states = all_hidden_states + (hidden_states,)
199
+
200
+ if not return_dict:
201
+ return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
202
+ return BaseModelOutput(
203
+ last_hidden_state=hidden_states,
204
+ hidden_states=all_hidden_states,
205
+ attentions=all_self_attentions,
206
+ )
207
+
208
+
209
+ class MyTimesformerModel(TimesformerModel):
210
+ def __init__(self, config: configuration):
211
+ super().__init__(config)
212
+ self.config = config
213
+
214
+ self.embeddings = TimesformerEmbeddings(config)
215
+ self.encoder = TimesformerEncoder(config)
216
+
217
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
218
+
219
+ # Initialize weights and apply final processing
220
+ self.post_init()
221
+
222
+ def get_input_embeddings(self):
223
+ return self.embeddings.patch_embeddings
224
+
225
+ def _prune_heads(self, heads_to_prune):
226
+ """
227
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
228
+ class PreTrainedModel
229
+ """
230
+ for layer, heads in heads_to_prune.items():
231
+ self.encoder.layer[layer].attention.prune_heads(heads)
232
+
233
+ @add_start_docstrings_to_model_forward(TIMESFORMER_INPUTS_DOCSTRING)
234
+ @replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC)
235
+ def forward(
236
+ self,
237
+ pixel_values: torch.FloatTensor,
238
+ output_attentions: Optional[bool] = None,
239
+ output_hidden_states: Optional[bool] = None,
240
+ return_dict: Optional[bool] = None,
241
+ ) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]:
242
+ r"""
243
+ Returns:
244
+
245
+ Examples:
246
+
247
+ ```python
248
+ >>> import av
249
+ >>> import numpy as np
250
+
251
+ >>> from transformers import AutoImageProcessor, TimesformerModel
252
+ >>> from huggingface_hub import hf_hub_download
253
+
254
+ >>> np.random.seed(0)
255
+
256
+
257
+ >>> def read_video_pyav(container, indices):
258
+ ... '''
259
+ ... Decode the video with PyAV decoder.
260
+ ... Args:
261
+ ... container (`av.container.input.InputContainer`): PyAV container.
262
+ ... indices (`List[int]`): List of frame indices to decode.
263
+ ... Returns:
264
+ ... result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3).
265
+ ... '''
266
+ ... frames = []
267
+ ... container.seek(0)
268
+ ... start_index = indices[0]
269
+ ... end_index = indices[-1]
270
+ ... for i, frame in enumerate(container.decode(video=0)):
271
+ ... if i > end_index:
272
+ ... break
273
+ ... if i >= start_index and i in indices:
274
+ ... frames.append(frame)
275
+ ... return np.stack([x.to_ndarray(format="rgb24") for x in frames])
276
+
277
+
278
+ >>> def sample_frame_indices(clip_len, frame_sample_rate, seg_len):
279
+ ... '''
280
+ ... Sample a given number of frame indices from the video.
281
+ ... Args:
282
+ ... clip_len (`int`): Total number of frames to sample.
283
+ ... frame_sample_rate (`int`): Sample every n-th frame.
284
+ ... seg_len (`int`): Maximum allowed index of sample's last frame.
285
+ ... Returns:
286
+ ... indices (`List[int]`): List of sampled frame indices
287
+ ... '''
288
+ ... converted_len = int(clip_len * frame_sample_rate)
289
+ ... end_idx = np.random.randint(converted_len, seg_len)
290
+ ... start_idx = end_idx - converted_len
291
+ ... indices = np.linspace(start_idx, end_idx, num=clip_len)
292
+ ... indices = np.clip(indices, start_idx, end_idx - 1).astype(np.int64)
293
+ ... return indices
294
+
295
+
296
+ >>> # video clip consists of 300 frames (10 seconds at 30 FPS)
297
+ >>> file_path = hf_hub_download(
298
+ ... repo_id="nielsr/video-demo", filename="eating_spaghetti.mp4", repo_type="dataset"
299
+ ... )
300
+ >>> container = av.open(file_path)
301
+
302
+ >>> # sample 8 frames
303
+ >>> indices = sample_frame_indices(clip_len=8, frame_sample_rate=4, seg_len=container.streams.video[0].frames)
304
+ >>> video = read_video_pyav(container, indices)
305
+
306
+ >>> image_processor = AutoImageProcessor.from_pretrained("MCG-NJU/videomae-base")
307
+ >>> model = TimesformerModel.from_pretrained("facebook/timesformer-base-finetuned-k400")
308
+
309
+ >>> # prepare video for the model
310
+ >>> inputs = image_processor(list(video), return_tensors="pt")
311
+
312
+ >>> # forward pass
313
+ >>> outputs = model(**inputs)
314
+ >>> last_hidden_states = outputs.last_hidden_state
315
+ >>> list(last_hidden_states.shape)
316
+ [1, 1569, 768]
317
+ ```"""
318
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
319
+ output_hidden_states = (
320
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
321
+ )
322
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
323
+
324
+ embedding_output = self.embeddings(pixel_values)
325
+
326
+ encoder_outputs = self.encoder(
327
+ embedding_output,
328
+ output_attentions=output_attentions,
329
+ output_hidden_states=output_hidden_states,
330
+ return_dict=return_dict,
331
+ )
332
+ sequence_output = encoder_outputs[0]
333
+ if self.layernorm is not None:
334
+ sequence_output = self.layernorm(sequence_output)
335
+
336
+ if not return_dict:
337
+ return (sequence_output,) + encoder_outputs[1:]
338
+
339
+ return BaseModelOutput(
340
+ last_hidden_state=sequence_output,
341
+ hidden_states=encoder_outputs.hidden_states,
342
+ attentions=encoder_outputs.attentions,
343
+ )
344
+
345
+ class MyTimesformerForVideoClassification(TimesformerForVideoClassification):
346
+ def __init__(self, config):
347
+ super().__init__(config)
348
+
349
+ self.num_labels = config.num_labels
350
+ self.timesformer = MyTimesformerModel(config)
351
+
352
+ # Classifier head
353
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()
354
+
355
+ # Initialize weights and apply final processing
356
+ self.post_init()
357
+
358
+ @add_start_docstrings_to_model_forward(TIMESFORMER_INPUTS_DOCSTRING)
359
+ @replace_return_docstrings(output_type=ImageClassifierOutput, config_class=_CONFIG_FOR_DOC)
360
+ def forward(
361
+ self,
362
+ pixel_values: Optional[torch.Tensor] = None,
363
+ labels: Optional[torch.Tensor] = None,
364
+ output_attentions: Optional[bool] = None,
365
+ output_hidden_states: Optional[bool] = None,
366
+ return_dict: Optional[bool] = None,
367
+ ) -> Union[Tuple, ImageClassifierOutput]:
368
+ r"""
369
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
370
+ Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
371
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
372
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
373
+
374
+ Returns:
375
+
376
+ Examples:
377
+
378
+ ```python
379
+ >>> import av
380
+ >>> import torch
381
+ >>> import numpy as np
382
+
383
+ >>> from transformers import AutoImageProcessor, TimesformerForVideoClassification
384
+ >>> from huggingface_hub import hf_hub_download
385
+
386
+ >>> np.random.seed(0)
387
+
388
+
389
+ >>> def read_video_pyav(container, indices):
390
+ ... '''
391
+ ... Decode the video with PyAV decoder.
392
+ ... Args:
393
+ ... container (`av.container.input.InputContainer`): PyAV container.
394
+ ... indices (`List[int]`): List of frame indices to decode.
395
+ ... Returns:
396
+ ... result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3).
397
+ ... '''
398
+ ... frames = []
399
+ ... container.seek(0)
400
+ ... start_index = indices[0]
401
+ ... end_index = indices[-1]
402
+ ... for i, frame in enumerate(container.decode(video=0)):
403
+ ... if i > end_index:
404
+ ... break
405
+ ... if i >= start_index and i in indices:
406
+ ... frames.append(frame)
407
+ ... return np.stack([x.to_ndarray(format="rgb24") for x in frames])
408
+
409
+
410
+ >>> def sample_frame_indices(clip_len, frame_sample_rate, seg_len):
411
+ ... '''
412
+ ... Sample a given number of frame indices from the video.
413
+ ... Args:
414
+ ... clip_len (`int`): Total number of frames to sample.
415
+ ... frame_sample_rate (`int`): Sample every n-th frame.
416
+ ... seg_len (`int`): Maximum allowed index of sample's last frame.
417
+ ... Returns:
418
+ ... indices (`List[int]`): List of sampled frame indices
419
+ ... '''
420
+ ... converted_len = int(clip_len * frame_sample_rate)
421
+ ... end_idx = np.random.randint(converted_len, seg_len)
422
+ ... start_idx = end_idx - converted_len
423
+ ... indices = np.linspace(start_idx, end_idx, num=clip_len)
424
+ ... indices = np.clip(indices, start_idx, end_idx - 1).astype(np.int64)
425
+ ... return indices
426
+
427
+
428
+ >>> # video clip consists of 300 frames (10 seconds at 30 FPS)
429
+ >>> file_path = hf_hub_download(
430
+ ... repo_id="nielsr/video-demo", filename="eating_spaghetti.mp4", repo_type="dataset"
431
+ ... )
432
+ >>> container = av.open(file_path)
433
+
434
+ >>> # sample 8 frames
435
+ >>> indices = sample_frame_indices(clip_len=8, frame_sample_rate=1, seg_len=container.streams.video[0].frames)
436
+ >>> video = read_video_pyav(container, indices)
437
+
438
+ >>> image_processor = AutoImageProcessor.from_pretrained("MCG-NJU/videomae-base-finetuned-kinetics")
439
+ >>> model = TimesformerForVideoClassification.from_pretrained("facebook/timesformer-base-finetuned-k400")
440
+
441
+ >>> inputs = image_processor(list(video), return_tensors="pt")
442
+
443
+ >>> with torch.no_grad():
444
+ ... outputs = model(**inputs)
445
+ ... logits = outputs.logits
446
+
447
+ >>> # model predicts one of the 400 Kinetics-400 classes
448
+ >>> predicted_label = logits.argmax(-1).item()
449
+ >>> print(model.config.id2label[predicted_label])
450
+ eating spaghetti
451
+ ```"""
452
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
453
+
454
+ outputs = self.timesformer(
455
+ pixel_values,
456
+ output_attentions=output_attentions,
457
+ output_hidden_states=output_hidden_states,
458
+ return_dict=return_dict,
459
+ )
460
+
461
+ sequence_output = outputs[0][:, 0]
462
+
463
+ logits = self.classifier(sequence_output)
464
+
465
+ loss = None
466
+ if labels is not None:
467
+ if self.config.problem_type is None:
468
+ if self.num_labels == 1:
469
+ self.config.problem_type = "regression"
470
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
471
+ self.config.problem_type = "single_label_classification"
472
+ else:
473
+ self.config.problem_type = "multi_label_classification"
474
+
475
+ if self.config.problem_type == "regression":
476
+ loss_fct = MSELoss()
477
+ if self.num_labels == 1:
478
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
479
+ else:
480
+ loss = loss_fct(logits, labels)
481
+ elif self.config.problem_type == "single_label_classification":
482
+ loss_fct = CrossEntropyLoss()
483
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
484
+ elif self.config.problem_type == "multi_label_classification":
485
+ loss_fct = BCEWithLogitsLoss()
486
+ loss = loss_fct(logits, labels)
487
+
488
+ if not return_dict:
489
+ output = (logits,) + outputs[1:]
490
+ return ((loss,) + output) if loss is not None else output
491
+
492
+ return ImageClassifierOutput(
493
+ loss=loss,
494
+ logits=logits,
495
+ hidden_states=outputs.hidden_states,
496
+ attentions=outputs.attentions,
497
+ )
498
+
499
+
500
+ from transformers import AutoImageProcessor
501
+
502
+ MODEL_CKPT = "JackWong0911/timesformer-base-finetuned-k400-kinetic400-subset-epoch6real-num_frame_10_myViT2_more_data"
503
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
504
 
505
+ MODEL = MyTimesformerForVideoClassification.from_pretrained(MODEL_CKPT).to(DEVICE)
506
+ PROCESSOR = AutoImageProcessor.from_pretrained("MCG-NJU/videomae-base-finetuned-kinetics")
507
 
508
  RESIZE_TO = PROCESSOR.size["shortest_edge"]
509
  NUM_FRAMES_TO_SAMPLE = MODEL.config.num_frames
 
601
  inputs=gr.Video(type="file"),
602
  outputs=gr.Label(num_top_classes=3),
603
  examples=[
604
+ ["examples/archery.mp4"],
605
+ ["examples/bowling.mp4"],
606
+ ["examples/flying_kite.mp4"],
607
+ ["examples/high_jump.mp4"],
608
+ ["examples/marching.mp4"],
609
  ],
610
+ title="MyViT fine-tuned on a subset of Kinetics400",
611
  description=(
612
+ "Gradio demo for MyViT for video classification. To use it, simply upload your video or click one of the"
613
  " examples to load them. Read more at the links below."
614
  ),
615
  article=(
616
+ "<div style='text-align: center;'><a href='https://huggingface.co/docs/transformers/model_doc/videomae' target='_blank'>MyViT</a>"
617
  " <center><a href='https://huggingface.co/sayakpaul/videomae-base-finetuned-kinetics-finetuned-ucf101-subset' target='_blank'>Fine-tuned Model</a></center></div>"
618
  ),
619
  allow_flagging=False,