Superxixixi commited on
Commit
5b70063
1 Parent(s): 2e36228

Upload loconet

Browse files
Files changed (4) hide show
  1. config.json +17 -0
  2. config_loconet.py +23 -0
  3. modeling_loconet.py +45 -0
  4. pytorch_model.bin +3 -0
config.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "adjust_attention": false,
3
+ "architectures": [
4
+ "loconet"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "config_loconet.LoCoNetConfig",
8
+ "AutoModel": "modeling_loconet.loconet"
9
+ },
10
+ "av": "speaker_temporal",
11
+ "av_layers": 3,
12
+ "clip_length": 200,
13
+ "model_type": "loconet",
14
+ "num_speakers": 3,
15
+ "torch_dtype": "float32",
16
+ "transformers_version": "4.28.1"
17
+ }
config_loconet.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+ from typing import List
3
+
4
+
5
+ class LoCoNetConfig(PretrainedConfig):
6
+ model_type = "loconet"
7
+
8
+ def __init__(
9
+ self,
10
+ num_speakers: int = 3,
11
+ clip_length: int = 200,
12
+ av: str = "speaker_temporal",
13
+ av_layers: int = 3,
14
+ adjust_attention: bool = False,
15
+ **kwargs,
16
+ ):
17
+
18
+ self.num_speakers = num_speakers
19
+ self.clip_length = clip_length
20
+ self.av = av
21
+ self.av_layers = av_layers
22
+ self.adjust_attention = adjust_attention
23
+ super().__init__(**kwargs)
modeling_loconet.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from config_loconet import LoCoNetConfig
2
+ from transformers import PreTrainedModel
3
+ from loconet_encoder import locoencoder
4
+ from loss_multi import lossAV, lossA, lossV
5
+
6
+
7
+ class loconet(PreTrainedModel):
8
+ config_class = LoCoNetConfig
9
+
10
+ def __init__(self, config):
11
+ super().__init__(config)
12
+
13
+ self.model = locoencoder(config)
14
+
15
+ def forward(self, audioFeature, visualFeature, masks, labels=None):
16
+ b, s, t = visualFeature.shape[:3]
17
+ visualFeature = visualFeature.view(b * s, *visualFeature.shape[2:])
18
+ labels = labels.view(b * s, *labels.shape[2:])
19
+ masks = masks.view(b * s, *masks.shape[2:])
20
+
21
+ audioEmbed = self.model.forward_audio_frontend(audioFeature) # B, C, T, 4
22
+ visualEmbed = self.model.forward_visual_frontend(visualFeature)
23
+ audioEmbed = audioEmbed.repeat(s, 1, 1)
24
+
25
+ audioEmbed, visualEmbed = self.model.forward_cross_attention(audioEmbed, visualEmbed)
26
+ outsAV = self.model.forward_audio_visual_backend(audioEmbed, visualEmbed, b, s)
27
+ outsA = self.model.forward_audio_backend(audioEmbed)
28
+ outsV = self.model.forward_visual_backend(visualEmbed)
29
+ num_frames = masks.sum()
30
+
31
+ if labels is not None:
32
+
33
+ labels = labels.reshape((-1))
34
+ masks = masks.reshape((-1))
35
+ nlossAV, _, _, prec = self.lossAV.forward(outsAV, labels, masks)
36
+ nlossA = self.lossA.forward(outsA, labels, masks)
37
+ nlossV = self.lossV.forward(outsV, labels, masks)
38
+
39
+ nloss = nlossAV + 0.4 * nlossA + 0.4 * nlossV
40
+
41
+ return {"loss": nloss, "logits": outsAV}
42
+
43
+ else:
44
+
45
+ return {"logits": outsAV}
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6918e8391d48c40cfd90b332687508bb4b2269879ba1303dacb5a26937ecda87
3
+ size 137464429