swimmiing commited on
Commit
b20af9f
1 Parent(s): 581ac51

Upload model files

Browse files
app.py CHANGED
@@ -1,7 +1,45 @@
1
  import gradio as gr
 
 
 
 
 
 
2
 
3
  def greet(name):
4
  return "Hello " + name + "!!"
5
 
6
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import torch
3
+ import numpy as np
4
+ from modules.models import *
5
+ from util import get_prompt_template
6
+ from PIL import Image
7
+
8
 
9
  def greet(name):
10
  return "Hello " + name + "!!"
11
 
12
+
13
+ def main():
14
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
15
+
16
+ # Get model
17
+ model_conf_file = f'./config/model/ACL_ViT16.yaml'
18
+ model = ACL(model_conf_file, device)
19
+ model.train(False)
20
+ model.load('./pretrain/Param_best.pth')
21
+
22
+ # Get placeholder text
23
+ prompt_template, text_pos_at_prompt, prompt_length = get_prompt_template()
24
+
25
+ # Input pre processing
26
+
27
+ # Inference
28
+ placeholder_tokens = model.get_placeholder_token(prompt_template.replace('{}', ''))
29
+ # audio_driven_embedding = model.encode_audio(audios.to(model.device), placeholder_tokens, text_pos_at_prompt,
30
+ # prompt_length)
31
+
32
+ # Localization result
33
+ # out_dict = model(images.to(model.device), audio_driven_embedding, 352)
34
+ # seg = out_dict['heatmap'][j:j + 1]
35
+ # seg_image = ((1 - seg.squeeze().detach().cpu().numpy()) * 255).astype(np.uint8)
36
+ # seg_image = Image.fromarray(seg_image)
37
+ heatmap_image = cv2.applyColorMap(np.array(seg_image), cv2.COLORMAP_JET)
38
+ # overlaid_image = cv2.addWeighted(np.array(original_image), 0.5, heatmap_image, 0.5, 0)
39
+
40
+
41
+ if __name__ == "__main__":
42
+ iface = gr.Interface(fn=greet, inputs="text", outputs="text")
43
+ iface.launch()
44
+
45
+ main()
config/model/ACL_ViT16.yaml ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ clip: ViT16
3
+ vision_backbone: null
4
+ audio_backbone: BEATs
5
+ audio_proj: FGA512
6
+
7
+ pretrain:
8
+ vision_backbone: null
9
+ audio_backbone: ./pretrain/BEATs_iter3_plus_AS2M_finetuned_on_AS2M_cpt2.pt
10
+ audio_proj: null
11
+
12
+ fga_conf:
13
+ FGA:
14
+ input_size: 768
15
+ output_size: 768
16
+
17
+ FGA512:
18
+ input_size: 768
19
+ output_size: 512
20
+
21
+ clip_conf:
22
+ RN50:
23
+ name: RN50
24
+ vision:
25
+ image_resolution: 224
26
+ vision_layers: [3, 4, 6, 3]
27
+ vision_width: 64
28
+ heads: 8
29
+ vision_patch_size: null
30
+ text:
31
+ transformer_layers: 12
32
+ transformer_width: 512
33
+ transformer_heads: 8
34
+ vocab_size: 49408
35
+ context_length: 77
36
+ embedding_dim: 1024
37
+
38
+ ViT16:
39
+ name: ViT-B/16
40
+ vision:
41
+ image_resolution: 224
42
+ vision_layers: 12
43
+ vision_width: 768
44
+ heads: 12
45
+ vision_patch_size: 16
46
+ text:
47
+ transformer_layers: 12
48
+ transformer_width: 512
49
+ transformer_heads: 8
50
+ vocab_size: 49408
51
+ context_length: 77
52
+ embedding_dim: 512
53
+
54
+ ViT14:
55
+ name: ViT-L/14
56
+ vision:
57
+ image_resolution: 224
58
+ vision_layers: 24
59
+ vision_width: 1024
60
+ heads: 16
61
+ vision_patch_size: 14
62
+ text:
63
+ transformer_layers: 12
64
+ transformer_width: 768
65
+ transformer_heads: 12
66
+ vocab_size: 49408
67
+ context_length: 77
68
+ embedding_dim: 768
69
+
70
+ vision_backbone_conf:
71
+ maskclip_plus_rn50_512:
72
+ name: maskclip_plus_rn50_512
73
+ image_resolution: 512
74
+ vision_layers: [ 3, 4, 6, 3 ]
75
+ vision_width: 2048
76
+ aspp:
77
+ dilations: [ 6, 12, 18, 24 ]
78
+ in_channels: 2048
79
+ channels: 512
80
+
81
+ maskclip_plus_rn101_512:
82
+ name: maskclip_plus_rn101_512
83
+ image_resolution: 512
84
+ vision_layers: [ 3, 4, 23, 3 ]
85
+ vision_width: 2048
86
+ aspp:
87
+ dilations: [ 6, 12, 18, 24 ]
88
+ in_channels: 2048
89
+ channels: 1024
config/train/Exp_ACL_v1.yaml ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model: ACL
2
+
3
+ common:
4
+ train_data: vggss
5
+ epoch: 20
6
+ batch_size: 8
7
+ input_resolution: 352
8
+ num_workers: 4
9
+ seed: 0
10
+ loss:
11
+ - acl_i
12
+ - acl_f
13
+ - area_reg
14
+ loss_w:
15
+ - 1
16
+ - 1
17
+ - 1
18
+
19
+ optimizer: Adam
20
+ scheduler: null
21
+ amp: True
22
+
23
+ optim_conf:
24
+ Adam:
25
+ module_path: torch.optim
26
+ module_name: Adam
27
+ lr: 0.0001
28
+ weight_decay: 0.0001
29
+
30
+ AdamW:
31
+ module_path: torch.optim
32
+ module_name: AdamW
33
+ lr: 0.001
34
+
35
+ SGDR:
36
+ module_path: torch.optim
37
+ module_name: SGD
38
+ lr: 0.5
39
+ weight_decay: 0.00001
40
+
41
+ sched_conf:
42
+ Cosine:
43
+ module_path: torch.optim.lr_scheduler
44
+ module_name: CosineAnnealingLR
45
+ eta_ratio: 0.0
modules/AudioToken/AudioToken.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from diffusers.loaders import AttnProcsLayers
3
+
4
+ from modules.BEATs.BEATs import BEATs, BEATsConfig
5
+ from modules.AudioToken.embedder import FGAEmbedder
6
+ from diffusers import AutoencoderKL, UNet2DConditionModel
7
+ from diffusers.models.attention_processor import LoRAAttnProcessor
8
+
9
+
10
+ class AudioTokenWrapper(torch.nn.Module):
11
+ """Simple wrapper module for Stable Diffusion that holds all the models together"""
12
+
13
+ def __init__(
14
+ self,
15
+ args,
16
+ accelerator,
17
+ ):
18
+
19
+ super().__init__()
20
+ # Load scheduler and models
21
+ from modules.clip_text_model.modeling_clip import CLIPTextModel
22
+ self.text_encoder = CLIPTextModel.from_pretrained(
23
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
24
+ )
25
+ self.unet = UNet2DConditionModel.from_pretrained(
26
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
27
+ )
28
+ self.vae = AutoencoderKL.from_pretrained(
29
+ args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision
30
+ )
31
+
32
+ checkpoint = torch.load(
33
+ 'models/BEATs/BEATs_iter3_plus_AS2M_finetuned_on_AS2M_cpt2.pt')
34
+ cfg = BEATsConfig(checkpoint['cfg'])
35
+ self.aud_encoder = BEATs(cfg)
36
+ self.aud_encoder.load_state_dict(checkpoint['model'])
37
+ self.aud_encoder.predictor = None
38
+ input_size = 768 * 3
39
+
40
+ if args.pretrained_model_name_or_path == "CompVis/stable-diffusion-v1-4":
41
+ self.embedder = FGAEmbedder(input_size=input_size, output_size=768)
42
+
43
+ else:
44
+ self.embedder = FGAEmbedder(input_size=input_size, output_size=1024)
45
+
46
+ self.vae.eval()
47
+ self.unet.eval()
48
+ self.text_encoder.eval()
49
+ self.aud_encoder.eval()
50
+
51
+ if 'lora' in args and args.lora:
52
+ # Set correct lora layers
53
+ lora_attn_procs = {}
54
+ for name in self.unet.attn_processors.keys():
55
+ cross_attention_dim = None if name.endswith(
56
+ "attn1.processor") else self.unet.config.cross_attention_dim
57
+ if name.startswith("mid_block"):
58
+ hidden_size = self.unet.config.block_out_channels[-1]
59
+ elif name.startswith("up_blocks"):
60
+ block_id = int(name[len("up_blocks.")])
61
+ hidden_size = list(reversed(self.unet.config.block_out_channels))[block_id]
62
+ elif name.startswith("down_blocks"):
63
+ block_id = int(name[len("down_blocks.")])
64
+ hidden_size = self.unet.config.block_out_channels[block_id]
65
+
66
+ lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size,
67
+ cross_attention_dim=cross_attention_dim)
68
+
69
+ self.unet.set_attn_processor(lora_attn_procs)
70
+ self.lora_layers = AttnProcsLayers(self.unet.attn_processors)
71
+
72
+ if args.data_set == 'train':
73
+
74
+ # Freeze vae, unet, text_enc and aud_encoder
75
+ self.vae.requires_grad_(False)
76
+ self.unet.requires_grad_(False)
77
+ self.text_encoder.requires_grad_(False)
78
+ self.aud_encoder.requires_grad_(False)
79
+ self.embedder.requires_grad_(True)
80
+ self.embedder.train()
81
+
82
+ if 'lora' in args and args.lora:
83
+ self.unet.train()
84
+
85
+ if args.data_set == 'test':
86
+
87
+ from transformers import CLIPTextModel
88
+ self.text_encoder = CLIPTextModel.from_pretrained(
89
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
90
+ )
91
+
92
+ self.embedder.eval()
93
+ embedder_learned_embeds = args.learned_embeds
94
+ self.embedder.load_state_dict(torch.load(embedder_learned_embeds, map_location=accelerator.device))
95
+
96
+ if 'lora' in args and args.lora:
97
+ self.lora_layers.eval()
98
+ lora_layers_learned_embeds = args.lora_learned_embeds
99
+ self.lora_layers.load_state_dict(torch.load(lora_layers_learned_embeds, map_location=accelerator.device))
100
+ self.unet.load_attn_procs(lora_layers_learned_embeds)
modules/AudioToken/embedder.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from modules.FGA.atten import Atten
3
+
4
+ class FGAEmbedder(nn.Module):
5
+ def __init__(self, input_size=768*3, output_size=768):
6
+ super(FGAEmbedder, self).__init__()
7
+ self.fc1 = nn.Linear(input_size, input_size)
8
+ self.fc2 = nn.Linear(input_size, output_size)
9
+ self.gelu = nn.GELU()
10
+ self.fga = Atten(util_e=[output_size], pairwise_flag=False)
11
+
12
+ def forward(self, audio_embs):
13
+ audio_embs = self.fc1(audio_embs)
14
+ audio_embs = self.gelu(audio_embs)
15
+ audio_embs = self.fc2(audio_embs)
16
+ attend = self.fga([audio_embs])[0]
17
+ return attend
modules/BEATs/BEATs.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
3
+ # Github source: https://github.com/microsoft/unilm/tree/master/beats
4
+ # Copyright (c) 2022 Microsoft
5
+ # Licensed under The MIT License [see LICENSE for details]
6
+ # Based on fairseq code bases
7
+ # https://github.com/pytorch/fairseq
8
+ # --------------------------------------------------------
9
+
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ from torch.nn import LayerNorm
14
+ import torchaudio.compliance.kaldi as ta_kaldi
15
+ from torch.cuda.amp import autocast
16
+
17
+ from modules.BEATs.backbone import (
18
+ TransformerEncoder,
19
+ )
20
+
21
+ import logging
22
+ from typing import Optional
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+
27
+ class BEATsConfig:
28
+ def __init__(self, cfg=None):
29
+ self.input_patch_size: int = -1 # path size of patch embedding
30
+ self.embed_dim: int = 512 # patch embedding dimension
31
+ self.conv_bias: bool = False # include bias in conv encoder
32
+
33
+ self.encoder_layers: int = 12 # num encoder layers in the transformer
34
+ self.encoder_embed_dim: int = 768 # encoder embedding dimension
35
+ self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN
36
+ self.encoder_attention_heads: int = 12 # num encoder attention heads
37
+ self.activation_fn: str = "gelu" # activation function to use
38
+
39
+ self.layer_wise_gradient_decay_ratio: float = 1.0 # ratio for layer-wise gradient decay
40
+ self.layer_norm_first: bool = False # apply layernorm first in the transformer
41
+ self.deep_norm: bool = False # apply deep_norm first in the transformer
42
+
43
+ # dropouts
44
+ self.dropout: float = 0.1 # dropout probability for the transformer
45
+ self.attention_dropout: float = 0.1 # dropout probability for attention weights
46
+ self.activation_dropout: float = 0.0 # dropout probability after activation in FFN
47
+ self.encoder_layerdrop: float = 0.0 # probability of dropping a tarnsformer layer
48
+ self.dropout_input: float = 0.0 # dropout to apply to the input (after feat extr)
49
+
50
+ # positional embeddings
51
+ self.conv_pos: int = 128 # number of filters for convolutional positional embeddings
52
+ self.conv_pos_groups: int = 16 # number of groups for convolutional positional embedding
53
+
54
+ # relative position embedding
55
+ self.relative_position_embedding: bool = False # apply relative position embedding
56
+ self.num_buckets: int = 320 # number of buckets for relative position embedding
57
+ self.max_distance: int = 1280 # maximum distance for relative position embedding
58
+ self.gru_rel_pos: bool = False # apply gated relative position embedding
59
+
60
+ # label predictor
61
+ self.finetuned_model: bool = False # whether the model is a fine-tuned model.
62
+ self.predictor_dropout: float = 0.1 # dropout probability for the predictor
63
+ self.predictor_class: int = 527 # target class number for the predictor
64
+
65
+ if cfg is not None:
66
+ self.update(cfg)
67
+
68
+ def update(self, cfg: dict):
69
+ self.__dict__.update(cfg)
70
+
71
+
72
+ class BEATs(nn.Module):
73
+ def __init__(
74
+ self,
75
+ cfg: BEATsConfig,
76
+ ) -> None:
77
+ super().__init__()
78
+ logger.info(f"BEATs Config: {cfg.__dict__}")
79
+
80
+ self.cfg = cfg
81
+
82
+ self.embed = cfg.embed_dim
83
+ self.post_extract_proj = (
84
+ nn.Linear(self.embed, cfg.encoder_embed_dim)
85
+ if self.embed != cfg.encoder_embed_dim
86
+ else None
87
+ )
88
+
89
+ self.input_patch_size = cfg.input_patch_size
90
+ self.patch_embedding = nn.Conv2d(1, self.embed, kernel_size=self.input_patch_size, stride=self.input_patch_size,
91
+ bias=cfg.conv_bias)
92
+
93
+ self.dropout_input = nn.Dropout(cfg.dropout_input)
94
+
95
+ assert not cfg.deep_norm or not cfg.layer_norm_first
96
+ self.encoder = TransformerEncoder(cfg)
97
+ self.layer_norm = LayerNorm(self.embed)
98
+
99
+ if cfg.finetuned_model:
100
+ self.predictor_dropout = nn.Dropout(cfg.predictor_dropout)
101
+ self.predictor = nn.Linear(cfg.encoder_embed_dim, cfg.predictor_class)
102
+ else:
103
+ self.predictor = None
104
+
105
+ def forward_padding_mask(
106
+ self,
107
+ features: torch.Tensor,
108
+ padding_mask: torch.Tensor,
109
+ ) -> torch.Tensor:
110
+ extra = padding_mask.size(1) % features.size(1)
111
+ if extra > 0:
112
+ padding_mask = padding_mask[:, :-extra]
113
+ padding_mask = padding_mask.view(
114
+ padding_mask.size(0), features.size(1), -1
115
+ )
116
+ padding_mask = padding_mask.all(-1)
117
+ return padding_mask
118
+
119
+ @autocast(enabled=False)
120
+ def preprocess(
121
+ self,
122
+ source: torch.Tensor,
123
+ fbank_mean: float = 15.41663,
124
+ fbank_std: float = 6.55582,
125
+ ) -> torch.Tensor:
126
+ fbanks = []
127
+ for waveform in source:
128
+ waveform = waveform.unsqueeze(0) * 2 ** 15
129
+ fbank = ta_kaldi.fbank(waveform, num_mel_bins=128, sample_frequency=16000, frame_length=25, frame_shift=10)
130
+ fbanks.append(fbank)
131
+ fbank = torch.stack(fbanks, dim=0)
132
+ fbank = (fbank - fbank_mean) / (2 * fbank_std)
133
+ return fbank
134
+
135
+ def extract_features(
136
+ self,
137
+ source: torch.Tensor,
138
+ padding_mask: Optional[torch.Tensor] = None,
139
+ fbank_mean: float = 15.41663,
140
+ fbank_std: float = 6.55582,
141
+ ):
142
+ fbank = self.preprocess(source, fbank_mean=fbank_mean, fbank_std=fbank_std)
143
+ if padding_mask is not None:
144
+ padding_mask = self.forward_padding_mask(fbank, padding_mask)
145
+ # ToDo Aug here
146
+ fbank = fbank.unsqueeze(1)
147
+ features = self.patch_embedding(fbank)
148
+ features = features.reshape(features.shape[0], features.shape[1], -1)
149
+ features = features.transpose(1, 2)
150
+ features = self.layer_norm(features)
151
+
152
+ if padding_mask is not None:
153
+ padding_mask = self.forward_padding_mask(features, padding_mask)
154
+
155
+ if self.post_extract_proj is not None:
156
+ features = self.post_extract_proj(features)
157
+
158
+ x = self.dropout_input(features)
159
+
160
+ x, layers_sum, layers = self.encoder(
161
+ x,
162
+ padding_mask=padding_mask,
163
+ )
164
+
165
+ if self.predictor is not None:
166
+ x = self.predictor_dropout(x)
167
+ logits = self.predictor(x)
168
+
169
+ if padding_mask is not None and padding_mask.any():
170
+ logits[padding_mask] = 0
171
+ logits = logits.sum(dim=1)
172
+ logits = logits / (~padding_mask).sum(dim=1).unsqueeze(-1).expand_as(logits)
173
+ else:
174
+ logits = logits.mean(dim=1)
175
+
176
+ lprobs = torch.sigmoid(logits)
177
+
178
+ return lprobs, padding_mask
179
+ else:
180
+ return x, layers_sum, layers
modules/BEATs/Tokenizers.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
3
+ # Github source: https://github.com/microsoft/unilm/tree/master/beats
4
+ # Copyright (c) 2022 Microsoft
5
+ # Licensed under The MIT License [see LICENSE for details]
6
+ # Based on fairseq code bases
7
+ # https://github.com/pytorch/fairseq
8
+ # --------------------------------------------------------
9
+
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ from torch.nn import LayerNorm
14
+ import torchaudio.compliance.kaldi as ta_kaldi
15
+
16
+ from modules.BEATs.backbone import (
17
+ TransformerEncoder,
18
+ )
19
+ from modules.BEATs.quantizer import (
20
+ NormEMAVectorQuantizer,
21
+ )
22
+
23
+ import logging
24
+ from typing import Optional
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+
29
+ class TokenizersConfig:
30
+ def __init__(self, cfg=None):
31
+ self.input_patch_size: int = -1 # path size of patch embedding
32
+ self.embed_dim: int = 512 # patch embedding dimension
33
+ self.conv_bias: bool = False # include bias in conv encoder
34
+
35
+ self.encoder_layers: int = 12 # num encoder layers in the transformer
36
+ self.encoder_embed_dim: int = 768 # encoder embedding dimension
37
+ self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN
38
+ self.encoder_attention_heads: int = 12 # num encoder attention heads
39
+ self.activation_fn: str = "gelu" # activation function to use
40
+
41
+ self.layer_norm_first: bool = False # apply layernorm first in the transformer
42
+ self.deep_norm: bool = False # apply deep_norm first in the transformer
43
+
44
+ # dropouts
45
+ self.dropout: float = 0.1 # dropout probability for the transformer
46
+ self.attention_dropout: float = 0.1 # dropout probability for attention weights
47
+ self.activation_dropout: float = 0.0 # dropout probability after activation in FFN
48
+ self.encoder_layerdrop: float = 0.0 # probability of dropping a tarnsformer layer
49
+ self.dropout_input: float = 0.0 # dropout to apply to the input (after feat extr)
50
+
51
+ # positional embeddings
52
+ self.conv_pos: int = 128 # number of filters for convolutional positional embeddings
53
+ self.conv_pos_groups: int = 16 # number of groups for convolutional positional embedding
54
+
55
+ # relative position embedding
56
+ self.relative_position_embedding: bool = False # apply relative position embedding
57
+ self.num_buckets: int = 320 # number of buckets for relative position embedding
58
+ self.max_distance: int = 1280 # maximum distance for relative position embedding
59
+ self.gru_rel_pos: bool = False # apply gated relative position embedding
60
+
61
+ # quantizer
62
+ self.quant_n: int = 1024 # codebook number in quantizer
63
+ self.quant_dim: int = 256 # codebook dimension in quantizer
64
+
65
+ if cfg is not None:
66
+ self.update(cfg)
67
+
68
+ def update(self, cfg: dict):
69
+ self.__dict__.update(cfg)
70
+
71
+
72
+ class Tokenizers(nn.Module):
73
+ def __init__(
74
+ self,
75
+ cfg: TokenizersConfig,
76
+ ) -> None:
77
+ super().__init__()
78
+ logger.info(f"Tokenizers Config: {cfg.__dict__}")
79
+
80
+ self.cfg = cfg
81
+
82
+ self.embed = cfg.embed_dim
83
+ self.post_extract_proj = (
84
+ nn.Linear(self.embed, cfg.encoder_embed_dim)
85
+ if self.embed != cfg.encoder_embed_dim
86
+ else None
87
+ )
88
+
89
+ self.input_patch_size = cfg.input_patch_size
90
+ self.patch_embedding = nn.Conv2d(1, self.embed, kernel_size=self.input_patch_size, stride=self.input_patch_size,
91
+ bias=cfg.conv_bias)
92
+
93
+ self.dropout_input = nn.Dropout(cfg.dropout_input)
94
+
95
+ assert not cfg.deep_norm or not cfg.layer_norm_first
96
+ self.encoder = TransformerEncoder(cfg)
97
+ self.layer_norm = LayerNorm(self.embed)
98
+
99
+ self.quantize = NormEMAVectorQuantizer(
100
+ n_embed=cfg.quant_n, embedding_dim=cfg.quant_dim, beta=1.0, kmeans_init=True, decay=0.99,
101
+ )
102
+ self.quant_n = cfg.quant_n
103
+ self.quantize_layer = nn.Sequential(
104
+ nn.Linear(cfg.encoder_embed_dim, cfg.encoder_embed_dim),
105
+ nn.Tanh(),
106
+ nn.Linear(cfg.encoder_embed_dim, cfg.quant_dim) # for quantize
107
+ )
108
+
109
+ def forward_padding_mask(
110
+ self,
111
+ features: torch.Tensor,
112
+ padding_mask: torch.Tensor,
113
+ ) -> torch.Tensor:
114
+ extra = padding_mask.size(1) % features.size(1)
115
+ if extra > 0:
116
+ padding_mask = padding_mask[:, :-extra]
117
+ padding_mask = padding_mask.view(
118
+ padding_mask.size(0), features.size(1), -1
119
+ )
120
+ padding_mask = padding_mask.all(-1)
121
+ return padding_mask
122
+
123
+ def preprocess(
124
+ self,
125
+ source: torch.Tensor,
126
+ fbank_mean: float = 15.41663,
127
+ fbank_std: float = 6.55582,
128
+ ) -> torch.Tensor:
129
+ fbanks = []
130
+ for waveform in source:
131
+ waveform = waveform.unsqueeze(0) * 2 ** 15
132
+ fbank = ta_kaldi.fbank(waveform, num_mel_bins=128, sample_frequency=16000, frame_length=25, frame_shift=10)
133
+ fbanks.append(fbank)
134
+ fbank = torch.stack(fbanks, dim=0)
135
+ fbank = (fbank - fbank_mean) / (2 * fbank_std)
136
+ return fbank
137
+
138
+ def extract_labels(
139
+ self,
140
+ source: torch.Tensor,
141
+ padding_mask: Optional[torch.Tensor] = None,
142
+ fbank_mean: float = 15.41663,
143
+ fbank_std: float = 6.55582,
144
+ ):
145
+ fbank = self.preprocess(source, fbank_mean=fbank_mean, fbank_std=fbank_std)
146
+
147
+ if padding_mask is not None:
148
+ padding_mask = self.forward_padding_mask(fbank, padding_mask)
149
+
150
+ fbank = fbank.unsqueeze(1)
151
+ features = self.patch_embedding(fbank)
152
+ features = features.reshape(features.shape[0], features.shape[1], -1)
153
+ features = features.transpose(1, 2)
154
+ features = self.layer_norm(features)
155
+
156
+ if padding_mask is not None:
157
+ padding_mask = self.forward_padding_mask(features, padding_mask)
158
+
159
+ if self.post_extract_proj is not None:
160
+ features = self.post_extract_proj(features)
161
+
162
+ x = self.dropout_input(features)
163
+
164
+ x, layer_results = self.encoder(
165
+ x,
166
+ padding_mask=padding_mask,
167
+ )
168
+
169
+ quantize_input = self.quantize_layer(x)
170
+ quantize_feature, embed_loss, embed_ind = self.quantize(quantize_input)
171
+
172
+ return embed_ind
modules/BEATs/backbone.py ADDED
@@ -0,0 +1,788 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
3
+ # Github source: https://github.com/microsoft/unilm/tree/master/beats
4
+ # Copyright (c) 2022 Microsoft
5
+ # Licensed under The MIT License [see LICENSE for details]
6
+ # Based on fairseq code bases
7
+ # https://github.com/pytorch/fairseq
8
+ # --------------------------------------------------------
9
+
10
+ import math
11
+ import numpy as np
12
+ from typing import Dict, Optional, Tuple
13
+ import torch
14
+ from torch import Tensor, nn
15
+ import torch.nn.functional as F
16
+ from torch.nn import LayerNorm, Parameter
17
+ from modules.BEATs.modules import (
18
+ GradMultiply,
19
+ SamePad,
20
+ get_activation_fn,
21
+ GLU_Linear,
22
+ quant_noise,
23
+ )
24
+
25
+
26
+ class TransformerEncoder(nn.Module):
27
+ def __init__(self, args):
28
+ super().__init__()
29
+
30
+ self.dropout = args.dropout
31
+ self.embedding_dim = args.encoder_embed_dim
32
+
33
+ self.pos_conv = nn.Conv1d(
34
+ self.embedding_dim,
35
+ self.embedding_dim,
36
+ kernel_size=args.conv_pos,
37
+ padding=args.conv_pos // 2,
38
+ groups=args.conv_pos_groups,
39
+ )
40
+ dropout = 0
41
+ std = math.sqrt((4 * (1.0 - dropout)) / (args.conv_pos * self.embedding_dim))
42
+ nn.init.normal_(self.pos_conv.weight, mean=0, std=std)
43
+ nn.init.constant_(self.pos_conv.bias, 0)
44
+
45
+ self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2)
46
+ self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos), nn.GELU())
47
+
48
+ if hasattr(args, "relative_position_embedding"):
49
+ self.relative_position_embedding = args.relative_position_embedding
50
+ self.num_buckets = args.num_buckets
51
+ self.max_distance = args.max_distance
52
+ else:
53
+ self.relative_position_embedding = False
54
+ self.num_buckets = 0
55
+ self.max_distance = 0
56
+
57
+ self.layers = nn.ModuleList(
58
+ [
59
+ TransformerSentenceEncoderLayer(
60
+ embedding_dim=self.embedding_dim,
61
+ ffn_embedding_dim=args.encoder_ffn_embed_dim,
62
+ num_attention_heads=args.encoder_attention_heads,
63
+ dropout=self.dropout,
64
+ attention_dropout=args.attention_dropout,
65
+ activation_dropout=args.activation_dropout,
66
+ activation_fn=args.activation_fn,
67
+ layer_norm_first=args.layer_norm_first,
68
+ deep_norm=args.deep_norm,
69
+ has_relative_attention_bias=self.relative_position_embedding,
70
+ num_buckets=self.num_buckets,
71
+ max_distance=self.max_distance,
72
+ gru_rel_pos=args.gru_rel_pos,
73
+ encoder_layers=args.encoder_layers,
74
+ )
75
+ for i in range(args.encoder_layers)
76
+ ]
77
+ )
78
+ if self.relative_position_embedding:
79
+ for i in range(1, args.encoder_layers):
80
+ del self.layers[i].self_attn.relative_attention_bias
81
+ self.layers[i].self_attn.relative_attention_bias = self.layers[0].self_attn.relative_attention_bias
82
+
83
+ self.layer_norm_first = args.layer_norm_first
84
+ self.layer_norm = LayerNorm(self.embedding_dim)
85
+ self.layerdrop = args.encoder_layerdrop
86
+
87
+ self.apply(init_bert_params)
88
+
89
+ if args.deep_norm:
90
+ deep_norm_beta = math.pow(8 * args.encoder_layers, -1 / 4)
91
+ for i in range(args.encoder_layers):
92
+ nn.init.xavier_normal_(self.layers[i].self_attn.k_proj.weight, gain=1)
93
+ nn.init.xavier_normal_(self.layers[i].self_attn.v_proj.weight, gain=deep_norm_beta)
94
+ nn.init.xavier_normal_(self.layers[i].self_attn.q_proj.weight, gain=1)
95
+ nn.init.xavier_normal_(self.layers[i].self_attn.out_proj.weight, gain=deep_norm_beta)
96
+ nn.init.xavier_normal_(self.layers[i].fc1.weight, gain=deep_norm_beta)
97
+ nn.init.xavier_normal_(self.layers[i].fc2.weight, gain=deep_norm_beta)
98
+
99
+ self.layer_wise_gradient_decay_ratio = getattr(args, "layer_wise_gradient_decay_ratio", 1)
100
+
101
+ def forward(self, x, padding_mask=None, layer=None):
102
+ x, layers_sum, layers = self.extract_features(x, padding_mask, layer)
103
+
104
+ if self.layer_norm_first and layer is None:
105
+ x = self.layer_norm(x)
106
+
107
+ return x, layers_sum, layers
108
+
109
+ def extract_features(self, x, padding_mask=None, tgt_layer=None):
110
+
111
+ if padding_mask is not None:
112
+ x[padding_mask] = 0
113
+
114
+ x_conv = self.pos_conv(x.transpose(1, 2))
115
+ x_conv = x_conv.transpose(1, 2)
116
+ x += x_conv
117
+
118
+ if not self.layer_norm_first:
119
+ x = self.layer_norm(x)
120
+
121
+ x = F.dropout(x, p=self.dropout, training=self.training)
122
+
123
+ # B x T x C -> T x B x C
124
+ x = x.transpose(0, 1)
125
+ layers = []
126
+
127
+ layer_results = []
128
+ z = None
129
+ if tgt_layer is not None:
130
+ layer_results.append((x, z))
131
+ r = None
132
+ pos_bias = None
133
+ for i, layer in enumerate(self.layers):
134
+ if self.layer_wise_gradient_decay_ratio != 1.0:
135
+ x = GradMultiply.apply(x, self.layer_wise_gradient_decay_ratio)
136
+ dropout_probability = np.random.random()
137
+ if not self.training or (dropout_probability > self.layerdrop):
138
+ x, z, pos_bias = layer(x, self_attn_padding_mask=padding_mask, need_weights=False, pos_bias=pos_bias)
139
+ if tgt_layer is not None:
140
+ layer_results.append((x, z))
141
+ if i == tgt_layer:
142
+ r = x
143
+ break
144
+ if i in [3, 7, 11]:
145
+ layers.append(x.transpose(0, 1))
146
+
147
+ if r is not None:
148
+ x = r
149
+
150
+ # T x B x C -> B x T x C
151
+ x = x.transpose(0, 1)
152
+ layers_cat = torch.cat(layers, dim=2)
153
+ # layers = layers[0] + layers[1] + layers[2]
154
+
155
+ return x, layers_cat, layers
156
+
157
+
158
+ class TransformerSentenceEncoderLayer(nn.Module):
159
+ def __init__(
160
+ self,
161
+ embedding_dim: float = 768,
162
+ ffn_embedding_dim: float = 3072,
163
+ num_attention_heads: float = 8,
164
+ dropout: float = 0.1,
165
+ attention_dropout: float = 0.1,
166
+ activation_dropout: float = 0.1,
167
+ activation_fn: str = "relu",
168
+ layer_norm_first: bool = False,
169
+ deep_norm: bool = False,
170
+ has_relative_attention_bias: bool = False,
171
+ num_buckets: int = 0,
172
+ max_distance: int = 0,
173
+ rescale_init: bool = False,
174
+ gru_rel_pos: bool = False,
175
+ encoder_layers: int = 0,
176
+ ) -> None:
177
+
178
+ super().__init__()
179
+ self.embedding_dim = embedding_dim
180
+ self.dropout = dropout
181
+ self.activation_dropout = activation_dropout
182
+
183
+ self.activation_name = activation_fn
184
+ self.activation_fn = get_activation_fn(activation_fn)
185
+ self.self_attn = MultiheadAttention(
186
+ self.embedding_dim,
187
+ num_attention_heads,
188
+ dropout=attention_dropout,
189
+ self_attention=True,
190
+ has_relative_attention_bias=has_relative_attention_bias,
191
+ num_buckets=num_buckets,
192
+ max_distance=max_distance,
193
+ rescale_init=rescale_init,
194
+ gru_rel_pos=gru_rel_pos,
195
+ )
196
+
197
+ self.dropout1 = nn.Dropout(dropout)
198
+ self.dropout2 = nn.Dropout(self.activation_dropout)
199
+ self.dropout3 = nn.Dropout(dropout)
200
+
201
+ self.layer_norm_first = layer_norm_first
202
+
203
+ self.self_attn_layer_norm = LayerNorm(self.embedding_dim)
204
+
205
+ if self.activation_name == "glu":
206
+ self.fc1 = GLU_Linear(self.embedding_dim, ffn_embedding_dim, "swish")
207
+ else:
208
+ self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim)
209
+ self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim)
210
+
211
+ self.final_layer_norm = LayerNorm(self.embedding_dim)
212
+
213
+ self.deep_norm = deep_norm
214
+ if self.deep_norm:
215
+ self.deep_norm_alpha = math.pow(2 * encoder_layers, 1 / 4)
216
+ else:
217
+ self.deep_norm_alpha = 1
218
+
219
+ def forward(
220
+ self,
221
+ x: torch.Tensor,
222
+ self_attn_mask: torch.Tensor = None,
223
+ self_attn_padding_mask: torch.Tensor = None,
224
+ need_weights: bool = False,
225
+ pos_bias=None
226
+ ):
227
+ residual = x
228
+
229
+ if self.layer_norm_first:
230
+ x = self.self_attn_layer_norm(x)
231
+ x, attn, pos_bias = self.self_attn(
232
+ query=x,
233
+ key=x,
234
+ value=x,
235
+ key_padding_mask=self_attn_padding_mask,
236
+ need_weights=False,
237
+ attn_mask=self_attn_mask,
238
+ position_bias=pos_bias
239
+ )
240
+ x = self.dropout1(x)
241
+ x = residual + x
242
+
243
+ residual = x
244
+ x = self.final_layer_norm(x)
245
+ if self.activation_name == "glu":
246
+ x = self.fc1(x)
247
+ else:
248
+ x = self.activation_fn(self.fc1(x))
249
+ x = self.dropout2(x)
250
+ x = self.fc2(x)
251
+ x = self.dropout3(x)
252
+ x = residual + x
253
+ else:
254
+ x, attn, pos_bias = self.self_attn(
255
+ query=x,
256
+ key=x,
257
+ value=x,
258
+ key_padding_mask=self_attn_padding_mask,
259
+ need_weights=need_weights,
260
+ attn_mask=self_attn_mask,
261
+ position_bias=pos_bias
262
+ )
263
+
264
+ x = self.dropout1(x)
265
+ x = residual * self.deep_norm_alpha + x
266
+
267
+ x = self.self_attn_layer_norm(x)
268
+
269
+ residual = x
270
+ if self.activation_name == "glu":
271
+ x = self.fc1(x)
272
+ else:
273
+ x = self.activation_fn(self.fc1(x))
274
+ x = self.dropout2(x)
275
+ x = self.fc2(x)
276
+ x = self.dropout3(x)
277
+ x = residual * self.deep_norm_alpha + x
278
+ x = self.final_layer_norm(x)
279
+
280
+ return x, attn, pos_bias
281
+
282
+
283
+ class MultiheadAttention(nn.Module):
284
+ """Multi-headed attention.
285
+
286
+ See "Attention Is All You Need" for more details.
287
+ """
288
+
289
+ def __init__(
290
+ self,
291
+ embed_dim,
292
+ num_heads,
293
+ kdim=None,
294
+ vdim=None,
295
+ dropout=0.0,
296
+ bias=True,
297
+ add_bias_kv=False,
298
+ add_zero_attn=False,
299
+ self_attention=False,
300
+ encoder_decoder_attention=False,
301
+ q_noise=0.0,
302
+ qn_block_size=8,
303
+ has_relative_attention_bias=False,
304
+ num_buckets=32,
305
+ max_distance=128,
306
+ gru_rel_pos=False,
307
+ rescale_init=False,
308
+ ):
309
+ super().__init__()
310
+ self.embed_dim = embed_dim
311
+ self.kdim = kdim if kdim is not None else embed_dim
312
+ self.vdim = vdim if vdim is not None else embed_dim
313
+ self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
314
+
315
+ self.num_heads = num_heads
316
+ self.dropout_module = nn.Dropout(dropout)
317
+
318
+ self.has_relative_attention_bias = has_relative_attention_bias
319
+ self.num_buckets = num_buckets
320
+ self.max_distance = max_distance
321
+ if self.has_relative_attention_bias:
322
+ self.relative_attention_bias = nn.Embedding(num_buckets, num_heads)
323
+
324
+ self.head_dim = embed_dim // num_heads
325
+ self.q_head_dim = self.head_dim
326
+ self.k_head_dim = self.head_dim
327
+ assert (
328
+ self.head_dim * num_heads == self.embed_dim
329
+ ), "embed_dim must be divisible by num_heads"
330
+ self.scaling = self.head_dim ** -0.5
331
+
332
+ self.self_attention = self_attention
333
+ self.encoder_decoder_attention = encoder_decoder_attention
334
+
335
+ assert not self.self_attention or self.qkv_same_dim, (
336
+ "Self-attention requires query, key and " "value to be of the same size"
337
+ )
338
+
339
+ k_bias = True
340
+ if rescale_init:
341
+ k_bias = False
342
+
343
+ k_embed_dim = embed_dim
344
+ q_embed_dim = embed_dim
345
+
346
+ self.k_proj = quant_noise(
347
+ nn.Linear(self.kdim, k_embed_dim, bias=k_bias), q_noise, qn_block_size
348
+ )
349
+ self.v_proj = quant_noise(
350
+ nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size
351
+ )
352
+ self.q_proj = quant_noise(
353
+ nn.Linear(embed_dim, q_embed_dim, bias=bias), q_noise, qn_block_size
354
+ )
355
+
356
+ self.out_proj = quant_noise(
357
+ nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size
358
+ )
359
+
360
+ if add_bias_kv:
361
+ self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
362
+ self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
363
+ else:
364
+ self.bias_k = self.bias_v = None
365
+
366
+ self.add_zero_attn = add_zero_attn
367
+
368
+ self.gru_rel_pos = gru_rel_pos
369
+ if self.gru_rel_pos:
370
+ self.grep_linear = nn.Linear(self.q_head_dim, 8)
371
+ self.grep_a = nn.Parameter(torch.ones(1, num_heads, 1, 1))
372
+
373
+ self.reset_parameters()
374
+
375
+ def reset_parameters(self):
376
+ if self.qkv_same_dim:
377
+ # Empirically observed the convergence to be much better with
378
+ # the scaled initialization
379
+ nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
380
+ nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
381
+ nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
382
+ else:
383
+ nn.init.xavier_uniform_(self.k_proj.weight)
384
+ nn.init.xavier_uniform_(self.v_proj.weight)
385
+ nn.init.xavier_uniform_(self.q_proj.weight)
386
+
387
+ nn.init.xavier_uniform_(self.out_proj.weight)
388
+ if self.out_proj.bias is not None:
389
+ nn.init.constant_(self.out_proj.bias, 0.0)
390
+ if self.bias_k is not None:
391
+ nn.init.xavier_normal_(self.bias_k)
392
+ if self.bias_v is not None:
393
+ nn.init.xavier_normal_(self.bias_v)
394
+ if self.has_relative_attention_bias:
395
+ nn.init.xavier_normal_(self.relative_attention_bias.weight)
396
+
397
+ def _relative_positions_bucket(self, relative_positions, bidirectional=True):
398
+ num_buckets = self.num_buckets
399
+ max_distance = self.max_distance
400
+ relative_buckets = 0
401
+
402
+ if bidirectional:
403
+ num_buckets = num_buckets // 2
404
+ relative_buckets += (relative_positions > 0).to(torch.long) * num_buckets
405
+ relative_positions = torch.abs(relative_positions)
406
+ else:
407
+ relative_positions = -torch.min(relative_positions, torch.zeros_like(relative_positions))
408
+
409
+ max_exact = num_buckets // 2
410
+ is_small = relative_positions < max_exact
411
+
412
+ relative_postion_if_large = max_exact + (
413
+ torch.log(relative_positions.float() / max_exact)
414
+ / math.log(max_distance / max_exact)
415
+ * (num_buckets - max_exact)
416
+ ).to(torch.long)
417
+ relative_postion_if_large = torch.min(
418
+ relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1)
419
+ )
420
+
421
+ relative_buckets += torch.where(is_small, relative_positions, relative_postion_if_large)
422
+ return relative_buckets
423
+
424
+ def compute_bias(self, query_length, key_length):
425
+ context_position = torch.arange(query_length, dtype=torch.long)[:, None]
426
+ memory_position = torch.arange(key_length, dtype=torch.long)[None, :]
427
+ relative_position = memory_position - context_position
428
+ relative_position_bucket = self._relative_positions_bucket(
429
+ relative_position,
430
+ bidirectional=True
431
+ )
432
+ relative_position_bucket = relative_position_bucket.to(self.relative_attention_bias.weight.device)
433
+ values = self.relative_attention_bias(relative_position_bucket)
434
+ values = values.permute([2, 0, 1])
435
+ return values
436
+
437
+ def forward(
438
+ self,
439
+ query,
440
+ key: Optional[Tensor],
441
+ value: Optional[Tensor],
442
+ key_padding_mask: Optional[Tensor] = None,
443
+ incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
444
+ need_weights: bool = True,
445
+ static_kv: bool = False,
446
+ attn_mask: Optional[Tensor] = None,
447
+ before_softmax: bool = False,
448
+ need_head_weights: bool = False,
449
+ position_bias: Optional[Tensor] = None
450
+ ) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
451
+ """Input shape: Time x Batch x Channel
452
+
453
+ Args:
454
+ key_padding_mask (ByteTensor, optional): mask to exclude
455
+ keys that are pads, of shape `(batch, src_len)`, where
456
+ padding elements are indicated by 1s.
457
+ need_weights (bool, optional): return the attention weights,
458
+ averaged over heads (default: False).
459
+ attn_mask (ByteTensor, optional): typically used to
460
+ implement causal attention, where the mask prevents the
461
+ attention from looking forward in time (default: None).
462
+ before_softmax (bool, optional): return the raw attention
463
+ weights and values before the attention softmax.
464
+ need_head_weights (bool, optional): return the attention
465
+ weights for each head. Implies *need_weights*. Default:
466
+ return the average attention weights over all heads.
467
+ """
468
+ if need_head_weights:
469
+ need_weights = True
470
+
471
+ is_tpu = query.device.type == "xla"
472
+
473
+ tgt_len, bsz, embed_dim = query.size()
474
+ src_len = tgt_len
475
+ assert embed_dim == self.embed_dim
476
+ assert list(query.size()) == [tgt_len, bsz, embed_dim]
477
+ if key is not None:
478
+ src_len, key_bsz, _ = key.size()
479
+ if not torch.jit.is_scripting():
480
+ assert key_bsz == bsz
481
+ assert value is not None
482
+ assert src_len, bsz == value.shape[:2]
483
+
484
+ if self.has_relative_attention_bias and position_bias is None:
485
+ position_bias = self.compute_bias(tgt_len, src_len)
486
+ position_bias = position_bias.unsqueeze(0).repeat(bsz, 1, 1, 1).view(bsz * self.num_heads, tgt_len, src_len)
487
+
488
+ if incremental_state is not None:
489
+ saved_state = self._get_input_buffer(incremental_state)
490
+ if saved_state is not None and "prev_key" in saved_state:
491
+ # previous time steps are cached - no need to recompute
492
+ # key and value if they are static
493
+ if static_kv:
494
+ assert self.encoder_decoder_attention and not self.self_attention
495
+ key = value = None
496
+ else:
497
+ saved_state = None
498
+
499
+ if self.self_attention:
500
+ q = self.q_proj(query)
501
+ k = self.k_proj(query)
502
+ v = self.v_proj(query)
503
+ elif self.encoder_decoder_attention:
504
+ # encoder-decoder attention
505
+ q = self.q_proj(query)
506
+ if key is None:
507
+ assert value is None
508
+ k = v = None
509
+ else:
510
+ k = self.k_proj(key)
511
+ v = self.v_proj(key)
512
+
513
+ else:
514
+ assert key is not None and value is not None
515
+ q = self.q_proj(query)
516
+ k = self.k_proj(key)
517
+ v = self.v_proj(value)
518
+ q *= self.scaling
519
+ alpha = 32
520
+ q *= 1 / alpha
521
+
522
+ if self.bias_k is not None:
523
+ assert self.bias_v is not None
524
+ k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
525
+ v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
526
+ if attn_mask is not None:
527
+ attn_mask = torch.cat(
528
+ [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
529
+ )
530
+ if key_padding_mask is not None:
531
+ key_padding_mask = torch.cat(
532
+ [
533
+ key_padding_mask,
534
+ key_padding_mask.new_zeros(key_padding_mask.size(0), 1),
535
+ ],
536
+ dim=1,
537
+ )
538
+
539
+ q = (
540
+ q.contiguous()
541
+ .view(tgt_len, bsz * self.num_heads, self.q_head_dim)
542
+ .transpose(0, 1)
543
+ )
544
+ if k is not None:
545
+ k = (
546
+ k.contiguous()
547
+ .view(-1, bsz * self.num_heads, self.k_head_dim)
548
+ .transpose(0, 1)
549
+ )
550
+ if v is not None:
551
+ v = (
552
+ v.contiguous()
553
+ .view(-1, bsz * self.num_heads, self.head_dim)
554
+ .transpose(0, 1)
555
+ )
556
+
557
+ if saved_state is not None:
558
+ # saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
559
+ if "prev_key" in saved_state:
560
+ _prev_key = saved_state["prev_key"]
561
+ assert _prev_key is not None
562
+ prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim)
563
+ if static_kv:
564
+ k = prev_key
565
+ else:
566
+ assert k is not None
567
+ k = torch.cat([prev_key, k], dim=1)
568
+ src_len = k.size(1)
569
+ if "prev_value" in saved_state:
570
+ _prev_value = saved_state["prev_value"]
571
+ assert _prev_value is not None
572
+ prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim)
573
+ if static_kv:
574
+ v = prev_value
575
+ else:
576
+ assert v is not None
577
+ v = torch.cat([prev_value, v], dim=1)
578
+ prev_key_padding_mask: Optional[Tensor] = None
579
+ if "prev_key_padding_mask" in saved_state:
580
+ prev_key_padding_mask = saved_state["prev_key_padding_mask"]
581
+ assert k is not None and v is not None
582
+ key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(
583
+ key_padding_mask=key_padding_mask,
584
+ prev_key_padding_mask=prev_key_padding_mask,
585
+ batch_size=bsz,
586
+ src_len=k.size(1),
587
+ static_kv=static_kv,
588
+ )
589
+
590
+ saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim)
591
+ saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim)
592
+ saved_state["prev_key_padding_mask"] = key_padding_mask
593
+ # In this branch incremental_state is never None
594
+ assert incremental_state is not None
595
+ incremental_state = self._set_input_buffer(incremental_state, saved_state)
596
+ assert k is not None
597
+ assert k.size(1) == src_len
598
+
599
+ # This is part of a workaround to get around fork/join parallelism
600
+ # not supporting Optional types.
601
+ if key_padding_mask is not None and key_padding_mask.dim() == 0:
602
+ key_padding_mask = None
603
+
604
+ if key_padding_mask is not None:
605
+ assert key_padding_mask.size(0) == bsz
606
+ assert key_padding_mask.size(1) == src_len
607
+
608
+ if self.add_zero_attn:
609
+ assert v is not None
610
+ src_len += 1
611
+ k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
612
+ v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
613
+ if attn_mask is not None:
614
+ attn_mask = torch.cat(
615
+ [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
616
+ )
617
+ if key_padding_mask is not None:
618
+ key_padding_mask = torch.cat(
619
+ [
620
+ key_padding_mask,
621
+ torch.zeros(key_padding_mask.size(0), 1).type_as(
622
+ key_padding_mask
623
+ ),
624
+ ],
625
+ dim=1,
626
+ )
627
+
628
+ attn_weights = torch.bmm(q, k.transpose(1, 2))
629
+ attn_weights = (attn_weights - attn_weights.max(dim=-1, keepdim=True)[0]) * alpha
630
+ attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
631
+
632
+ assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
633
+
634
+ if attn_mask is not None:
635
+ attn_mask = attn_mask.unsqueeze(0)
636
+ attn_weights += attn_mask
637
+
638
+ if key_padding_mask is not None:
639
+ # don't attend to padding symbols
640
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
641
+ if not is_tpu:
642
+ attn_weights = attn_weights.masked_fill(
643
+ key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
644
+ float("-inf"),
645
+ )
646
+ else:
647
+ attn_weights = attn_weights.transpose(0, 2)
648
+ attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf"))
649
+ attn_weights = attn_weights.transpose(0, 2)
650
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
651
+
652
+ if before_softmax:
653
+ return attn_weights, v, position_bias
654
+
655
+ if position_bias is not None:
656
+ attn_mask_rel_pos = position_bias
657
+ if self.gru_rel_pos == 1:
658
+ query_layer = q.view(bsz, self.num_heads, tgt_len, self.q_head_dim) * alpha / self.scaling
659
+ _B, _H, _L, __ = query_layer.size()
660
+ gate_a, gate_b = torch.sigmoid(self.grep_linear(query_layer).view(
661
+ _B, _H, _L, 2, 4).sum(-1, keepdim=False)).chunk(2, dim=-1)
662
+ gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0
663
+ attn_mask_rel_pos = gate_a_1.view(bsz * self.num_heads, tgt_len, 1) * position_bias
664
+
665
+ attn_mask_rel_pos = attn_mask_rel_pos.view(attn_weights.size())
666
+
667
+ attn_weights = attn_weights + attn_mask_rel_pos
668
+
669
+ attn_weights_float = F.softmax(
670
+ attn_weights, dim=-1
671
+ )
672
+ attn_weights = attn_weights_float.type_as(attn_weights)
673
+ attn_probs = self.dropout_module(attn_weights)
674
+
675
+ assert v is not None
676
+ attn = torch.bmm(attn_probs, v)
677
+ assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
678
+ attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
679
+ attn = self.out_proj(attn)
680
+ attn_weights: Optional[Tensor] = None
681
+ if need_weights:
682
+ attn_weights = attn_weights_float.view(
683
+ bsz, self.num_heads, tgt_len, src_len
684
+ ).transpose(1, 0)
685
+ if not need_head_weights:
686
+ # average attention weights over heads
687
+ attn_weights = attn_weights.mean(dim=0)
688
+
689
+ return attn, attn_weights, position_bias
690
+
691
+ @staticmethod
692
+ def _append_prev_key_padding_mask(
693
+ key_padding_mask: Optional[Tensor],
694
+ prev_key_padding_mask: Optional[Tensor],
695
+ batch_size: int,
696
+ src_len: int,
697
+ static_kv: bool,
698
+ ) -> Optional[Tensor]:
699
+ # saved key padding masks have shape (bsz, seq_len)
700
+ if prev_key_padding_mask is not None and static_kv:
701
+ new_key_padding_mask = prev_key_padding_mask
702
+ elif prev_key_padding_mask is not None and key_padding_mask is not None:
703
+ new_key_padding_mask = torch.cat(
704
+ [prev_key_padding_mask.float(), key_padding_mask.float()], dim=1
705
+ )
706
+ # During incremental decoding, as the padding token enters and
707
+ # leaves the frame, there will be a time when prev or current
708
+ # is None
709
+ elif prev_key_padding_mask is not None:
710
+ if src_len > prev_key_padding_mask.size(1):
711
+ filler = torch.zeros(
712
+ (batch_size, src_len - prev_key_padding_mask.size(1)),
713
+ device=prev_key_padding_mask.device,
714
+ )
715
+ new_key_padding_mask = torch.cat(
716
+ [prev_key_padding_mask.float(), filler.float()], dim=1
717
+ )
718
+ else:
719
+ new_key_padding_mask = prev_key_padding_mask.float()
720
+ elif key_padding_mask is not None:
721
+ if src_len > key_padding_mask.size(1):
722
+ filler = torch.zeros(
723
+ (batch_size, src_len - key_padding_mask.size(1)),
724
+ device=key_padding_mask.device,
725
+ )
726
+ new_key_padding_mask = torch.cat(
727
+ [filler.float(), key_padding_mask.float()], dim=1
728
+ )
729
+ else:
730
+ new_key_padding_mask = key_padding_mask.float()
731
+ else:
732
+ new_key_padding_mask = prev_key_padding_mask
733
+ return new_key_padding_mask
734
+
735
+ def _get_input_buffer(
736
+ self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]
737
+ ) -> Dict[str, Optional[Tensor]]:
738
+ result = self.get_incremental_state(incremental_state, "attn_state")
739
+ if result is not None:
740
+ return result
741
+ else:
742
+ empty_result: Dict[str, Optional[Tensor]] = {}
743
+ return empty_result
744
+
745
+ def _set_input_buffer(
746
+ self,
747
+ incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
748
+ buffer: Dict[str, Optional[Tensor]],
749
+ ):
750
+ return self.set_incremental_state(incremental_state, "attn_state", buffer)
751
+
752
+ def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int):
753
+ return attn_weights
754
+
755
+
756
+ def init_bert_params(module):
757
+ """
758
+ Initialize the weights specific to the BERT Model.
759
+ This overrides the default initializations depending on the specified arguments.
760
+ 1. If normal_init_linear_weights is set then weights of linear
761
+ layer will be initialized using the normal distribution and
762
+ bais will be set to the specified value.
763
+ 2. If normal_init_embed_weights is set then weights of embedding
764
+ layer will be initialized using the normal distribution.
765
+ 3. If normal_init_proj_weights is set then weights of
766
+ in_project_weight for MultiHeadAttention initialized using
767
+ the normal distribution (to be validated).
768
+ """
769
+
770
+ def normal_(data):
771
+ # with FSDP, module params will be on CUDA, so we cast them back to CPU
772
+ # so that the RNG is consistent with and without FSDP
773
+ data.copy_(
774
+ data.cpu().normal_(mean=0.0, std=0.02).to(data.device)
775
+ )
776
+
777
+ if isinstance(module, nn.Linear):
778
+ normal_(module.weight.data)
779
+ if module.bias is not None:
780
+ module.bias.data.zero_()
781
+ if isinstance(module, nn.Embedding):
782
+ normal_(module.weight.data)
783
+ if module.padding_idx is not None:
784
+ module.weight.data[module.padding_idx].zero_()
785
+ if isinstance(module, MultiheadAttention):
786
+ normal_(module.q_proj.weight.data)
787
+ normal_(module.k_proj.weight.data)
788
+ normal_(module.v_proj.weight.data)
modules/BEATs/modules.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
3
+ # Github source: https://github.com/microsoft/unilm/tree/master/beats
4
+ # Copyright (c) 2022 Microsoft
5
+ # Licensed under The MIT License [see LICENSE for details]
6
+ # Based on fairseq code bases
7
+ # https://github.com/pytorch/fairseq
8
+ # --------------------------------------------------------
9
+
10
+ import math
11
+ import warnings
12
+ import torch
13
+ from torch import Tensor, nn
14
+ import torch.nn.functional as F
15
+
16
+
17
+ class GradMultiply(torch.autograd.Function):
18
+ @staticmethod
19
+ def forward(ctx, x, scale):
20
+ ctx.scale = scale
21
+ res = x.new(x)
22
+ return res
23
+
24
+ @staticmethod
25
+ def backward(ctx, grad):
26
+ return grad * ctx.scale, None
27
+
28
+
29
+ class SamePad(nn.Module):
30
+ def __init__(self, kernel_size, causal=False):
31
+ super().__init__()
32
+ if causal:
33
+ self.remove = kernel_size - 1
34
+ else:
35
+ self.remove = 1 if kernel_size % 2 == 0 else 0
36
+
37
+ def forward(self, x):
38
+ if self.remove > 0:
39
+ x = x[:, :, : -self.remove]
40
+ return x
41
+
42
+
43
+ class Swish(nn.Module):
44
+ def __init__(self):
45
+ super(Swish, self).__init__()
46
+ self.act = torch.nn.Sigmoid()
47
+
48
+ def forward(self, x):
49
+ return x * self.act(x)
50
+
51
+
52
+ class GLU_Linear(nn.Module):
53
+ def __init__(self, input_dim, output_dim, glu_type="sigmoid", bias_in_glu=True):
54
+ super(GLU_Linear, self).__init__()
55
+
56
+ self.glu_type = glu_type
57
+ self.output_dim = output_dim
58
+
59
+ if glu_type == "sigmoid":
60
+ self.glu_act = torch.nn.Sigmoid()
61
+ elif glu_type == "swish":
62
+ self.glu_act = Swish()
63
+ elif glu_type == "relu":
64
+ self.glu_act = torch.nn.ReLU()
65
+ elif glu_type == "gelu":
66
+ self.glu_act = torch.nn.GELU()
67
+
68
+ if bias_in_glu:
69
+ self.linear = nn.Linear(input_dim, output_dim * 2, True)
70
+ else:
71
+ self.linear = nn.Linear(input_dim, output_dim * 2, False)
72
+
73
+ def forward(self, x):
74
+ # to be consistent with GLU_Linear, we assume the input always has the #channel (#dim) in the last dimension of the tensor, so need to switch the dimension first for 1D-Conv case
75
+ x = self.linear(x)
76
+
77
+ if self.glu_type == "bilinear":
78
+ x = (x[:, :, 0:self.output_dim] * x[:, :, self.output_dim:self.output_dim * 2])
79
+ else:
80
+ x = (x[:, :, 0:self.output_dim] * self.glu_act(x[:, :, self.output_dim:self.output_dim * 2]))
81
+
82
+ return x
83
+
84
+
85
+ def gelu_accurate(x):
86
+ if not hasattr(gelu_accurate, "_a"):
87
+ gelu_accurate._a = math.sqrt(2 / math.pi)
88
+ return (
89
+ 0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3))))
90
+ )
91
+
92
+
93
+ def gelu(x: torch.Tensor) -> torch.Tensor:
94
+ return torch.nn.functional.gelu(x.float()).type_as(x)
95
+
96
+
97
+ def get_activation_fn(activation: str):
98
+ """Returns the activation function corresponding to `activation`"""
99
+
100
+ if activation == "relu":
101
+ return F.relu
102
+ elif activation == "gelu":
103
+ return gelu
104
+ elif activation == "gelu_fast":
105
+ warnings.warn(
106
+ "--activation-fn=gelu_fast has been renamed to gelu_accurate"
107
+ )
108
+ return gelu_accurate
109
+ elif activation == "gelu_accurate":
110
+ return gelu_accurate
111
+ elif activation == "tanh":
112
+ return torch.tanh
113
+ elif activation == "linear":
114
+ return lambda x: x
115
+ elif activation == "glu":
116
+ return lambda x: x
117
+ else:
118
+ raise RuntimeError("--activation-fn {} not supported".format(activation))
119
+
120
+
121
+ def quant_noise(module, p, block_size):
122
+ """
123
+ Wraps modules and applies quantization noise to the weights for
124
+ subsequent quantization with Iterative Product Quantization as
125
+ described in "Training with Quantization Noise for Extreme Model Compression"
126
+
127
+ Args:
128
+ - module: nn.Module
129
+ - p: amount of Quantization Noise
130
+ - block_size: size of the blocks for subsequent quantization with iPQ
131
+
132
+ Remarks:
133
+ - Module weights must have the right sizes wrt the block size
134
+ - Only Linear, Embedding and Conv2d modules are supported for the moment
135
+ - For more detail on how to quantize by blocks with convolutional weights,
136
+ see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks"
137
+ - We implement the simplest form of noise here as stated in the paper
138
+ which consists in randomly dropping blocks
139
+ """
140
+
141
+ # if no quantization noise, don't register hook
142
+ if p <= 0:
143
+ return module
144
+
145
+ # supported modules
146
+ assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d))
147
+
148
+ # test whether module.weight has the right sizes wrt block_size
149
+ is_conv = module.weight.ndim == 4
150
+
151
+ # 2D matrix
152
+ if not is_conv:
153
+ assert (
154
+ module.weight.size(1) % block_size == 0
155
+ ), "Input features must be a multiple of block sizes"
156
+
157
+ # 4D matrix
158
+ else:
159
+ # 1x1 convolutions
160
+ if module.kernel_size == (1, 1):
161
+ assert (
162
+ module.in_channels % block_size == 0
163
+ ), "Input channels must be a multiple of block sizes"
164
+ # regular convolutions
165
+ else:
166
+ k = module.kernel_size[0] * module.kernel_size[1]
167
+ assert k % block_size == 0, "Kernel size must be a multiple of block size"
168
+
169
+ def _forward_pre_hook(mod, input):
170
+ # no noise for evaluation
171
+ if mod.training:
172
+ if not is_conv:
173
+ # gather weight and sizes
174
+ weight = mod.weight
175
+ in_features = weight.size(1)
176
+ out_features = weight.size(0)
177
+
178
+ # split weight matrix into blocks and randomly drop selected blocks
179
+ mask = torch.zeros(
180
+ in_features // block_size * out_features, device=weight.device
181
+ )
182
+ mask.bernoulli_(p)
183
+ mask = mask.repeat_interleave(block_size, -1).view(-1, in_features)
184
+
185
+ else:
186
+ # gather weight and sizes
187
+ weight = mod.weight
188
+ in_channels = mod.in_channels
189
+ out_channels = mod.out_channels
190
+
191
+ # split weight matrix into blocks and randomly drop selected blocks
192
+ if mod.kernel_size == (1, 1):
193
+ mask = torch.zeros(
194
+ int(in_channels // block_size * out_channels),
195
+ device=weight.device,
196
+ )
197
+ mask.bernoulli_(p)
198
+ mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels)
199
+ else:
200
+ mask = torch.zeros(
201
+ weight.size(0), weight.size(1), device=weight.device
202
+ )
203
+ mask.bernoulli_(p)
204
+ mask = (
205
+ mask.unsqueeze(2)
206
+ .unsqueeze(3)
207
+ .repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1])
208
+ )
209
+
210
+ # scale weights and apply mask
211
+ mask = mask.to(
212
+ torch.bool
213
+ ) # x.bool() is not currently supported in TorchScript
214
+ s = 1 / (1 - p)
215
+ mod.weight.data = s * weight.masked_fill(mask, 0)
216
+
217
+ module.register_forward_pre_hook(_forward_pre_hook)
218
+ return module
modules/BEATs/quantizer.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
3
+ # Github source: https://github.com/microsoft/unilm/tree/master/beats
4
+ # Copyright (c) 2022 Microsoft
5
+ # Licensed under The MIT License [see LICENSE for details]
6
+ # Based on VQGAN code bases
7
+ # https://github.com/CompVis/taming-transformers
8
+ # --------------------------------------------------------'
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ import torch.distributed as distributed
14
+
15
+ try:
16
+ from einops import rearrange, repeat
17
+ except ImportError:
18
+ pass
19
+
20
+
21
+ def l2norm(t):
22
+ return F.normalize(t, p=2, dim=-1)
23
+
24
+
25
+ def ema_inplace(moving_avg, new, decay):
26
+ moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
27
+
28
+
29
+ def sample_vectors(samples, num):
30
+ num_samples, device = samples.shape[0], samples.device
31
+
32
+ if num_samples >= num:
33
+ indices = torch.randperm(num_samples, device=device)[:num]
34
+ else:
35
+ indices = torch.randint(0, num_samples, (num,), device=device)
36
+
37
+ return samples[indices]
38
+
39
+
40
+ def kmeans(samples, num_clusters, num_iters=10, use_cosine_sim=False):
41
+ dim, dtype, device = samples.shape[-1], samples.dtype, samples.device
42
+
43
+ means = sample_vectors(samples, num_clusters)
44
+
45
+ for _ in range(num_iters):
46
+ if use_cosine_sim:
47
+ dists = samples @ means.t()
48
+ else:
49
+ diffs = rearrange(samples, 'n d -> n () d') \
50
+ - rearrange(means, 'c d -> () c d')
51
+ dists = -(diffs ** 2).sum(dim=-1)
52
+
53
+ buckets = dists.max(dim=-1).indices
54
+ bins = torch.bincount(buckets, minlength=num_clusters)
55
+ zero_mask = bins == 0
56
+ bins_min_clamped = bins.masked_fill(zero_mask, 1)
57
+
58
+ new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
59
+ new_means.scatter_add_(0, repeat(buckets, 'n -> n d', d=dim), samples)
60
+ new_means = new_means / bins_min_clamped[..., None]
61
+
62
+ if use_cosine_sim:
63
+ new_means = l2norm(new_means)
64
+
65
+ means = torch.where(zero_mask[..., None], means, new_means)
66
+
67
+ return means, bins
68
+
69
+
70
+ class EmbeddingEMA(nn.Module):
71
+ def __init__(self, num_tokens, codebook_dim, decay=0.99, eps=1e-5, kmeans_init=True, codebook_init_path=''):
72
+ super().__init__()
73
+ self.num_tokens = num_tokens
74
+ self.codebook_dim = codebook_dim
75
+ self.decay = decay
76
+ self.eps = eps
77
+ if codebook_init_path == '':
78
+ if not kmeans_init:
79
+ weight = torch.randn(num_tokens, codebook_dim)
80
+ weight = l2norm(weight)
81
+ else:
82
+ weight = torch.zeros(num_tokens, codebook_dim)
83
+ self.register_buffer('initted', torch.Tensor([not kmeans_init]))
84
+ else:
85
+ print(f"load init codebook weight from {codebook_init_path}")
86
+ codebook_ckpt_weight = torch.load(codebook_init_path, map_location='cpu')
87
+ weight = codebook_ckpt_weight.clone()
88
+ self.register_buffer('initted', torch.Tensor([True]))
89
+
90
+ self.weight = nn.Parameter(weight, requires_grad=False)
91
+ self.cluster_size = nn.Parameter(torch.zeros(num_tokens), requires_grad=False)
92
+ self.embed_avg = nn.Parameter(weight.clone(), requires_grad=False)
93
+ # self.register_buffer('initted', torch.Tensor([not kmeans_init]))
94
+ self.update = True
95
+
96
+ @torch.jit.ignore
97
+ def init_embed_(self, data):
98
+ if self.initted:
99
+ return
100
+ print("Performing Kemans init for codebook")
101
+ embed, cluster_size = kmeans(data, self.num_tokens, 10, use_cosine_sim=True)
102
+ self.weight.data.copy_(embed)
103
+ self.cluster_size.data.copy_(cluster_size)
104
+ self.initted.data.copy_(torch.Tensor([True]))
105
+
106
+ def forward(self, embed_id):
107
+ return F.embedding(embed_id, self.weight)
108
+
109
+ def cluster_size_ema_update(self, new_cluster_size):
110
+ self.cluster_size.data.mul_(self.decay).add_(new_cluster_size, alpha=1 - self.decay)
111
+
112
+ def embed_avg_ema_update(self, new_embed_avg):
113
+ self.embed_avg.data.mul_(self.decay).add_(new_embed_avg, alpha=1 - self.decay)
114
+
115
+ def weight_update(self, num_tokens):
116
+ n = self.cluster_size.sum()
117
+ smoothed_cluster_size = (
118
+ (self.cluster_size + self.eps) / (n + num_tokens * self.eps) * n
119
+ )
120
+ # normalize embedding average with smoothed cluster size
121
+ embed_normalized = self.embed_avg / smoothed_cluster_size.unsqueeze(1)
122
+ # embed_normalized = l2norm(self.embed_avg / smoothed_cluster_size.unsqueeze(1))
123
+ self.weight.data.copy_(embed_normalized)
124
+
125
+
126
+ def norm_ema_inplace(moving_avg, new, decay):
127
+ moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
128
+ moving_avg.data.copy_(l2norm(moving_avg.data))
129
+
130
+
131
+ class NormEMAVectorQuantizer(nn.Module):
132
+ def __init__(self, n_embed, embedding_dim, beta, decay=0.99, eps=1e-5,
133
+ statistic_code_usage=True, kmeans_init=False, codebook_init_path=''):
134
+ super().__init__()
135
+ self.codebook_dim = embedding_dim
136
+ self.num_tokens = n_embed
137
+ self.beta = beta
138
+ self.decay = decay
139
+
140
+ # learnable = True if orthogonal_reg_weight > 0 else False
141
+ self.embedding = EmbeddingEMA(self.num_tokens, self.codebook_dim, decay, eps, kmeans_init, codebook_init_path)
142
+
143
+ self.statistic_code_usage = statistic_code_usage
144
+ if statistic_code_usage:
145
+ self.register_buffer('cluster_size', torch.zeros(n_embed))
146
+ if distributed.is_available() and distributed.is_initialized():
147
+ print("ddp is enable, so use ddp_reduce to sync the statistic_code_usage for each gpu!")
148
+ self.all_reduce_fn = distributed.all_reduce
149
+ else:
150
+ self.all_reduce_fn = nn.Identity()
151
+
152
+ def reset_cluster_size(self, device):
153
+ if self.statistic_code_usage:
154
+ self.register_buffer('cluster_size', torch.zeros(self.num_tokens))
155
+ self.cluster_size = self.cluster_size.to(device)
156
+
157
+ def forward(self, z):
158
+ # reshape z -> (batch, height, width, channel) and flatten
159
+ # z, 'b c h w -> b h w c'
160
+ # z = rearrange(z, 'b c h w -> b h w c')
161
+ # z = z.transpose(1, 2)
162
+ z = l2norm(z)
163
+ z_flattened = z.reshape(-1, self.codebook_dim)
164
+
165
+ self.embedding.init_embed_(z_flattened)
166
+
167
+ d = z_flattened.pow(2).sum(dim=1, keepdim=True) + \
168
+ self.embedding.weight.pow(2).sum(dim=1) - 2 * \
169
+ torch.einsum('bd,nd->bn', z_flattened, self.embedding.weight) # 'n d -> d n'
170
+
171
+ encoding_indices = torch.argmin(d, dim=1)
172
+
173
+ z_q = self.embedding(encoding_indices).view(z.shape)
174
+
175
+ encodings = F.one_hot(encoding_indices, self.num_tokens).type(z.dtype)
176
+
177
+ if not self.training:
178
+ with torch.no_grad():
179
+ cluster_size = encodings.sum(0)
180
+ self.all_reduce_fn(cluster_size)
181
+ ema_inplace(self.cluster_size, cluster_size, self.decay)
182
+
183
+ if self.training and self.embedding.update:
184
+ # EMA cluster size
185
+
186
+ bins = encodings.sum(0)
187
+ self.all_reduce_fn(bins)
188
+
189
+ # self.embedding.cluster_size_ema_update(bins)
190
+ ema_inplace(self.cluster_size, bins, self.decay)
191
+
192
+ zero_mask = (bins == 0)
193
+ bins = bins.masked_fill(zero_mask, 1.)
194
+
195
+ embed_sum = z_flattened.t() @ encodings
196
+ self.all_reduce_fn(embed_sum)
197
+
198
+ embed_normalized = (embed_sum / bins.unsqueeze(0)).t()
199
+ embed_normalized = l2norm(embed_normalized)
200
+
201
+ embed_normalized = torch.where(zero_mask[..., None], self.embedding.weight,
202
+ embed_normalized)
203
+ norm_ema_inplace(self.embedding.weight, embed_normalized, self.decay)
204
+
205
+ # compute loss for embedding
206
+ loss = self.beta * F.mse_loss(z_q.detach(), z)
207
+
208
+ # preserve gradients
209
+ z_q = z + (z_q - z).detach()
210
+
211
+ # reshape back to match original input shape
212
+ # z_q, 'b h w c -> b c h w'
213
+ # z_q = rearrange(z_q, 'b h w c -> b c h w')
214
+ # z_q = z_q.transpose(1, 2)
215
+ return z_q, loss, encoding_indices
modules/CLIPSeg/clipseg_for_audio.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import transformers
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from torch import nn
5
+ from typing import List, Tuple, Union, Optional
6
+ import numpy as np
7
+ from transformers.models.clipseg.modeling_clipseg import _expand_mask
8
+
9
+
10
+ class CLIPSeg(transformers.CLIPSegForImageSegmentation):
11
+ def __init__(self, *args, **kwargs):
12
+ super().__init__(*args, **kwargs)
13
+
14
+ def encode_text(self, text: torch.Tensor) -> torch.Tensor:
15
+ """
16
+ Encode textual input and return the text embeddings.
17
+
18
+ Args:
19
+ text (torch.Tensor): Input text tensor.
20
+
21
+ Returns:
22
+ torch.Tensor: Text embeddings.
23
+ """
24
+ tokens = text
25
+ if text.ndim == 3:
26
+ tokens = torch.squeeze(text, dim=1)
27
+ non_zero_index = torch.nonzero(tokens.sum(axis=0) == 0)[0]
28
+ input_ids = tokens[:, :non_zero_index]
29
+ attention_mask = (input_ids > 0).to(tokens.dtype)
30
+ input_ids += torch.max(input_ids) * (1 - attention_mask)
31
+ conditional_embeddings = self.clip.get_text_features(input_ids, attention_mask=attention_mask,
32
+ position_ids=None)
33
+
34
+ return conditional_embeddings
35
+
36
+ def similarity(self, image: torch.Tensor, embeddings: List[torch.Tensor]) -> torch.Tensor:
37
+ """
38
+ Calculate the similarity score between an image and a list of embeddings.
39
+
40
+ Args:
41
+ image (torch.Tensor): Input image tensor of shape (B, C, H, W).
42
+ embeddings (List[torch.Tensor]): List of N embedding tensors of shape (dim,).
43
+
44
+ Returns:
45
+ torch.Tensor: Similarity scores of shape (B, N) for each batch.
46
+ """
47
+ B, c, h, w = image.shape
48
+ if (h, w) != (352, 352):
49
+ vision_outputs = self.clip.vision_model(pixel_values=F.interpolate(image, 352, mode='bicubic'),
50
+ output_attentions=False,
51
+ output_hidden_states=False,
52
+ return_dict=False)
53
+ img_embedding = self.clip.visual_projection(vision_outputs[1])
54
+ else:
55
+ vision_outputs = self.clip.vision_model(pixel_values=image,
56
+ output_attentions=False,
57
+ output_hidden_states=False,
58
+ return_dict=False)
59
+ img_embedding = self.clip.visual_projection(vision_outputs[1])
60
+
61
+ paired_embedding = torch.cat(embeddings, dim=0)
62
+ paired_embedding = paired_embedding.repeat(B, 1) # Batch-wise replication of embeddings
63
+ paired_embedding = paired_embedding.view(B, -1, img_embedding.size(-1))
64
+
65
+ result = torch.matmul(F.normalize(paired_embedding, dim=-1), F.normalize(img_embedding, dim=-1).unsqueeze(-1))
66
+ result = result.squeeze(-1).view(B, -1)
67
+ return F.softmax(result, dim=-1)
68
+
69
+ def encode_audio(self, placeholder_token: torch.Tensor, audio_token: torch.Tensor, pos: int,
70
+ length: int) -> torch.Tensor:
71
+ """
72
+ Encode audio token into the audio-driven embeddings. (Audio-Driven Embedder)
73
+
74
+ Args:
75
+ placeholder_token (torch.Tensor): Placeholder text token tensor.
76
+ audio_token (torch.Tensor): Audio token tensor.
77
+ pos (int): Position index for audio token.
78
+ length (int): Length of the input token.
79
+
80
+ Returns:
81
+ torch.Tensor: Audio-driven embeddings.
82
+
83
+ Reference:
84
+ "Can CLIP Help Sound Source Localization?" WACV 2024
85
+ - https://arxiv.org/abs/2311.04066
86
+ """
87
+ tokens = placeholder_token
88
+ if placeholder_token.ndim == 3:
89
+ tokens = torch.squeeze(placeholder_token, dim=1)
90
+
91
+ inputs_embeds = self.clip.text_model.embeddings.token_embedding(tokens).type(
92
+ self.dtype) # [batch_size, n_ctx, d_model]
93
+ inputs_embeds = torch.cat((inputs_embeds[:, :pos, :], audio_token, inputs_embeds[:, pos:, :]),
94
+ dim=1) # Inject Audio token
95
+ inputs_embeds = inputs_embeds[:, :length, :]
96
+
97
+ bsz, seq_len, _ = inputs_embeds.shape
98
+ attention_mask = torch.ones((bsz, seq_len)).to(placeholder_token.device)
99
+ position_ids = torch.arange(length).unsqueeze(0).to(placeholder_token.device)
100
+
101
+ position_embeddings = self.clip.text_model.embeddings.position_embedding(position_ids)
102
+ hidden_states = inputs_embeds + position_embeddings
103
+
104
+ bsz, seq_len, _ = inputs_embeds.shape
105
+ # CLIPSeg's text model uses causal mask, prepare it here.
106
+ # https://github.com/openai/CLIPSeg/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clipseg/model.py#L324
107
+ causal_attention_mask = self.clip.text_model._build_causal_attention_mask(bsz, seq_len, hidden_states.dtype).to(
108
+ hidden_states.device
109
+ )
110
+ # expand attention_mask
111
+ if attention_mask is not None:
112
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
113
+ attention_mask = _expand_mask(attention_mask, hidden_states.dtype)
114
+
115
+ encoder_outputs = self.clip.text_model.encoder(
116
+ inputs_embeds=hidden_states,
117
+ attention_mask=attention_mask,
118
+ causal_attention_mask=causal_attention_mask,
119
+ output_attentions=False,
120
+ output_hidden_states=False,
121
+ return_dict=True,
122
+ )
123
+
124
+ last_hidden_state = encoder_outputs[0]
125
+ last_hidden_state = self.clip.text_model.final_layer_norm(last_hidden_state)
126
+
127
+ # text_embeds.shape = [batch_size, sequence_length, transformer.width]
128
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
129
+ # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
130
+ pooled_output = last_hidden_state[:, -1, :]
131
+ audio_driven_embeddings = self.clip.text_projection(pooled_output)
132
+ return audio_driven_embeddings
133
+
134
+ def get_pixels(self, image: torch.Tensor) -> torch.Tensor:
135
+ """
136
+ Extract spatial features (pixel-level) from the CLIP image encoder.
137
+
138
+ Args:
139
+ image (torch.Tensor): Input image tensor.
140
+
141
+ Returns:
142
+ torch.Tensor: Spatial visual features (pixel-level).
143
+ """
144
+ vision_outputs = self.clip.vision_model(pixel_values=image,
145
+ output_attentions=None,
146
+ output_hidden_states=True,
147
+ return_dict=True)
148
+ last_layer = self.clip.vision_model.encoder.layers[-1]
149
+
150
+ hidden_states = vision_outputs.hidden_states[-2]
151
+ residual = hidden_states
152
+
153
+ hidden_states = last_layer.layer_norm1(hidden_states)
154
+
155
+ bsz, tgt_len, embed_dim = hidden_states.size()
156
+
157
+ # get query proj
158
+ # query_states = last_layer.self_attn.q_proj(hidden_states) * last_layer.self_attn.scale
159
+ # key_states = last_layer.self_attn.k_proj(hidden_states)
160
+ value_states = last_layer.self_attn.v_proj(hidden_states)
161
+
162
+ value_states = last_layer.self_attn.out_proj(value_states)
163
+
164
+ value_states += residual
165
+
166
+ residual = value_states
167
+ value_states = last_layer.layer_norm2(value_states)
168
+ value_states = last_layer.mlp(value_states)
169
+ value_states += residual
170
+
171
+ value_states = self.clip.vision_model.post_layernorm(value_states)
172
+ output = self.clip.visual_projection(value_states)
173
+
174
+ width = int(np.sqrt(tgt_len - 1))
175
+ output = output[:, 1:]
176
+ if output.ndim == 2:
177
+ output = output.unsqueeze(0)
178
+
179
+ output = output.permute(0, 2, 1)
180
+ output = output.reshape(bsz, self.clip.visual_projection.out_features, width, width)
181
+
182
+ return output
modules/FGA/atten.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from torch.autograd import Variable
6
+ from itertools import product, permutations, combinations_with_replacement, chain
7
+
8
+
9
+ class Unary(nn.Module):
10
+ def __init__(self, embed_size):
11
+ """
12
+ Captures local entity information
13
+ :param embed_size: the embedding dimension
14
+ """
15
+ super(Unary, self).__init__()
16
+ self.embed = nn.Conv1d(embed_size, embed_size, 1)
17
+ self.feature_reduce = nn.Conv1d(embed_size, 1, 1)
18
+
19
+ def forward(self, X):
20
+ X = X.transpose(1, 2)
21
+
22
+ X_embed = self.embed(X)
23
+
24
+ X_nl_embed = F.dropout(F.relu(X_embed), training=self.training)
25
+ X_poten = self.feature_reduce(X_nl_embed)
26
+ return X_poten.squeeze(1)
27
+
28
+
29
+ class Pairwise(nn.Module):
30
+ def __init__(self, embed_x_size, x_spatial_dim=None, embed_y_size=None, y_spatial_dim=None):
31
+ """
32
+ Captures interaction between utilities or entities of the same utility
33
+ :param embed_x_size: the embedding dimension of the first utility
34
+ :param x_spatial_dim: the spatial dimension of the first utility for batch norm and weighted marginalization
35
+ :param embed_y_size: the embedding dimension of the second utility (none for self-interactions)
36
+ :param y_spatial_dim: the spatial dimension of the second utility for batch norm and weighted marginalization
37
+ """
38
+
39
+ super(Pairwise, self).__init__()
40
+ embed_y_size = embed_y_size if y_spatial_dim is not None else embed_x_size
41
+ self.y_spatial_dim = y_spatial_dim if y_spatial_dim is not None else x_spatial_dim
42
+
43
+ self.embed_size = max(embed_x_size, embed_y_size)
44
+ self.x_spatial_dim = x_spatial_dim
45
+
46
+ self.embed_X = nn.Conv1d(embed_x_size, self.embed_size, 1)
47
+ self.embed_Y = nn.Conv1d(embed_y_size, self.embed_size, 1)
48
+ if x_spatial_dim is not None:
49
+ self.normalize_S = nn.BatchNorm1d(self.x_spatial_dim * self.y_spatial_dim)
50
+
51
+ self.margin_X = nn.Conv1d(self.y_spatial_dim, 1, 1)
52
+ self.margin_Y = nn.Conv1d(self.x_spatial_dim, 1, 1)
53
+
54
+ def forward(self, X, Y=None):
55
+
56
+ X_t = X.transpose(1, 2)
57
+ Y_t = Y.transpose(1, 2) if Y is not None else X_t
58
+
59
+
60
+ X_embed = self.embed_X(X_t)
61
+ Y_embed = self.embed_Y(Y_t)
62
+
63
+ X_norm = F.normalize(X_embed)
64
+ Y_norm = F.normalize(Y_embed)
65
+
66
+ S = X_norm.transpose(1, 2).bmm(Y_norm)
67
+ if self.x_spatial_dim is not None:
68
+ S = self.normalize_S(S.view(-1, self.x_spatial_dim * self.y_spatial_dim)) \
69
+ .view(-1, self.x_spatial_dim, self.y_spatial_dim)
70
+
71
+ X_poten = self.margin_X(S.transpose(1, 2)).transpose(1, 2).squeeze(2)
72
+ Y_poten = self.margin_Y(S).transpose(1, 2).squeeze(2)
73
+ else:
74
+ X_poten = S.mean(dim=2, keepdim=False)
75
+ Y_poten = S.mean(dim=1, keepdim=False)
76
+
77
+ if Y is None:
78
+ return X_poten
79
+ else:
80
+ return X_poten, Y_poten
81
+
82
+
83
+ class Atten(nn.Module):
84
+ def __init__(self, util_e, sharing_factor_weights=[], prior_flag=False,
85
+ sizes=[], size_force=False, pairwise_flag=True,
86
+ unary_flag=True, self_flag=True):
87
+ """
88
+ The class performs an attention on a given list of utilities representation.
89
+ :param util_e: the embedding dimensions
90
+ :param sharing_factor_weights: To share weights, provide a dict of tuples:
91
+ {idx: (num_utils, connected utils)
92
+ Note, for efficiency, the shared utils (i.e., history, are connected to ans
93
+ and question only.
94
+ TODO: connections between shared utils
95
+ :param prior_flag: is prior factor provided
96
+ :param sizes: the spatial simension (used for batch-norm and weighted marginalization)
97
+ :param size_force: force spatial size with adaptive avg pooling.
98
+ :param pairwise_flag: use pairwise interaction between utilities
99
+ :param unary_flag: use local information
100
+ :param self_flag: use self interactions between utilitie's entities
101
+ """
102
+ super(Atten, self).__init__()
103
+ self.util_e = util_e
104
+
105
+ self.prior_flag = prior_flag
106
+
107
+ self.n_utils = len(util_e)
108
+
109
+ self.spatial_pool = nn.ModuleDict()
110
+
111
+ self.un_models = nn.ModuleList()
112
+
113
+ self.self_flag = self_flag
114
+ self.pairwise_flag = pairwise_flag
115
+ self.unary_flag = unary_flag
116
+ self.size_force = size_force
117
+
118
+ if len(sizes) == 0:
119
+ sizes = [None for _ in util_e]
120
+
121
+ self.sharing_factor_weights = sharing_factor_weights
122
+
123
+ #force the provided size
124
+ for idx, e_dim in enumerate(util_e):
125
+ self.un_models.append(Unary(e_dim))
126
+ if self.size_force:
127
+ self.spatial_pool[str(idx)] = nn.AdaptiveAvgPool1d(sizes[idx])
128
+
129
+ #Pairwise
130
+ self.pp_models = nn.ModuleDict()
131
+ for ((idx1, e_dim_1), (idx2, e_dim_2)) \
132
+ in combinations_with_replacement(enumerate(util_e), 2):
133
+ # self
134
+ if self.self_flag and idx1 == idx2:
135
+ self.pp_models[str(idx1)] = Pairwise(e_dim_1, sizes[idx1])
136
+ else:
137
+ if pairwise_flag:
138
+ if idx1 in self.sharing_factor_weights:
139
+ # not connected
140
+ if idx2 not in self.sharing_factor_weights[idx1][1]:
141
+ continue
142
+ if idx2 in self.sharing_factor_weights:
143
+ # not connected
144
+ if idx1 not in self.sharing_factor_weights[idx2][1]:
145
+ continue
146
+ self.pp_models[str((idx1, idx2))] = Pairwise(e_dim_1, sizes[idx1], e_dim_2, sizes[idx2])
147
+
148
+ # Handle reduce potentials (with scalars)
149
+ self.reduce_potentials = nn.ModuleList()
150
+
151
+ self.num_of_potentials = dict()
152
+
153
+ self.default_num_of_potentials = 0
154
+
155
+ if self.self_flag:
156
+ self.default_num_of_potentials += 1
157
+ if self.unary_flag:
158
+ self.default_num_of_potentials += 1
159
+ if self.prior_flag:
160
+ self.default_num_of_potentials += 1
161
+ for idx in range(self.n_utils):
162
+ self.num_of_potentials[idx] = self.default_num_of_potentials
163
+
164
+ '''
165
+ All other utilities
166
+ '''
167
+ if pairwise_flag:
168
+ for idx, (num_utils, connected_utils) in sharing_factor_weights:
169
+ for c_u in connected_utils:
170
+ self.num_of_potentials[c_u] += num_utils
171
+ self.num_of_potentials[idx] += 1
172
+ for k in self.num_of_potentials:
173
+ if k not in self.sharing_factor_weights:
174
+ self.num_of_potentials[k] += (self.n_utils - 1) \
175
+ - len(sharing_factor_weights)
176
+
177
+ for idx in range(self.n_utils):
178
+ self.reduce_potentials.append(nn.Conv1d(self.num_of_potentials[idx],
179
+ 1, 1, bias=False))
180
+
181
+ def forward(self, utils, priors=None):
182
+ assert self.n_utils == len(utils)
183
+ assert (priors is None and not self.prior_flag) \
184
+ or (priors is not None
185
+ and self.prior_flag
186
+ and len(priors) == self.n_utils)
187
+ b_size = utils[0].size(0)
188
+ util_factors = dict()
189
+ attention = list()
190
+
191
+ #Force size, constant size is used for pairwise batch normalization
192
+ if self.size_force:
193
+ for i, (num_utils, _) in self.sharing_factor_weights.items():
194
+ if str(i) not in self.spatial_pool.keys():
195
+ continue
196
+ else:
197
+ high_util = utils[i]
198
+ high_util = high_util.view(num_utils * b_size, high_util.size(2), high_util.size(3))
199
+ high_util = high_util.transpose(1, 2)
200
+ utils[i] = self.spatial_pool[str(i)](high_util).transpose(1, 2)
201
+
202
+ for i in range(self.n_utils):
203
+ if i in self.sharing_factor_weights \
204
+ or str(i) not in self.spatial_pool.keys():
205
+ continue
206
+ utils[i] = utils[i].transpose(1, 2)
207
+ utils[i] = self.spatial_pool[str(i)](utils[i]).transpose(1, 2)
208
+ if self.prior_flag and priors[i] is not None:
209
+ priors[i] = self.spatial_pool[str(i)](priors[i].unsqueeze(1)).squeeze(1)
210
+
211
+ # handle Shared weights
212
+ for i, (num_utils, connected_list) in self.sharing_factor_weights:
213
+ if self.unary_flag:
214
+ util_factors.setdefault(i, []).append(self.un_models[i](utils[i]))
215
+
216
+ if self.self_flag:
217
+ util_factors.setdefault(i, []).append(self.pp_models[str(i)](utils[i]))
218
+
219
+ if self.pairwise_flag:
220
+ for j in connected_list:
221
+ other_util = utils[j]
222
+ expanded_util = other_util.unsqueeze(1).expand(b_size,
223
+ num_utils,
224
+ other_util.size(1),
225
+ other_util.size(2)).contiguous().view(
226
+ b_size * num_utils,
227
+ other_util.size(1),
228
+ other_util.size(2))
229
+
230
+ if i < j:
231
+ factor_ij, factor_ji = self.pp_models[str((i, j))](utils[i], expanded_util)
232
+ else:
233
+ factor_ji, factor_ij = self.pp_models[str((j, i))](expanded_util, utils[i])
234
+ util_factors[i].append(factor_ij)
235
+ util_factors.setdefault(j, []).append(factor_ji.view(b_size, num_utils, factor_ji.size(1)))
236
+
237
+ # handle local factors
238
+ for i in range(self.n_utils):
239
+ if i in self.sharing_factor_weights:
240
+ continue
241
+ if self.unary_flag:
242
+ util_factors.setdefault(i, []).append(self.un_models[i](utils[i]))
243
+ if self.self_flag:
244
+ util_factors.setdefault(i, []).append(self.pp_models[str(i)](utils[i]))
245
+
246
+ # joint
247
+ if self.pairwise_flag:
248
+ for (i, j) in combinations_with_replacement(range(self.n_utils), 2):
249
+ if i in self.sharing_factor_weights \
250
+ or j in self.sharing_factor_weights:
251
+ continue
252
+ if i == j:
253
+ continue
254
+ else:
255
+ factor_ij, factor_ji = self.pp_models[str((i, j))](utils[i], utils[j])
256
+ util_factors.setdefault(i, []).append(factor_ij)
257
+ util_factors.setdefault(j, []).append(factor_ji)
258
+
259
+ # perform attention
260
+ for i in range(self.n_utils):
261
+ if self.prior_flag:
262
+ prior = priors[i] \
263
+ if priors[i] is not None \
264
+ else torch.zeros_like(util_factors[i][0], requires_grad=False).cuda()
265
+
266
+ util_factors[i].append(prior)
267
+
268
+ util_factors[i] = torch.cat([p if len(p.size()) == 3 else p.unsqueeze(1)
269
+ for p in util_factors[i]], dim=1)
270
+ util_factors[i] = self.reduce_potentials[i](util_factors[i]).squeeze(1)
271
+ util_factors[i] = F.softmax(util_factors[i], dim=1).unsqueeze(2)
272
+ attention.append(torch.bmm(utils[i].transpose(1, 2), util_factors[i]).squeeze(2))
273
+
274
+ return attention
275
+
276
+
277
+ class NaiveAttention(nn.Module):
278
+ def __init__(self):
279
+ """
280
+ Used for ablation analysis - removing attention.
281
+ """
282
+ super(NaiveAttention, self).__init__()
283
+
284
+ def forward(self, utils, priors):
285
+ atten = []
286
+ spatial_atten = []
287
+ for u, p in zip(utils, priors):
288
+ if type(u) is tuple:
289
+ u = u[1]
290
+ num_elements = u.shape[0]
291
+ if p is not None:
292
+ u = u.view(-1, u.shape[-2], u.shape[-1])
293
+ p = p.view(-1, p.shape[-2], p.shape[-1])
294
+ spatial_atten.append(
295
+ torch.bmm(p.transpose(1, 2), u).squeeze(2).view(num_elements, -1, u.shape[-2], u.shape[-1]))
296
+ else:
297
+ spatial_atten.append(u.mean(2))
298
+ continue
299
+ if p is not None:
300
+ atten.append(torch.bmm(u.transpose(1, 2), p.unsqueeze(2)).squeeze(2))
301
+ else:
302
+ atten.append(u.mean(1))
303
+ return atten, spatial_atten
modules/FGA/fga_model.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from modules.FGA.atten import Atten
5
+
6
+
7
+ class FGA(nn.Module):
8
+ def __init__(self, vocab_size, word_embed_dim, hidden_ques_dim, hidden_ans_dim,
9
+ hidden_hist_dim, hidden_cap_dim, hidden_img_dim):
10
+ '''
11
+ Factor Graph Attention
12
+ :param vocab_size: vocabulary size
13
+ :param word_embed_dim
14
+ :param hidden_ques_dim:
15
+ :param hidden_ans_dim:
16
+ :param hidden_hist_dim:
17
+ :param img_features_dim:
18
+ '''
19
+ super(FGA, self).__init__()
20
+
21
+ print("Init FGA with vocab size %s, word embed %s, hidden ques %s, hidden ans %s,"
22
+ " hidden hist %s, hidden cap %s, hidden img %s" % (vocab_size, word_embed_dim,
23
+ hidden_ques_dim,
24
+ hidden_ans_dim,
25
+ hidden_hist_dim,
26
+ hidden_cap_dim,
27
+ hidden_img_dim))
28
+ self.hidden_ques_dim = hidden_ques_dim
29
+ self.hidden_ans_dim = hidden_ans_dim
30
+ self.hidden_cap_dim = hidden_cap_dim
31
+ self.hidden_img_dim = hidden_img_dim
32
+ self.hidden_hist_dim = hidden_hist_dim
33
+
34
+ # Vocab of History LSTMs is one more as we are keeping a stop id (the last id)
35
+ self.word_embedddings = nn.Embedding(vocab_size+1+1, word_embed_dim, padding_idx=0)
36
+
37
+ self.lstm_ques = nn.LSTM(word_embed_dim, self.hidden_ques_dim, batch_first=True)
38
+ self.lstm_ans = nn.LSTM(word_embed_dim, self.hidden_ans_dim, batch_first=True)
39
+
40
+ self.lstm_hist_ques = nn.LSTM(word_embed_dim, self.hidden_hist_dim, batch_first=True)
41
+ self.lstm_hist_ans = nn.LSTM(word_embed_dim, self.hidden_hist_dim, batch_first=True)
42
+
43
+ self.lstm_hist_cap = nn.LSTM(word_embed_dim, self.hidden_cap_dim, batch_first=True)
44
+
45
+
46
+ self.qahistnet = nn.Sequential(
47
+ nn.Linear(self.hidden_hist_dim*2, self.hidden_hist_dim),
48
+ nn.ReLU(inplace=True)
49
+ )
50
+
51
+ self.concat_dim = self.hidden_ques_dim + self.hidden_ans_dim + \
52
+ self.hidden_ans_dim + self.hidden_img_dim + \
53
+ self.hidden_cap_dim + self.hidden_hist_dim*9
54
+
55
+ self.simnet = nn.Sequential(
56
+ nn.Linear(self.concat_dim, (self.concat_dim)//2, bias=False),
57
+ nn.BatchNorm1d((self.concat_dim) // 2),
58
+ nn.ReLU(inplace=True),
59
+ nn.Linear((self.concat_dim)//2, (self.concat_dim)//4, bias=False),
60
+ nn.BatchNorm1d((self.concat_dim) // 4),
61
+ nn.ReLU(inplace=True),
62
+ nn.Dropout(0.5),
63
+ nn.Linear((self.concat_dim)//4, 1)
64
+ )
65
+
66
+ # To share weights, provide list of tuples: (idx, list of connected utils)
67
+ # Note, for efficiency, the shared utils (i.e., history, are connected to ans and question only.
68
+ # connecting shared factors is not supported (!)
69
+ sharing_factor_weights = {4: (9, [0, 1]),
70
+ 5: (9, [0, 1])}
71
+
72
+ self.mul_atten = Atten(util_e=[self.hidden_ans_dim, # Answer modal
73
+ self.hidden_ques_dim, # Question modal
74
+ self.hidden_cap_dim, # Caption modal
75
+ self.hidden_img_dim, # Image modal
76
+ self.hidden_hist_dim, # Question-history modal
77
+ self.hidden_hist_dim # Answer-history modal
78
+ ],
79
+ sharing_factor_weights=sharing_factor_weights,
80
+ sizes=[100, # 100 Answers
81
+ 21, # Question length
82
+ 41, # Caption length
83
+ 37, # 36 Image regions
84
+ 21, # History-Question length
85
+ 21 # History-Answer length
86
+ ] # The spatial dim used for pairwise normalization (use force for adaptive)
87
+ , prior_flag=True,
88
+ pairwise_flag=True)
89
+
90
+
91
+
92
+ def forward(self, input_ques, input_ans, input_hist_ques, input_hist_ans, input_hist_cap,
93
+ input_ques_length, input_ans_length, input_cap_length, i_e):
94
+ """
95
+
96
+ :param input_ques:
97
+ :param input_ans:
98
+ :param input_hist_ques:
99
+ :param input_hist_ans:
100
+ :param input_hist_cap:
101
+ :param input_ques_length:
102
+ :param input_ans_length:
103
+ :param input_cap_length:
104
+ :param i_e:
105
+ :return:
106
+ """
107
+
108
+
109
+ n_options = input_ans.size()[1]
110
+ batch_size = input_ques.size()[0]
111
+
112
+
113
+
114
+ nqa_per_dial, nwords_per_qa = input_hist_ques.size()[1], input_hist_ques.size()[2]
115
+ nwords_per_cap = input_hist_cap.size()[1]
116
+ max_length_input_ans = input_ans.size()[-1]
117
+
118
+ assert batch_size == input_hist_ques.size()[0] == input_hist_ans.size()[0] == input_ques.size()[0] == \
119
+ input_ans.size()[0] == input_hist_cap.size()[0]
120
+ assert nqa_per_dial == input_hist_ques.size()[1] == input_hist_ans.size()[1]
121
+ assert nwords_per_qa == input_hist_ques.size()[2] == input_hist_ans.size()[2]
122
+
123
+ q_we = self.word_embedddings(input_ques)
124
+ a_we = self.word_embedddings(input_ans.view(-1, max_length_input_ans))
125
+ hq_we = self.word_embedddings(input_hist_ques.view(-1, nwords_per_qa))
126
+ ha_we = self.word_embedddings(input_hist_ans.view(-1, nwords_per_qa))
127
+ c_we = self.word_embedddings(input_hist_cap.view(-1, nwords_per_cap))
128
+
129
+
130
+
131
+ '''
132
+ q_we = batch x 20 x embed_ques_dim
133
+ a_we = 100*batch x 20 x embed_ans_dim
134
+ hq_we = batch*nqa_per_dial, nwords_per_qa, embed_hist_dim
135
+ ha_we = batch*nqa_per_dial, nwords_per_qa, embed_hist_dim
136
+ c_we = batch*ncap_per_dial, nwords_per_cap, embed_hist_dim
137
+ '''
138
+ self.lstm_ques.flatten_parameters()
139
+ self.lstm_ans.flatten_parameters()
140
+ self.lstm_hist_ques.flatten_parameters()
141
+ self.lstm_hist_ans.flatten_parameters()
142
+ self.lstm_hist_cap.flatten_parameters()
143
+
144
+
145
+ i_feat = i_e
146
+
147
+ q_seq, self.hidden_ques = self.lstm_ques(q_we)
148
+ a_seq, self.hidden_ans = self.lstm_ans(a_we)
149
+ hq_seq, self.hidden_hist_ques = self.lstm_hist_ques(hq_we)
150
+ ha_seq, self.hidden_hist_ans = self.lstm_hist_ans(ha_we)
151
+ cap_seq, self.hidden_cap = self.lstm_hist_cap(c_we)
152
+
153
+
154
+ '''
155
+ length is used for attention prior
156
+ '''
157
+ q_len = input_ques_length.data - 1
158
+ c_len = input_cap_length.data.view(-1) - 1
159
+
160
+
161
+ ans_index = torch.arange(0, n_options * batch_size).long().cuda()
162
+ ans_len = input_ans_length.data.view(-1) - 1
163
+ ans_seq = a_seq[ans_index, ans_len, :]
164
+ ans_seq = ans_seq.view(batch_size, n_options, self.hidden_ans_dim)
165
+
166
+ batch_index = torch.arange(0, batch_size).long().cuda()
167
+ q_prior = torch.zeros(batch_size, q_seq.size(1)).cuda()
168
+ q_prior[batch_index, q_len] = 100
169
+ c_prior = torch.zeros(batch_size, cap_seq.size(1)).cuda()
170
+ c_prior[batch_index, c_len] = 100
171
+ ans_prior = torch.ones(batch_size, ans_seq.size(1)).cuda()
172
+ img_prior = torch.ones(batch_size, i_feat.size(1)).cuda()
173
+
174
+ (ans_atten, ques_atten, cap_atten, img_atten, hq_atten, ha_atten) = \
175
+ self.mul_atten([ans_seq, q_seq, cap_seq, i_feat, hq_seq, ha_seq],
176
+ priors=[ans_prior, q_prior, c_prior, img_prior, None, None])
177
+
178
+ '''
179
+ expand to answers based
180
+ '''
181
+ ques_atten = torch.unsqueeze(ques_atten, 1).expand(batch_size,
182
+ n_options,
183
+ self.hidden_ques_dim)
184
+ cap_atten = torch.unsqueeze(cap_atten, 1).expand(batch_size,
185
+ n_options,
186
+ self.hidden_cap_dim)
187
+ img_atten = torch.unsqueeze(img_atten, 1).expand(batch_size, n_options,
188
+ self.hidden_img_dim)
189
+ ans_atten = torch.unsqueeze(ans_atten, 1).expand(batch_size, n_options,
190
+ self.hidden_ans_dim)
191
+
192
+
193
+ '''
194
+ combine history
195
+ '''
196
+
197
+ input_qahistnet = torch.cat((hq_atten, ha_atten), 1)
198
+ # input_qahistnet: (nqa_per_dial*batch x 2*hidden_hist_dim)
199
+ output_qahistnet = self.qahistnet(input_qahistnet)
200
+ # output_qahistnet: (nqa_per_dial*batch x hidden_hist_dim)
201
+ output_qahistnet = output_qahistnet.view(batch_size,
202
+ nqa_per_dial * self.hidden_hist_dim)
203
+ # output_qahistnet: (batch x nqa_per_dial*hidden_hist_dim)
204
+ output_qahistnet = torch.unsqueeze(output_qahistnet, 1)\
205
+ .expand(batch_size,
206
+ n_options,
207
+ nqa_per_dial * self.hidden_hist_dim)
208
+
209
+ input_qa = torch.cat((ans_seq, ques_atten, ans_atten, img_atten,
210
+ output_qahistnet, cap_atten), 2) # Concatenate last dimension
211
+
212
+ input_qa = input_qa.view(batch_size * n_options, self.concat_dim)
213
+
214
+ out_scores = self.simnet(input_qa)
215
+
216
+ out_scores = out_scores.squeeze(dim=1)
217
+ out_scores = out_scores.view(batch_size, n_options)
218
+
219
+ return out_scores
modules/arg_utils.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from typing import List, Optional, Union, Tuple
3
+
4
+
5
+ def int_or_int_list_or_none(value: Optional[Union[int, str]]) -> List[Optional[int]]:
6
+ """
7
+ Parse an input value into a list of integers or a single integer, or None.
8
+
9
+ Args:
10
+ value (Optional[Union[int, str]]): The input value to parse.
11
+
12
+ Returns:
13
+ List[Optional[int]]: A list containing either a single integer, a list of integers,
14
+ or a single None value.
15
+
16
+ Raises:
17
+ argparse.ArgumentTypeError: If the input value cannot be parsed into the specified formats.
18
+ """
19
+ if value in ['None', 'null']:
20
+ return [None]
21
+ try:
22
+ # If the value contains commas, parse it as a comma-separated list of integers
23
+ if ',' in value:
24
+ return [int(x) for x in value.split(',')]
25
+ # If it's a single integer, pack it into a list
26
+ else:
27
+ return [int(value)]
28
+ except ValueError:
29
+ raise argparse.ArgumentTypeError("Invalid format. Use an integer, a comma-separated list of integers, or None.")
30
+
31
+
32
+ def int_or_float(value):
33
+ if '.' in value:
34
+ try:
35
+ return float(value)
36
+ except ValueError:
37
+ raise argparse.ArgumentTypeError("Quality level must be an integer or a float")
38
+ else:
39
+ try:
40
+ return int(value)
41
+ except ValueError:
42
+ raise argparse.ArgumentTypeError("Quality level must be an integer or a float")
modules/mask_utils.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ def gumbel_sigmoid(logits: torch.Tensor, tau: float = 1, hard: bool = False):
6
+ """Samples from the Gumbel-Sigmoid distribution and optionally discretizes.
7
+ References:
8
+ - https://github.com/yandexdataschool/gumbel_dpg/blob/master/gumbel.py
9
+ - https://pytorch.org/docs/stable/_modules/torch/nn/functional.html#gumbel_softmax
10
+ Note:
11
+ X - Y ~ Logistic(0,1) s.t. X, Y ~ Gumbel(0, 1).
12
+ That is, we can implement gumbel_sigmoid using Logistic distribution.
13
+ """
14
+ logistic = torch.rand_like(logits)
15
+ logistic = logistic.div_(1. - logistic).log_() # ~Logistic(0,1)
16
+
17
+ gumbels = (logits + logistic) / tau # ~Logistic(logits, tau)
18
+ y_soft = gumbels.sigmoid_()
19
+
20
+ if hard:
21
+ # Straight through.
22
+ y_hard = y_soft.gt(0.5).type(y_soft.dtype)
23
+ # gt_ break gradient flow
24
+ # y_hard = y_soft.gt_(0.5) # gt_() maintain dtype, different to gt()
25
+ ret = y_hard - y_soft.detach() + y_soft
26
+ else:
27
+ # Reparametrization trick.
28
+ ret = y_soft
29
+
30
+ return ret
31
+
32
+
33
+ class Sim2Mask(nn.Module):
34
+ def __init__(self, init_w: float = 1.0, init_b: float = 0.0, gumbel_tau: float = 1.0, learnable: bool = True):
35
+ """
36
+ Sim2Mask module for generating binary masks.
37
+
38
+ Args:
39
+ init_w (float): Initial value for weight.
40
+ init_b (float): Initial value for bias.
41
+ gumbel_tau (float): Gumbel-Softmax temperature.
42
+ learnable (bool): If True, weight and bias are learnable parameters.
43
+
44
+ Reference:
45
+ "Learning to Generate Text-grounded Mask for Open-world Semantic Segmentation from Only Image-Text Pairs" CVPR 2023
46
+ - https://github.com/kakaobrain/tcl
47
+ - https://arxiv.org/abs/2212.00785
48
+ """
49
+ super().__init__()
50
+ self.init_w = init_w
51
+ self.init_b = init_b
52
+ self.gumbel_tau = gumbel_tau
53
+ self.learnable = learnable
54
+
55
+ assert not ((init_w is None) ^ (init_b is None))
56
+ if learnable:
57
+ self.w = nn.Parameter(torch.full([], float(init_w)))
58
+ self.b = nn.Parameter(torch.full([], float(init_b)))
59
+ else:
60
+ self.w = init_w
61
+ self.b = init_b
62
+
63
+ def forward(self, x, deterministic=False):
64
+ logits = x * self.w + self.b
65
+
66
+ soft_mask = torch.sigmoid(logits)
67
+ if deterministic:
68
+ hard_mask = soft_mask.gt(0.5).type(logits.dtype)
69
+ else:
70
+ hard_mask = gumbel_sigmoid(logits, hard=True, tau=self.gumbel_tau)
71
+
72
+ return hard_mask, soft_mask
73
+
74
+ def extra_repr(self):
75
+ return f'init_w={self.init_w}, init_b={self.init_b}, learnable={self.learnable}, gumbel_tau={self.gumbel_tau}'
76
+
77
+
78
+ def norm_img_tensor(tensor: torch.Tensor) -> torch.Tensor:
79
+ """
80
+ Normalize image tensor to the range [0, 1].
81
+
82
+ Args:
83
+ tensor (torch.Tensor): Input image tensor.
84
+
85
+ Returns:
86
+ torch.Tensor: Normalized image tensor.
87
+ """
88
+ vmin = tensor.amin((2, 3), keepdims=True) - 1e-7
89
+ vmax = tensor.amax((2, 3), keepdims=True) + 1e-7
90
+ tensor = (tensor - vmin) / (vmax - vmin)
91
+ return tensor
92
+
93
+
94
+ class ImageMasker(Sim2Mask):
95
+ def forward(self, x: torch.Tensor, infer: bool = False) -> torch.Tensor:
96
+ """
97
+ Forward pass for generating image-level binary masks.
98
+
99
+ Args:
100
+ x (torch.Tensor): Input tensor.
101
+ infer (bool): True for only inference stage.
102
+
103
+ Returns:
104
+ torch.Tensor: Binary mask.
105
+
106
+ Reference:
107
+ "Can CLIP Help Sound Source Localization?" WACV 2024
108
+ - https://arxiv.org/abs/2311.04066
109
+ """
110
+ if self.training or not infer:
111
+ output = super().forward(x, False)[0]
112
+ else:
113
+ output = torch.sigmoid(x + self.b / self.w)
114
+ return output
115
+
116
+
117
+ class FeatureMasker(nn.Module):
118
+ def __init__(self, thr: float = 0.5, tau: float = 0.07):
119
+ """
120
+ Masker module for generating feature-level masks.
121
+
122
+ Args:
123
+ thr (float): Threshold for generating the mask.
124
+ tau (float): Temperature for the sigmoid function.
125
+
126
+ Reference:
127
+ "Can CLIP Help Sound Source Localization?" WACV 2024
128
+ - https://arxiv.org/abs/2311.04066
129
+ """
130
+ super().__init__()
131
+ self.thr = thr
132
+ self.tau = tau
133
+
134
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
135
+ """
136
+ Forward pass for generating feature-level masks
137
+
138
+ Args:
139
+ x (torch.Tensor): Input tensor.
140
+
141
+ Returns:
142
+ torch.Tensor: Generated mask.
143
+ """
144
+ return torch.sigmoid((norm_img_tensor(x) - self.thr) / self.tau)
modules/models.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torch import nn
4
+
5
+ import yaml
6
+ import argparse
7
+
8
+ from modules.BEATs.BEATs import BEATs, BEATsConfig
9
+ from modules.AudioToken.embedder import FGAEmbedder
10
+ from modules.CLIPSeg.clipseg_for_audio import CLIPSeg
11
+ from modules.mask_utils import ImageMasker, FeatureMasker
12
+ from transformers import AutoTokenizer
13
+
14
+
15
+ class ACL(nn.Module):
16
+ def __init__(self, conf_file: str, device: str):
17
+ """
18
+ Audio-Grounded Contrastive Learning (ACL) model.
19
+
20
+ Args:
21
+ conf_file (str): Path to the configuration file.
22
+ device (str): Device to move the model to.
23
+ """
24
+ super(ACL, self).__init__()
25
+
26
+ # Get configuration
27
+ with open(conf_file) as f:
28
+ config = yaml.load(f, Loader=yaml.FullLoader)
29
+ self.args = argparse.Namespace()
30
+ self.args.model = argparse.Namespace(**config['model'])
31
+ self.args.clip_embedding_dim = config['clip_conf'][self.args.model.clip]['embedding_dim']
32
+ self.args.clip_name = config['clip_conf'][self.args.model.clip]['name']
33
+ self.pretrain = argparse.Namespace(**config['pretrain'])
34
+ self.args.audio_proj = argparse.Namespace(**config['fga_conf'][self.args.model.audio_proj])
35
+
36
+ # Init audio encoder
37
+ checkpoint = torch.load(self.pretrain.audio_backbone)
38
+ cfg = BEATsConfig(checkpoint['cfg'])
39
+ self.audio_backbone = BEATs(cfg)
40
+
41
+ # Text Tokenizer for placeholder prompt
42
+ self.tokenizer = AutoTokenizer.from_pretrained("CIDAS/clipseg-rd64-refined")
43
+
44
+ # Init audio projection layer
45
+ self.audio_proj = FGAEmbedder(input_size=self.args.audio_proj.input_size * 3,
46
+ output_size=self.args.audio_proj.output_size)
47
+
48
+ # Init audio-visual grounder (Grounder: CLIPSeg)
49
+ self.av_grounder = CLIPSeg.from_pretrained("CIDAS/clipseg-rd64-refined")
50
+
51
+ # Init maskers
52
+ self.masker_i = ImageMasker(10.0, 14.0, 1.0)
53
+ self.masker_f = FeatureMasker(0.5, 0.07)
54
+
55
+ # Load weights
56
+ self.audio_backbone.load_state_dict(checkpoint['model'])
57
+ self.audio_backbone.predictor = None
58
+
59
+ if self.pretrain.audio_proj is not None:
60
+ self.audio_proj.load_state_dict(torch.load(self.pretrain.audio_embedder))
61
+
62
+ # Set device
63
+ self.device = device
64
+ self.audio_backbone.to(device=self.device)
65
+ self.av_grounder.to(device=self.device)
66
+ self.audio_proj.to(device=self.device)
67
+ self.masker_i.to(self.device)
68
+ self.masker_f.to(self.device)
69
+
70
+ def get_placeholder_token(self, prompt_text: str):
71
+ """
72
+ Get placeholder token from prompt text
73
+
74
+ Args:
75
+ prompt_text (str): prompt text without '{}'
76
+
77
+ Returns:
78
+ CLIPTokenizerFast result with prompt text
79
+ """
80
+ placeholder_token = self.tokenizer(prompt_text, return_tensors="pt").data['input_ids']
81
+ placeholder_token = F.pad(placeholder_token, (0, 77 - placeholder_token.shape[-1])).to(self.device)
82
+ return placeholder_token
83
+
84
+ def train(self, bool: bool = True):
85
+ """
86
+ Set the module in training mode.
87
+
88
+ Args:
89
+ bool (bool): If True, set the module in training mode.
90
+ """
91
+ super().train(bool)
92
+ self.av_grounder.requires_grad_(False)
93
+ self.audio_backbone.requires_grad_(False)
94
+
95
+ def encode_audio(self, audio: torch.Tensor, placeholder_token: torch.Tensor, pos: int,
96
+ prompt_size: int) -> torch.Tensor:
97
+ """
98
+ Encode audio input into audio-driven embedding (Audio-Driven Embedder)
99
+
100
+ Args:
101
+ audio (torch.Tensor): Input audio tensor.
102
+ placeholder_token (torch.Tensor): Placeholder token for CLIP Text encoder.
103
+ pos (int): Position of audio token.
104
+ prompt_size (int): Size of the placeholder prompt.
105
+
106
+ Returns:
107
+ torch.Tensor: Audio-driven embeddings.
108
+ """
109
+ audio_feat = self.audio_backbone.extract_features(audio)[1]
110
+ audio_token_emb = self.audio_proj(audio_feat).unsqueeze(1)
111
+ audio_driven_embedding = self.av_grounder.encode_audio(placeholder_token, audio_token_emb, pos,
112
+ prompt_size + audio_token_emb.shape[1])
113
+
114
+ return audio_driven_embedding
115
+
116
+ def encode_vision(self, image: torch.Tensor) -> torch.Tensor:
117
+ """
118
+ Encode visual input and generate visual embeddings.
119
+
120
+ Args:
121
+ image (torch.Tensor): Input image tensor.
122
+
123
+ Returns:
124
+ torch.Tensor: Visual embeddings.
125
+ """
126
+ vision_outputs = self.av_grounder.clip.vision_model(pixel_values=image,
127
+ output_attentions=None,
128
+ output_hidden_states=True,
129
+ return_dict=True)
130
+ pooled_output = self.av_grounder.clip.visual_projection(vision_outputs[1])
131
+
132
+ return pooled_output
133
+
134
+ def forward_decoder(self, image: torch.Tensor, embedding: torch.Tensor, resolution: int = 224) -> torch.Tensor:
135
+ """
136
+ Forward pass of audio-visual grounder
137
+
138
+ Args:
139
+ image (torch.Tensor): Input image tensor.
140
+ embedding (torch.Tensor): Condition embedding tensor for grounder.
141
+ resolution (int): Resolution of the output.
142
+ ignore_indices (list): List of indices to ignore.
143
+
144
+ Returns:
145
+ torch.Tensor: Logits from the decoder.
146
+ """
147
+ # step 1: forward the query images through the frozen CLIP vision encoder
148
+ vision_outputs = self.av_grounder.clip.vision_model(pixel_values=image,
149
+ output_attentions=None,
150
+ output_hidden_states=True,
151
+ return_dict=True)
152
+
153
+ hidden_states = vision_outputs.hidden_states
154
+ # we add +1 here as the hidden states also include the initial embeddings
155
+ activations = [hidden_states[i + 1] for i in self.av_grounder.extract_layers]
156
+
157
+ # step 2: compute conditional embeddings, either from text, images or an own provided embedding
158
+ # Audio injected embedding from input argument
159
+
160
+ # step 3: forward both the pooled output and the activations through the lightweight decoder to predict masks
161
+ decoder_outputs = self.av_grounder.decoder(
162
+ activations,
163
+ embedding,
164
+ output_attentions=None,
165
+ output_hidden_states=None,
166
+ return_dict=True,
167
+ )
168
+ logits = decoder_outputs.logits
169
+
170
+ if logits.ndim == 2:
171
+ logits = logits.unsqueeze(0).unsqueeze(1)
172
+ else:
173
+ logits = logits.unsqueeze(1)
174
+
175
+ B, c, h, w = image.shape
176
+ if (h, w) != (resolution, resolution):
177
+ logits = F.interpolate(logits, resolution, mode='bicubic')
178
+
179
+ return logits
180
+
181
+ def forward_module(self, image: torch.Tensor, embedding: torch.Tensor, resolution: int = 224,
182
+ force_comb: bool = False) -> torch.Tensor:
183
+ """
184
+ Forward pass through the module.
185
+
186
+ Args:
187
+ image (torch.Tensor): Input image tensor.
188
+ embedding (torch.Tensor): Condition embedding tensor for grounder.
189
+ resolution (int): Resolution of the output tensor.
190
+ force_comb (bool): If True, force to get logits with all combination audio and image.
191
+
192
+ Returns:
193
+ torch.Tensor: Logits from the decoder.
194
+ """
195
+ # N image, 1 embedding case -> [B_i, h, w]
196
+ if embedding.shape[0] != image.shape[0] and embedding.shape[0] == 1:
197
+ embeddings = embedding.repeat(image.shape[0], 1)
198
+ logits = self.forward_decoder(image, embeddings, resolution)
199
+
200
+ # N image, M embedding case -> [B_i, B_e, h, w]
201
+ elif embedding.shape[0] != image.shape[0] and embedding.shape[0] != 1 and image.shape[0] != 1 or force_comb:
202
+ logit_list = []
203
+ for i in range(embedding.shape[0]):
204
+ embeddings = embedding[i].unsqueeze(0).repeat(image.shape[0], 1)
205
+ logit_list.append(self.forward_decoder(image, embeddings, resolution))
206
+ logits = torch.cat(logit_list, dim=1)
207
+
208
+ # N image, N embedding or 1 image, N embedding -> [B_e, h, w]
209
+ else:
210
+ logits = self.forward_decoder(image, embedding, resolution)
211
+
212
+ return logits
213
+
214
+ def encode_masked_vision(self, image: torch.Tensor, embedding: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, float, float]:
215
+ """
216
+ Encode masked visual feature both image-level and feature-level.
217
+
218
+ Args:
219
+ image (torch.Tensor): Input image tensor.
220
+ embedding (torch.Tensor): Condition embedding tensor for grounder.
221
+
222
+ Returns:
223
+ tuple[torch.Tensor, torch.Tensor, float, float]: Feature masked embeddings, masked image embeddings, positive area, negative area.
224
+ """
225
+ B, c, h, w = image.shape
226
+ maskclip_feat = self.av_grounder.get_pixels(image) # v^D: [B, c, h, w]
227
+ clipseg_mask = self.forward_module(image, embedding, h, force_comb=True) # M^G: [B, B, H, W]
228
+
229
+ # Area
230
+ area_matrix = self.masker_i(clipseg_mask).mean((2, 3))
231
+ positive_area = area_matrix.diagonal().mean()
232
+ negative_area = area_matrix.mean() - positive_area / B
233
+
234
+ # Feature level masker
235
+ feature_mask = F.interpolate(self.masker_f(clipseg_mask), maskclip_feat.shape[2])
236
+
237
+ # Image level masker
238
+ ind = torch.arange(B).to(image.device)
239
+ image_mask = self.masker_i(clipseg_mask[ind, ind].unsqueeze(1)) # Positive pair only
240
+ feature_masked_emb = torch.einsum('bchw,bnhw->bnc', maskclip_feat, feature_mask) / (feature_mask.sum() + 1e-6)
241
+
242
+ # step 1: forward the query images through the frozen CLIP vision encoder
243
+ masked_vision_outputs = self.av_grounder.clip.vision_model(pixel_values=image * image_mask,
244
+ output_attentions=None,
245
+ output_hidden_states=True,
246
+ return_dict=True)
247
+ masked_image_emb = self.av_grounder.clip.visual_projection(masked_vision_outputs[1])
248
+
249
+ return feature_masked_emb, masked_image_emb, positive_area, negative_area
250
+
251
+ def forward(self, image: torch.Tensor, embedding: torch.Tensor, resolution: int = 224) -> dict:
252
+ """
253
+ Forward pass of ACL model.
254
+
255
+ Args:
256
+ image (torch.Tensor): Input image tensor.
257
+ embedding (torch.Tensor): Condition embedding tensor for grounder.
258
+ resolution (int): Resolution of the output tensor.
259
+
260
+ Returns:
261
+ dict: Output dictionary containing relevant tensors.
262
+ """
263
+ if self.training:
264
+ # seg_logit = self.forward_module(image, embedding, resolution)
265
+ v_f, v_i, p_area, n_area = self.encode_masked_vision(image, embedding)
266
+ out_dict = {'v_f': v_f, 'v_i': v_i, 'p_area': p_area, 'n_area': n_area}
267
+
268
+ else:
269
+ seg_logit = self.forward_module(image, embedding, resolution)
270
+ heatmap = self.masker_i(seg_logit, infer=True)
271
+ out_dict = {'heatmap': heatmap}
272
+
273
+ return out_dict
274
+
275
+ def save(self, model_dir: str):
276
+ """
277
+ Save model parameters to a file. (Only trainable parts)
278
+
279
+ Args:
280
+ model_dir (str): Directory to save the model.
281
+ """
282
+ ckp = {'audio_proj': self.audio_proj.state_dict(), 'masker_i': self.masker_i.state_dict()}
283
+ torch.save(ckp, model_dir)
284
+
285
+ def load(self, model_dir: str):
286
+ """
287
+ Load model parameters from a file. (Only trainable parts)
288
+
289
+ Args:
290
+ model_dir (str): Directory to load the model from.
291
+ """
292
+ ckp = torch.load(model_dir)
293
+ self.audio_proj.load_state_dict(ckp['audio_proj'])
294
+ self.masker_i.load_state_dict(ckp['masker_i'])