gullalc commited on
Commit
23e38ea
1 Parent(s): 6c3fbe4

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +53 -1
README.md CHANGED
@@ -30,4 +30,56 @@ Model achieves:
30
 
31
  Class-wise accuracies:
32
  - *shot scale*: ECS - 90.92%, CS - 83.2%, MS - 85.0%, FS - 89.71%, LS - 94.55%
33
- - *shot movement*: Static - 94.6%, Motion - 87.7%, Pull - 57.5%, Push - 66.82%
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
  Class-wise accuracies:
32
  - *shot scale*: ECS - 90.92%, CS - 83.2%, MS - 85.0%, FS - 89.71%, LS - 94.55%
33
+ - *shot movement*: Static - 94.6%, Motion - 87.7%, Pull - 57.5%, Push - 66.82%
34
+
35
+
36
+ ## Model Definition
37
+ ```python
38
+ from transformers import VideoMAEImageProcessor, VideoMAEModel, VideoMAEConfig, PreTrainedModel
39
+
40
+ class CustomVideoMAEConfig(VideoMAEConfig):
41
+ def __init__(self, scale_label2id=None, scale_id2label=None, movement_label2id=None, movement_id2label=None, **kwargs):
42
+ super().__init__(**kwargs)
43
+ self.scale_label2id = scale_label2id if scale_label2id is not None else {}
44
+ self.scale_id2label = scale_id2label if scale_id2label is not None else {}
45
+ self.movement_label2id = movement_label2id if movement_label2id is not None else {}
46
+ self.movement_id2label = movement_id2label if movement_id2label is not None else {}
47
+
48
+
49
+ class CustomModel(PreTrainedModel):
50
+ config_class = CustomVideoMAEConfig
51
+
52
+ def __init__(self, config, model_name, scale_num_classes, movement_num_classes):
53
+ super().__init__(config)
54
+ self.vmae = VideoMAEModel.from_pretrained(model_name, ignore_mismatched_sizes=True)
55
+ self.fc_norm = nn.LayerNorm(config.hidden_size) if config.use_mean_pooling else None
56
+ self.scale_cf = nn.Linear(config.hidden_size, scale_num_classes)
57
+ self.movement_cf = nn.Linear(config.hidden_size, movement_num_classes)
58
+
59
+ def forward(self, pixel_values, scale_labels=None, movement_labels=None):
60
+
61
+ vmae_outputs = self.vmae(pixel_values)
62
+ sequence_output = vmae_outputs[0]
63
+
64
+ if self.fc_norm is not None:
65
+ sequence_output = self.fc_norm(sequence_output.mean(1))
66
+ else:
67
+ sequence_output = sequence_output[:, 0]
68
+
69
+ scale_logits = self.scale_cf(sequence_output)
70
+ movement_logits = self.movement_cf(sequence_output)
71
+
72
+ if scale_labels is not None and movement_labels is not None:
73
+ loss = F.cross_entropy(scale_logits, scale_labels) + F.cross_entropy(movement_logits, movement_labels)
74
+ return {"loss": loss, "scale_logits": scale_logits, "movement_logits": movement_logits}
75
+ return {"scale_logits": scale_logits, "movement_logits": movement_logits}
76
+
77
+
78
+ scale_lab2id = {"ECS": 0, "CS": 1, "MS": 2, "FS": 3, "LS": 4}
79
+ scale_id2lab = {v:k for k,v in scale_lab2id.items()}
80
+ movement_lab2id = {"Static": 0, "Motion": 1, "Pull": 2, "Push": 3}
81
+ movement_id2lab = {v:k for k,v in movement_lab2id.items()}
82
+
83
+ config = CustomVideoMAEConfig(scale_lab2id, scale_id2lab, movement_lab2id, movement_id2lab)
84
+ model = CustomModel(config, model_name, 5, 4)
85
+ ```