guyyariv commited on
Commit
1b92e8f
1 Parent(s): 00d2193

AudioTokenDemo

Browse files
.idea/.gitignore ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
1
+ # Default ignored files
2
+ /shelf/
3
+ /workspace.xml
4
+ # Editor-based HTTP Client requests
5
+ /httpRequests/
6
+ # Datasource local storage ignored files
7
+ /dataSources/
8
+ /dataSources.local.xml
.idea/demo.iml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <module type="PYTHON_MODULE" version="4">
3
+ <component name="NewModuleRootManager">
4
+ <content url="file://$MODULE_DIR$" />
5
+ <orderEntry type="inheritedJdk" />
6
+ <orderEntry type="sourceFolder" forTests="false" />
7
+ </component>
8
+ </module>
.idea/inspectionProfiles/Project_Default.xml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <component name="InspectionProjectProfileManager">
2
+ <profile version="1.0">
3
+ <option name="myName" value="Project Default" />
4
+ <inspection_tool class="PyPackageRequirementsInspection" enabled="false" level="WARNING" enabled_by_default="false">
5
+ <option name="ignoredPackages">
6
+ <value>
7
+ <list size="1">
8
+ <item index="0" class="java.lang.String" itemvalue="tensorflow" />
9
+ </list>
10
+ </value>
11
+ </option>
12
+ </inspection_tool>
13
+ <inspection_tool class="PyStubPackagesAdvertiser" enabled="true" level="WARNING" enabled_by_default="true">
14
+ <option name="ignoredPackages">
15
+ <list>
16
+ <option value="pyspark-stubs==3.0.0.post3" />
17
+ </list>
18
+ </option>
19
+ </inspection_tool>
20
+ </profile>
21
+ </component>
.idea/inspectionProfiles/profiles_settings.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
1
+ <component name="InspectionProjectProfileManager">
2
+ <settings>
3
+ <option name="USE_PROJECT_PROFILE" value="false" />
4
+ <version value="1.0" />
5
+ </settings>
6
+ </component>
.idea/misc.xml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="ProjectRootManager" version="2" project-jdk-name="Python 3.8" project-jdk-type="Python SDK" />
4
+ </project>
.idea/modules.xml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="ProjectModuleManager">
4
+ <modules>
5
+ <module fileurl="file://$PROJECT_DIR$/.idea/demo.iml" filepath="$PROJECT_DIR$/.idea/demo.iml" />
6
+ </modules>
7
+ </component>
8
+ </project>
.idea/vcs.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="VcsDirectoryMappings">
4
+ <mapping directory="$PROJECT_DIR$" vcs="Git" />
5
+ </component>
6
+ </project>
README.md CHANGED
@@ -1,8 +1,8 @@
1
  ---
2
  title: AudioToken
3
- emoji: 🌖
4
- colorFrom: purple
5
- colorTo: purple
6
  sdk: gradio
7
  sdk_version: 3.29.0
8
  app_file: app.py
1
  ---
2
  title: AudioToken
3
+ emoji: 🏆
4
+ colorFrom: indigo
5
+ colorTo: indigo
6
  sdk: gradio
7
  sdk_version: 3.29.0
8
  app_file: app.py
app.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from diffusers.loaders import AttnProcsLayers
3
+ from transformers import CLIPTextModel, CLIPTokenizer
4
+ from modules.beats.BEATs import BEATs, BEATsConfig
5
+ from modules.AudioToken.embedder import FGAEmbedder
6
+ from diffusers import AutoencoderKL, UNet2DConditionModel
7
+ from diffusers.models.attention_processor import LoRAAttnProcessor
8
+ from diffusers import StableDiffusionPipeline
9
+ import numpy as np
10
+ import gradio as gr
11
+
12
+
13
+ class AudioTokenWrapper(torch.nn.Module):
14
+ """Simple wrapper module for Stable Diffusion that holds all the models together"""
15
+
16
+ def __init__(
17
+ self,
18
+ lora,
19
+ device,
20
+ ):
21
+
22
+ super().__init__()
23
+ # Load scheduler and models
24
+ self.tokenizer = CLIPTokenizer.from_pretrained(
25
+ "CompVis/stable-diffusion-v1-4", subfolder="tokenizer"
26
+ )
27
+ self.text_encoder = CLIPTextModel.from_pretrained(
28
+ "CompVis/stable-diffusion-v1-4", subfolder="text_encoder", revision=None
29
+ )
30
+ self.unet = UNet2DConditionModel.from_pretrained(
31
+ "CompVis/stable-diffusion-v1-4", subfolder="unet", revision=None
32
+ )
33
+ self.vae = AutoencoderKL.from_pretrained(
34
+ "CompVis/stable-diffusion-v1-4", subfolder="vae", revision=None
35
+ )
36
+
37
+ checkpoint = torch.load(
38
+ 'models/BEATs_iter3_plus_AS2M_finetuned_on_AS2M_cpt2.pt')
39
+ cfg = BEATsConfig(checkpoint['cfg'])
40
+ self.aud_encoder = BEATs(cfg)
41
+ self.aud_encoder.load_state_dict(checkpoint['model'])
42
+ self.aud_encoder.predictor = None
43
+ input_size = 768 * 3
44
+ self.embedder = FGAEmbedder(input_size=input_size, output_size=768)
45
+
46
+ self.vae.eval()
47
+ self.unet.eval()
48
+ self.text_encoder.eval()
49
+ self.aud_encoder.eval()
50
+
51
+ if lora:
52
+ # Set correct lora layers
53
+ lora_attn_procs = {}
54
+ for name in self.unet.attn_processors.keys():
55
+ cross_attention_dim = None if name.endswith(
56
+ "attn1.processor") else self.unet.config.cross_attention_dim
57
+ if name.startswith("mid_block"):
58
+ hidden_size = self.unet.config.block_out_channels[-1]
59
+ elif name.startswith("up_blocks"):
60
+ block_id = int(name[len("up_blocks.")])
61
+ hidden_size = list(reversed(self.unet.config.block_out_channels))[block_id]
62
+ elif name.startswith("down_blocks"):
63
+ block_id = int(name[len("down_blocks.")])
64
+ hidden_size = self.unet.config.block_out_channels[block_id]
65
+
66
+ lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size,
67
+ cross_attention_dim=cross_attention_dim)
68
+
69
+ self.unet.set_attn_processor(lora_attn_procs)
70
+ self.lora_layers = AttnProcsLayers(self.unet.attn_processors)
71
+ self.lora_layers.eval()
72
+ lora_layers_learned_embeds = 'models/lora_layers_learned_embeds.bin'
73
+ self.lora_layers.load_state_dict(torch.load(lora_layers_learned_embeds, map_location=device))
74
+ self.unet.load_attn_procs(lora_layers_learned_embeds)
75
+
76
+ self.embedder.eval()
77
+ embedder_learned_embeds = 'models/embedder_learned_embeds.bin'
78
+ self.embedder.load_state_dict(torch.load(embedder_learned_embeds, map_location=device))
79
+
80
+ self.placeholder_token = '<*>'
81
+ num_added_tokens = self.tokenizer.add_tokens(self.placeholder_token)
82
+ if num_added_tokens == 0:
83
+ raise ValueError(
84
+ f"The tokenizer already contains the token {self.placeholder_token}. Please pass a different"
85
+ " `placeholder_token` that is not already in the tokenizer."
86
+ )
87
+ self.placeholder_token_id = self.tokenizer.convert_tokens_to_ids(self.placeholder_token)
88
+ # Resize the token embeddings as we are adding new special tokens to the tokenizer
89
+ self.text_encoder.resize_token_embeddings(len(self.tokenizer))
90
+
91
+
92
+ def greet(audio):
93
+ audio = audio[-1].astype(np.float32, order='C') / 32768.0
94
+ weight_dtype = torch.float32
95
+ prompt = 'a photo of <*>'
96
+
97
+ audio_values = torch.unsqueeze(torch.tensor(audio), dim=0).to(device).to(dtype=weight_dtype)
98
+ aud_features = model.aud_encoder.extract_features(audio_values)[1]
99
+ audio_token = model.embedder(aud_features)
100
+
101
+ token_embeds = model.text_encoder.get_input_embeddings().weight.data
102
+ token_embeds[model.placeholder_token_id] = audio_token.clone()
103
+
104
+ pipeline = StableDiffusionPipeline.from_pretrained(
105
+ "CompVis/stable-diffusion-v1-4",
106
+ tokenizer=model.tokenizer,
107
+ text_encoder=model.text_encoder,
108
+ vae=model.vae,
109
+ unet=model.unet,
110
+ ).to(device)
111
+ image = pipeline(prompt, num_inference_steps=50, guidance_scale=7.5).images[0]
112
+ return image
113
+
114
+ description = """
115
+ This is a demo of [AudioToken: Adaptation of Text-Conditioned Diffusion Models for Audio-to-Image Generation](https://pages.cs.huji.ac.il/adiyoss-lab/AudioToken/)
116
+ """
117
+
118
+
119
+ if __name__ == "__main__":
120
+
121
+ lora = True
122
+ device = 'cpu'
123
+ model = AudioTokenWrapper(lora, device)
124
+ print('here')
125
+
126
+ description = """
127
+ This is a demo of [AudioToken: Adaptation of Text-Conditioned Diffusion Models for Audio-to-Image Generation](https://pages.cs.huji.ac.il/adiyoss-lab/AudioToken/).<br>
128
+ Simply upload an audio to test your own case.<br>
129
+ For more information, please see the original [paper](https://arxiv.org/abs/2305.13050) and [repo](https://github.com/guyyariv/AudioToken/).
130
+ """
131
+
132
+ examples = [
133
+ ["assets/train.wav"],
134
+ ["assets/dog barking.wav"],
135
+ ["assets/airplane.wav"]
136
+ ]
137
+
138
+ demo = gr.Interface(
139
+ fn=greet,
140
+ inputs="audio",
141
+ outputs="image",
142
+ title='AudioToken',
143
+ description=description,
144
+ examples=examples
145
+ )
146
+ demo.launch()
147
+
assets/airplane.wav ADDED
Binary file (320 kB). View file
assets/dog barking.wav ADDED
Binary file (320 kB). View file
assets/train.wav ADDED
Binary file (320 kB). View file
models/lora_layers_learned_embeds.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:24c971f2e526f4f13331fb0fb85272d3efb3befc17f0d671b422360b26d864a5
3
+ size 3294091
modules/AudioToken/embedder.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from modules.fga.atten import Atten
4
+
5
+
6
+ class FGAEmbedder(nn.Module):
7
+ def __init__(self, input_size=768*3, output_size=768):
8
+ super(FGAEmbedder, self).__init__()
9
+ self.fc1 = nn.Linear(input_size, input_size)
10
+ self.fc2 = nn.Linear(input_size, output_size)
11
+ self.gelu = nn.GELU()
12
+ self.fga = Atten(util_e=[output_size], pairwise_flag=False)
13
+
14
+ def forward(self, audio_embs):
15
+ audio_embs = self.fc1(audio_embs)
16
+ audio_embs = self.gelu(audio_embs)
17
+ audio_embs = self.fc2(audio_embs)
18
+ attend = self.fga([audio_embs])[0]
19
+ return attend
modules/beats/BEATs.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # beats: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
3
+ # Github source: https://github.com/microsoft/unilm/tree/master/beats
4
+ # Copyright (c) 2022 Microsoft
5
+ # Licensed under The MIT License [see LICENSE for details]
6
+ # Based on fairseq code bases
7
+ # https://github.com/pytorch/fairseq
8
+ # --------------------------------------------------------
9
+
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ from torch.nn import LayerNorm
14
+ import torchaudio.compliance.kaldi as ta_kaldi
15
+
16
+ from modules.beats.backbone import (
17
+ TransformerEncoder,
18
+ )
19
+
20
+ import logging
21
+ from typing import Optional
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ class BEATsConfig:
27
+ def __init__(self, cfg=None):
28
+ self.input_patch_size: int = -1 # path size of patch embedding
29
+ self.embed_dim: int = 512 # patch embedding dimension
30
+ self.conv_bias: bool = False # include bias in conv encoder
31
+
32
+ self.encoder_layers: int = 12 # num encoder layers in the transformer
33
+ self.encoder_embed_dim: int = 768 # encoder embedding dimension
34
+ self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN
35
+ self.encoder_attention_heads: int = 12 # num encoder attention heads
36
+ self.activation_fn: str = "gelu" # activation function to use
37
+
38
+ self.layer_wise_gradient_decay_ratio: float = 1.0 # ratio for layer-wise gradient decay
39
+ self.layer_norm_first: bool = False # apply layernorm first in the transformer
40
+ self.deep_norm: bool = False # apply deep_norm first in the transformer
41
+
42
+ # dropouts
43
+ self.dropout: float = 0.1 # dropout probability for the transformer
44
+ self.attention_dropout: float = 0.1 # dropout probability for attention weights
45
+ self.activation_dropout: float = 0.0 # dropout probability after activation in FFN
46
+ self.encoder_layerdrop: float = 0.0 # probability of dropping a tarnsformer layer
47
+ self.dropout_input: float = 0.0 # dropout to apply to the input (after feat extr)
48
+
49
+ # positional embeddings
50
+ self.conv_pos: int = 128 # number of filters for convolutional positional embeddings
51
+ self.conv_pos_groups: int = 16 # number of groups for convolutional positional embedding
52
+
53
+ # relative position embedding
54
+ self.relative_position_embedding: bool = False # apply relative position embedding
55
+ self.num_buckets: int = 320 # number of buckets for relative position embedding
56
+ self.max_distance: int = 1280 # maximum distance for relative position embedding
57
+ self.gru_rel_pos: bool = False # apply gated relative position embedding
58
+
59
+ # label predictor
60
+ self.finetuned_model: bool = False # whether the model is a fine-tuned model.
61
+ self.predictor_dropout: float = 0.1 # dropout probability for the predictor
62
+ self.predictor_class: int = 527 # target class number for the predictor
63
+
64
+ if cfg is not None:
65
+ self.update(cfg)
66
+
67
+ def update(self, cfg: dict):
68
+ self.__dict__.update(cfg)
69
+
70
+
71
+ class BEATs(nn.Module):
72
+ def __init__(
73
+ self,
74
+ cfg: BEATsConfig,
75
+ ) -> None:
76
+ super().__init__()
77
+ logger.info(f"BEATs Config: {cfg.__dict__}")
78
+
79
+ self.cfg = cfg
80
+
81
+ self.embed = cfg.embed_dim
82
+ self.post_extract_proj = (
83
+ nn.Linear(self.embed, cfg.encoder_embed_dim)
84
+ if self.embed != cfg.encoder_embed_dim
85
+ else None
86
+ )
87
+
88
+ self.input_patch_size = cfg.input_patch_size
89
+ self.patch_embedding = nn.Conv2d(1, self.embed, kernel_size=self.input_patch_size, stride=self.input_patch_size,
90
+ bias=cfg.conv_bias)
91
+
92
+ self.dropout_input = nn.Dropout(cfg.dropout_input)
93
+
94
+ assert not cfg.deep_norm or not cfg.layer_norm_first
95
+ self.encoder = TransformerEncoder(cfg)
96
+ self.layer_norm = LayerNorm(self.embed)
97
+
98
+ if cfg.finetuned_model:
99
+ self.predictor_dropout = nn.Dropout(cfg.predictor_dropout)
100
+ self.predictor = nn.Linear(cfg.encoder_embed_dim, cfg.predictor_class)
101
+ else:
102
+ self.predictor = None
103
+
104
+ def forward_padding_mask(
105
+ self,
106
+ features: torch.Tensor,
107
+ padding_mask: torch.Tensor,
108
+ ) -> torch.Tensor:
109
+ extra = padding_mask.size(1) % features.size(1)
110
+ if extra > 0:
111
+ padding_mask = padding_mask[:, :-extra]
112
+ padding_mask = padding_mask.view(
113
+ padding_mask.size(0), features.size(1), -1
114
+ )
115
+ padding_mask = padding_mask.all(-1)
116
+ return padding_mask
117
+
118
+ def preprocess(
119
+ self,
120
+ source: torch.Tensor,
121
+ fbank_mean: float = 15.41663,
122
+ fbank_std: float = 6.55582,
123
+ ) -> torch.Tensor:
124
+ fbanks = []
125
+ for waveform in source:
126
+ waveform = waveform.unsqueeze(0) * 2 ** 15
127
+ fbank = ta_kaldi.fbank(waveform, num_mel_bins=128, sample_frequency=16000, frame_length=25, frame_shift=10)
128
+ fbanks.append(fbank)
129
+ fbank = torch.stack(fbanks, dim=0)
130
+ fbank = (fbank - fbank_mean) / (2 * fbank_std)
131
+ return fbank
132
+
133
+ def extract_features(
134
+ self,
135
+ source: torch.Tensor,
136
+ padding_mask: Optional[torch.Tensor] = None,
137
+ fbank_mean: float = 15.41663,
138
+ fbank_std: float = 6.55582,
139
+ ):
140
+ fbank = self.preprocess(source, fbank_mean=fbank_mean, fbank_std=fbank_std)
141
+ if padding_mask is not None:
142
+ padding_mask = self.forward_padding_mask(fbank, padding_mask)
143
+
144
+ fbank = fbank.unsqueeze(1)
145
+ features = self.patch_embedding(fbank)
146
+ features = features.reshape(features.shape[0], features.shape[1], -1)
147
+ features = features.transpose(1, 2)
148
+ features = self.layer_norm(features)
149
+
150
+ if padding_mask is not None:
151
+ padding_mask = self.forward_padding_mask(features, padding_mask)
152
+
153
+ if self.post_extract_proj is not None:
154
+ features = self.post_extract_proj(features)
155
+
156
+ x = self.dropout_input(features)
157
+
158
+ x, layers_sum, layers = self.encoder(
159
+ x,
160
+ padding_mask=padding_mask,
161
+ )
162
+
163
+ if self.predictor is not None:
164
+ x = self.predictor_dropout(x)
165
+ logits = self.predictor(x)
166
+
167
+ if padding_mask is not None and padding_mask.any():
168
+ logits[padding_mask] = 0
169
+ logits = logits.sum(dim=1)
170
+ logits = logits / (~padding_mask).sum(dim=1).unsqueeze(-1).expand_as(logits)
171
+ else:
172
+ logits = logits.mean(dim=1)
173
+
174
+ lprobs = torch.sigmoid(logits)
175
+
176
+ return lprobs, padding_mask
177
+ else:
178
+ return x, layers_sum, layers
modules/beats/Tokenizers.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # beats: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
3
+ # Github source: https://github.com/microsoft/unilm/tree/master/beats
4
+ # Copyright (c) 2022 Microsoft
5
+ # Licensed under The MIT License [see LICENSE for details]
6
+ # Based on fairseq code bases
7
+ # https://github.com/pytorch/fairseq
8
+ # --------------------------------------------------------
9
+
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ from torch.nn import LayerNorm
14
+ import torchaudio.compliance.kaldi as ta_kaldi
15
+
16
+ from modules.beats.backbone import (
17
+ TransformerEncoder,
18
+ )
19
+ from modules.beats.quantizer import (
20
+ NormEMAVectorQuantizer,
21
+ )
22
+
23
+ import logging
24
+ from typing import Optional
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+
29
+ class TokenizersConfig:
30
+ def __init__(self, cfg=None):
31
+ self.input_patch_size: int = -1 # path size of patch embedding
32
+ self.embed_dim: int = 512 # patch embedding dimension
33
+ self.conv_bias: bool = False # include bias in conv encoder
34
+
35
+ self.encoder_layers: int = 12 # num encoder layers in the transformer
36
+ self.encoder_embed_dim: int = 768 # encoder embedding dimension
37
+ self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN
38
+ self.encoder_attention_heads: int = 12 # num encoder attention heads
39
+ self.activation_fn: str = "gelu" # activation function to use
40
+
41
+ self.layer_norm_first: bool = False # apply layernorm first in the transformer
42
+ self.deep_norm: bool = False # apply deep_norm first in the transformer
43
+
44
+ # dropouts
45
+ self.dropout: float = 0.1 # dropout probability for the transformer
46
+ self.attention_dropout: float = 0.1 # dropout probability for attention weights
47
+ self.activation_dropout: float = 0.0 # dropout probability after activation in FFN
48
+ self.encoder_layerdrop: float = 0.0 # probability of dropping a tarnsformer layer
49
+ self.dropout_input: float = 0.0 # dropout to apply to the input (after feat extr)
50
+
51
+ # positional embeddings
52
+ self.conv_pos: int = 128 # number of filters for convolutional positional embeddings
53
+ self.conv_pos_groups: int = 16 # number of groups for convolutional positional embedding
54
+
55
+ # relative position embedding
56
+ self.relative_position_embedding: bool = False # apply relative position embedding
57
+ self.num_buckets: int = 320 # number of buckets for relative position embedding
58
+ self.max_distance: int = 1280 # maximum distance for relative position embedding
59
+ self.gru_rel_pos: bool = False # apply gated relative position embedding
60
+
61
+ # quantizer
62
+ self.quant_n: int = 1024 # codebook number in quantizer
63
+ self.quant_dim: int = 256 # codebook dimension in quantizer
64
+
65
+ if cfg is not None:
66
+ self.update(cfg)
67
+
68
+ def update(self, cfg: dict):
69
+ self.__dict__.update(cfg)
70
+
71
+
72
+ class Tokenizers(nn.Module):
73
+ def __init__(
74
+ self,
75
+ cfg: TokenizersConfig,
76
+ ) -> None:
77
+ super().__init__()
78
+ logger.info(f"Tokenizers Config: {cfg.__dict__}")
79
+
80
+ self.cfg = cfg
81
+
82
+ self.embed = cfg.embed_dim
83
+ self.post_extract_proj = (
84
+ nn.Linear(self.embed, cfg.encoder_embed_dim)
85
+ if self.embed != cfg.encoder_embed_dim
86
+ else None
87
+ )
88
+
89
+ self.input_patch_size = cfg.input_patch_size
90
+ self.patch_embedding = nn.Conv2d(1, self.embed, kernel_size=self.input_patch_size, stride=self.input_patch_size,
91
+ bias=cfg.conv_bias)
92
+
93
+ self.dropout_input = nn.Dropout(cfg.dropout_input)
94
+
95
+ assert not cfg.deep_norm or not cfg.layer_norm_first
96
+ self.encoder = TransformerEncoder(cfg)
97
+ self.layer_norm = LayerNorm(self.embed)
98
+
99
+ self.quantize = NormEMAVectorQuantizer(
100
+ n_embed=cfg.quant_n, embedding_dim=cfg.quant_dim, beta=1.0, kmeans_init=True, decay=0.99,
101
+ )
102
+ self.quant_n = cfg.quant_n
103
+ self.quantize_layer = nn.Sequential(
104
+ nn.Linear(cfg.encoder_embed_dim, cfg.encoder_embed_dim),
105
+ nn.Tanh(),
106
+ nn.Linear(cfg.encoder_embed_dim, cfg.quant_dim) # for quantize
107
+ )
108
+
109
+ def forward_padding_mask(
110
+ self,
111
+ features: torch.Tensor,
112
+ padding_mask: torch.Tensor,
113
+ ) -> torch.Tensor:
114
+ extra = padding_mask.size(1) % features.size(1)
115
+ if extra > 0:
116
+ padding_mask = padding_mask[:, :-extra]
117
+ padding_mask = padding_mask.view(
118
+ padding_mask.size(0), features.size(1), -1
119
+ )
120
+ padding_mask = padding_mask.all(-1)
121
+ return padding_mask
122
+
123
+ def preprocess(
124
+ self,
125
+ source: torch.Tensor,
126
+ fbank_mean: float = 15.41663,
127
+ fbank_std: float = 6.55582,
128
+ ) -> torch.Tensor:
129
+ fbanks = []
130
+ for waveform in source:
131
+ waveform = waveform.unsqueeze(0) * 2 ** 15
132
+ fbank = ta_kaldi.fbank(waveform, num_mel_bins=128, sample_frequency=16000, frame_length=25, frame_shift=10)
133
+ fbanks.append(fbank)
134
+ fbank = torch.stack(fbanks, dim=0)
135
+ fbank = (fbank - fbank_mean) / (2 * fbank_std)
136
+ return fbank
137
+
138
+ def extract_labels(
139
+ self,
140
+ source: torch.Tensor,
141
+ padding_mask: Optional[torch.Tensor] = None,
142
+ fbank_mean: float = 15.41663,
143
+ fbank_std: float = 6.55582,
144
+ ):
145
+ fbank = self.preprocess(source, fbank_mean=fbank_mean, fbank_std=fbank_std)
146
+
147
+ if padding_mask is not None:
148
+ padding_mask = self.forward_padding_mask(fbank, padding_mask)
149
+
150
+ fbank = fbank.unsqueeze(1)
151
+ features = self.patch_embedding(fbank)
152
+ features = features.reshape(features.shape[0], features.shape[1], -1)
153
+ features = features.transpose(1, 2)
154
+ features = self.layer_norm(features)
155
+
156
+ if padding_mask is not None:
157
+ padding_mask = self.forward_padding_mask(features, padding_mask)
158
+
159
+ if self.post_extract_proj is not None:
160
+ features = self.post_extract_proj(features)
161
+
162
+ x = self.dropout_input(features)
163
+
164
+ x, layer_results = self.encoder(
165
+ x,
166
+ padding_mask=padding_mask,
167
+ )
168
+
169
+ quantize_input = self.quantize_layer(x)
170
+ quantize_feature, embed_loss, embed_ind = self.quantize(quantize_input)
171
+
172
+ return embed_ind
modules/beats/backbone.py ADDED
@@ -0,0 +1,787 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # beats: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
3
+ # Github source: https://github.com/microsoft/unilm/tree/master/beats
4
+ # Copyright (c) 2022 Microsoft
5
+ # Licensed under The MIT License [see LICENSE for details]
6
+ # Based on fairseq code bases
7
+ # https://github.com/pytorch/fairseq
8
+ # --------------------------------------------------------
9
+
10
+ import math
11
+ import numpy as np
12
+ from typing import Dict, Optional, Tuple
13
+ import torch
14
+ from torch import Tensor, nn
15
+ import torch.nn.functional as F
16
+ from torch.nn import LayerNorm, Parameter
17
+ from modules.beats.modules import (
18
+ GradMultiply,
19
+ SamePad,
20
+ get_activation_fn,
21
+ GLU_Linear,
22
+ quant_noise,
23
+ )
24
+
25
+
26
+ class TransformerEncoder(nn.Module):
27
+ def __init__(self, args):
28
+ super().__init__()
29
+
30
+ self.dropout = args.dropout
31
+ self.embedding_dim = args.encoder_embed_dim
32
+
33
+ self.pos_conv = nn.Conv1d(
34
+ self.embedding_dim,
35
+ self.embedding_dim,
36
+ kernel_size=args.conv_pos,
37
+ padding=args.conv_pos // 2,
38
+ groups=args.conv_pos_groups,
39
+ )
40
+ dropout = 0
41
+ std = math.sqrt((4 * (1.0 - dropout)) / (args.conv_pos * self.embedding_dim))
42
+ nn.init.normal_(self.pos_conv.weight, mean=0, std=std)
43
+ nn.init.constant_(self.pos_conv.bias, 0)
44
+
45
+ self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2)
46
+ self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos), nn.GELU())
47
+
48
+ if hasattr(args, "relative_position_embedding"):
49
+ self.relative_position_embedding = args.relative_position_embedding
50
+ self.num_buckets = args.num_buckets
51
+ self.max_distance = args.max_distance
52
+ else:
53
+ self.relative_position_embedding = False
54
+ self.num_buckets = 0
55
+ self.max_distance = 0
56
+
57
+ self.layers = nn.ModuleList(
58
+ [
59
+ TransformerSentenceEncoderLayer(
60
+ embedding_dim=self.embedding_dim,
61
+ ffn_embedding_dim=args.encoder_ffn_embed_dim,
62
+ num_attention_heads=args.encoder_attention_heads,
63
+ dropout=self.dropout,
64
+ attention_dropout=args.attention_dropout,
65
+ activation_dropout=args.activation_dropout,
66
+ activation_fn=args.activation_fn,
67
+ layer_norm_first=args.layer_norm_first,
68
+ deep_norm=args.deep_norm,
69
+ has_relative_attention_bias=self.relative_position_embedding,
70
+ num_buckets=self.num_buckets,
71
+ max_distance=self.max_distance,
72
+ gru_rel_pos=args.gru_rel_pos,
73
+ encoder_layers=args.encoder_layers,
74
+ )
75
+ for i in range(args.encoder_layers)
76
+ ]
77
+ )
78
+ if self.relative_position_embedding:
79
+ for i in range(1, args.encoder_layers):
80
+ del self.layers[i].self_attn.relative_attention_bias
81
+ self.layers[i].self_attn.relative_attention_bias = self.layers[0].self_attn.relative_attention_bias
82
+
83
+ self.layer_norm_first = args.layer_norm_first
84
+ self.layer_norm = LayerNorm(self.embedding_dim)
85
+ self.layerdrop = args.encoder_layerdrop
86
+
87
+ self.apply(init_bert_params)
88
+
89
+ if args.deep_norm:
90
+ deep_norm_beta = math.pow(8 * args.encoder_layers, -1 / 4)
91
+ for i in range(args.encoder_layers):
92
+ nn.init.xavier_normal_(self.layers[i].self_attn.k_proj.weight, gain=1)
93
+ nn.init.xavier_normal_(self.layers[i].self_attn.v_proj.weight, gain=deep_norm_beta)
94
+ nn.init.xavier_normal_(self.layers[i].self_attn.q_proj.weight, gain=1)
95
+ nn.init.xavier_normal_(self.layers[i].self_attn.out_proj.weight, gain=deep_norm_beta)
96
+ nn.init.xavier_normal_(self.layers[i].fc1.weight, gain=deep_norm_beta)
97
+ nn.init.xavier_normal_(self.layers[i].fc2.weight, gain=deep_norm_beta)
98
+
99
+ self.layer_wise_gradient_decay_ratio = getattr(args, "layer_wise_gradient_decay_ratio", 1)
100
+
101
+ def forward(self, x, padding_mask=None, layer=None):
102
+ x, layers_sum, layers = self.extract_features(x, padding_mask, layer)
103
+
104
+ if self.layer_norm_first and layer is None:
105
+ x = self.layer_norm(x)
106
+
107
+ return x, layers_sum, layers
108
+
109
+ def extract_features(self, x, padding_mask=None, tgt_layer=None):
110
+
111
+ if padding_mask is not None:
112
+ x[padding_mask] = 0
113
+
114
+ x_conv = self.pos_conv(x.transpose(1, 2))
115
+ x_conv = x_conv.transpose(1, 2)
116
+ x += x_conv
117
+
118
+ if not self.layer_norm_first:
119
+ x = self.layer_norm(x)
120
+
121
+ x = F.dropout(x, p=self.dropout, training=self.training)
122
+
123
+ # B x T x C -> T x B x C
124
+ x = x.transpose(0, 1)
125
+ layers = []
126
+
127
+ layer_results = []
128
+ z = None
129
+ if tgt_layer is not None:
130
+ layer_results.append((x, z))
131
+ r = None
132
+ pos_bias = None
133
+ for i, layer in enumerate(self.layers):
134
+ if self.layer_wise_gradient_decay_ratio != 1.0:
135
+ x = GradMultiply.apply(x, self.layer_wise_gradient_decay_ratio)
136
+ dropout_probability = np.random.random()
137
+ if not self.training or (dropout_probability > self.layerdrop):
138
+ x, z, pos_bias = layer(x, self_attn_padding_mask=padding_mask, need_weights=False, pos_bias=pos_bias)
139
+ if tgt_layer is not None:
140
+ layer_results.append((x, z))
141
+ if i == tgt_layer:
142
+ r = x
143
+ break
144
+ if i in [3, 7, 11]:
145
+ layers.append(x.transpose(0, 1))
146
+
147
+ if r is not None:
148
+ x = r
149
+
150
+ # T x B x C -> B x T x C
151
+ x = x.transpose(0, 1)
152
+ layers_cat = torch.cat(layers, dim=2)
153
+
154
+ return x, layers_cat, layers
155
+
156
+
157
+ class TransformerSentenceEncoderLayer(nn.Module):
158
+ def __init__(
159
+ self,
160
+ embedding_dim: float = 768,
161
+ ffn_embedding_dim: float = 3072,
162
+ num_attention_heads: float = 8,
163
+ dropout: float = 0.1,
164
+ attention_dropout: float = 0.1,
165
+ activation_dropout: float = 0.1,
166
+ activation_fn: str = "relu",
167
+ layer_norm_first: bool = False,
168
+ deep_norm: bool = False,
169
+ has_relative_attention_bias: bool = False,
170
+ num_buckets: int = 0,
171
+ max_distance: int = 0,
172
+ rescale_init: bool = False,
173
+ gru_rel_pos: bool = False,
174
+ encoder_layers: int = 0,
175
+ ) -> None:
176
+
177
+ super().__init__()
178
+ self.embedding_dim = embedding_dim
179
+ self.dropout = dropout
180
+ self.activation_dropout = activation_dropout
181
+
182
+ self.activation_name = activation_fn
183
+ self.activation_fn = get_activation_fn(activation_fn)
184
+ self.self_attn = MultiheadAttention(
185
+ self.embedding_dim,
186
+ num_attention_heads,
187
+ dropout=attention_dropout,
188
+ self_attention=True,
189
+ has_relative_attention_bias=has_relative_attention_bias,
190
+ num_buckets=num_buckets,
191
+ max_distance=max_distance,
192
+ rescale_init=rescale_init,
193
+ gru_rel_pos=gru_rel_pos,
194
+ )
195
+
196
+ self.dropout1 = nn.Dropout(dropout)
197
+ self.dropout2 = nn.Dropout(self.activation_dropout)
198
+ self.dropout3 = nn.Dropout(dropout)
199
+
200
+ self.layer_norm_first = layer_norm_first
201
+
202
+ self.self_attn_layer_norm = LayerNorm(self.embedding_dim)
203
+
204
+ if self.activation_name == "glu":
205
+ self.fc1 = GLU_Linear(self.embedding_dim, ffn_embedding_dim, "swish")
206
+ else:
207
+ self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim)
208
+ self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim)
209
+
210
+ self.final_layer_norm = LayerNorm(self.embedding_dim)
211
+
212
+ self.deep_norm = deep_norm
213
+ if self.deep_norm:
214
+ self.deep_norm_alpha = math.pow(2 * encoder_layers, 1 / 4)
215
+ else:
216
+ self.deep_norm_alpha = 1
217
+
218
+ def forward(
219
+ self,
220
+ x: torch.Tensor,
221
+ self_attn_mask: torch.Tensor = None,
222
+ self_attn_padding_mask: torch.Tensor = None,
223
+ need_weights: bool = False,
224
+ pos_bias=None
225
+ ):
226
+ residual = x
227
+
228
+ if self.layer_norm_first:
229
+ x = self.self_attn_layer_norm(x)
230
+ x, attn, pos_bias = self.self_attn(
231
+ query=x,
232
+ key=x,
233
+ value=x,
234
+ key_padding_mask=self_attn_padding_mask,
235
+ need_weights=False,
236
+ attn_mask=self_attn_mask,
237
+ position_bias=pos_bias
238
+ )
239
+ x = self.dropout1(x)
240
+ x = residual + x
241
+
242
+ residual = x
243
+ x = self.final_layer_norm(x)
244
+ if self.activation_name == "glu":
245
+ x = self.fc1(x)
246
+ else:
247
+ x = self.activation_fn(self.fc1(x))
248
+ x = self.dropout2(x)
249
+ x = self.fc2(x)
250
+ x = self.dropout3(x)
251
+ x = residual + x
252
+ else:
253
+ x, attn, pos_bias = self.self_attn(
254
+ query=x,
255
+ key=x,
256
+ value=x,
257
+ key_padding_mask=self_attn_padding_mask,
258
+ need_weights=need_weights,
259
+ attn_mask=self_attn_mask,
260
+ position_bias=pos_bias
261
+ )
262
+
263
+ x = self.dropout1(x)
264
+ x = residual * self.deep_norm_alpha + x
265
+
266
+ x = self.self_attn_layer_norm(x)
267
+
268
+ residual = x
269
+ if self.activation_name == "glu":
270
+ x = self.fc1(x)
271
+ else:
272
+ x = self.activation_fn(self.fc1(x))
273
+ x = self.dropout2(x)
274
+ x = self.fc2(x)
275
+ x = self.dropout3(x)
276
+ x = residual * self.deep_norm_alpha + x
277
+ x = self.final_layer_norm(x)
278
+
279
+ return x, attn, pos_bias
280
+
281
+
282
+ class MultiheadAttention(nn.Module):
283
+ """Multi-headed attention.
284
+
285
+ See "Attention Is All You Need" for more details.
286
+ """
287
+
288
+ def __init__(
289
+ self,
290
+ embed_dim,
291
+ num_heads,
292
+ kdim=None,
293
+ vdim=None,
294
+ dropout=0.0,
295
+ bias=True,
296
+ add_bias_kv=False,
297
+ add_zero_attn=False,
298
+ self_attention=False,
299
+ encoder_decoder_attention=False,
300
+ q_noise=0.0,
301
+ qn_block_size=8,
302
+ has_relative_attention_bias=False,
303
+ num_buckets=32,
304
+ max_distance=128,
305
+ gru_rel_pos=False,
306
+ rescale_init=False,
307
+ ):
308
+ super().__init__()
309
+ self.embed_dim = embed_dim
310
+ self.kdim = kdim if kdim is not None else embed_dim
311
+ self.vdim = vdim if vdim is not None else embed_dim
312
+ self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
313
+
314
+ self.num_heads = num_heads
315
+ self.dropout_module = nn.Dropout(dropout)
316
+
317
+ self.has_relative_attention_bias = has_relative_attention_bias
318
+ self.num_buckets = num_buckets
319
+ self.max_distance = max_distance
320
+ if self.has_relative_attention_bias:
321
+ self.relative_attention_bias = nn.Embedding(num_buckets, num_heads)
322
+
323
+ self.head_dim = embed_dim // num_heads
324
+ self.q_head_dim = self.head_dim
325
+ self.k_head_dim = self.head_dim
326
+ assert (
327
+ self.head_dim * num_heads == self.embed_dim
328
+ ), "embed_dim must be divisible by num_heads"
329
+ self.scaling = self.head_dim ** -0.5
330
+
331
+ self.self_attention = self_attention
332
+ self.encoder_decoder_attention = encoder_decoder_attention
333
+
334
+ assert not self.self_attention or self.qkv_same_dim, (
335
+ "Self-attention requires query, key and " "value to be of the same size"
336
+ )
337
+
338
+ k_bias = True
339
+ if rescale_init:
340
+ k_bias = False
341
+
342
+ k_embed_dim = embed_dim
343
+ q_embed_dim = embed_dim
344
+
345
+ self.k_proj = quant_noise(
346
+ nn.Linear(self.kdim, k_embed_dim, bias=k_bias), q_noise, qn_block_size
347
+ )
348
+ self.v_proj = quant_noise(
349
+ nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size
350
+ )
351
+ self.q_proj = quant_noise(
352
+ nn.Linear(embed_dim, q_embed_dim, bias=bias), q_noise, qn_block_size
353
+ )
354
+
355
+ self.out_proj = quant_noise(
356
+ nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size
357
+ )
358
+
359
+ if add_bias_kv:
360
+ self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
361
+ self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
362
+ else:
363
+ self.bias_k = self.bias_v = None
364
+
365
+ self.add_zero_attn = add_zero_attn
366
+
367
+ self.gru_rel_pos = gru_rel_pos
368
+ if self.gru_rel_pos:
369
+ self.grep_linear = nn.Linear(self.q_head_dim, 8)
370
+ self.grep_a = nn.Parameter(torch.ones(1, num_heads, 1, 1))
371
+
372
+ self.reset_parameters()
373
+
374
+ def reset_parameters(self):
375
+ if self.qkv_same_dim:
376
+ # Empirically observed the convergence to be much better with
377
+ # the scaled initialization
378
+ nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
379
+ nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
380
+ nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
381
+ else:
382
+ nn.init.xavier_uniform_(self.k_proj.weight)
383
+ nn.init.xavier_uniform_(self.v_proj.weight)
384
+ nn.init.xavier_uniform_(self.q_proj.weight)
385
+
386
+ nn.init.xavier_uniform_(self.out_proj.weight)
387
+ if self.out_proj.bias is not None:
388
+ nn.init.constant_(self.out_proj.bias, 0.0)
389
+ if self.bias_k is not None:
390
+ nn.init.xavier_normal_(self.bias_k)
391
+ if self.bias_v is not None:
392
+ nn.init.xavier_normal_(self.bias_v)
393
+ if self.has_relative_attention_bias:
394
+ nn.init.xavier_normal_(self.relative_attention_bias.weight)
395
+
396
+ def _relative_positions_bucket(self, relative_positions, bidirectional=True):
397
+ num_buckets = self.num_buckets
398
+ max_distance = self.max_distance
399
+ relative_buckets = 0
400
+
401
+ if bidirectional:
402
+ num_buckets = num_buckets // 2
403
+ relative_buckets += (relative_positions > 0).to(torch.long) * num_buckets
404
+ relative_positions = torch.abs(relative_positions)
405
+ else:
406
+ relative_positions = -torch.min(relative_positions, torch.zeros_like(relative_positions))
407
+
408
+ max_exact = num_buckets // 2
409
+ is_small = relative_positions < max_exact
410
+
411
+ relative_postion_if_large = max_exact + (
412
+ torch.log(relative_positions.float() / max_exact)
413
+ / math.log(max_distance / max_exact)
414
+ * (num_buckets - max_exact)
415
+ ).to(torch.long)
416
+ relative_postion_if_large = torch.min(
417
+ relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1)
418
+ )
419
+
420
+ relative_buckets += torch.where(is_small, relative_positions, relative_postion_if_large)
421
+ return relative_buckets
422
+
423
+ def compute_bias(self, query_length, key_length):
424
+ context_position = torch.arange(query_length, dtype=torch.long)[:, None]
425
+ memory_position = torch.arange(key_length, dtype=torch.long)[None, :]
426
+ relative_position = memory_position - context_position
427
+ relative_position_bucket = self._relative_positions_bucket(
428
+ relative_position,
429
+ bidirectional=True
430
+ )
431
+ relative_position_bucket = relative_position_bucket.to(self.relative_attention_bias.weight.device)
432
+ values = self.relative_attention_bias(relative_position_bucket)
433
+ values = values.permute([2, 0, 1])
434
+ return values
435
+
436
+ def forward(
437
+ self,
438
+ query,
439
+ key: Optional[Tensor],
440
+ value: Optional[Tensor],
441
+ key_padding_mask: Optional[Tensor] = None,
442
+ incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
443
+ need_weights: bool = True,
444
+ static_kv: bool = False,
445
+ attn_mask: Optional[Tensor] = None,
446
+ before_softmax: bool = False,
447
+ need_head_weights: bool = False,
448
+ position_bias: Optional[Tensor] = None
449
+ ) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
450
+ """Input shape: Time x Batch x Channel
451
+
452
+ Args:
453
+ key_padding_mask (ByteTensor, optional): mask to exclude
454
+ keys that are pads, of shape `(batch, src_len)`, where
455
+ padding elements are indicated by 1s.
456
+ need_weights (bool, optional): return the attention weights,
457
+ averaged over heads (default: False).
458
+ attn_mask (ByteTensor, optional): typically used to
459
+ implement causal attention, where the mask prevents the
460
+ attention from looking forward in time (default: None).
461
+ before_softmax (bool, optional): return the raw attention
462
+ weights and values before the attention softmax.
463
+ need_head_weights (bool, optional): return the attention
464
+ weights for each head. Implies *need_weights*. Default:
465
+ return the average attention weights over all heads.
466
+ """
467
+ if need_head_weights:
468
+ need_weights = True
469
+
470
+ is_tpu = query.device.type == "xla"
471
+
472
+ tgt_len, bsz, embed_dim = query.size()
473
+ src_len = tgt_len
474
+ assert embed_dim == self.embed_dim
475
+ assert list(query.size()) == [tgt_len, bsz, embed_dim]
476
+ if key is not None:
477
+ src_len, key_bsz, _ = key.size()
478
+ if not torch.jit.is_scripting():
479
+ assert key_bsz == bsz
480
+ assert value is not None
481
+ assert src_len, bsz == value.shape[:2]
482
+
483
+ if self.has_relative_attention_bias and position_bias is None:
484
+ position_bias = self.compute_bias(tgt_len, src_len)
485
+ position_bias = position_bias.unsqueeze(0).repeat(bsz, 1, 1, 1).view(bsz * self.num_heads, tgt_len, src_len)
486
+
487
+ if incremental_state is not None:
488
+ saved_state = self._get_input_buffer(incremental_state)
489
+ if saved_state is not None and "prev_key" in saved_state:
490
+ # previous time steps are cached - no need to recompute
491
+ # key and value if they are static
492
+ if static_kv:
493
+ assert self.encoder_decoder_attention and not self.self_attention
494
+ key = value = None
495
+ else:
496
+ saved_state = None
497
+
498
+ if self.self_attention:
499
+ q = self.q_proj(query)
500
+ k = self.k_proj(query)
501
+ v = self.v_proj(query)
502
+ elif self.encoder_decoder_attention:
503
+ # encoder-decoder attention
504
+ q = self.q_proj(query)
505
+ if key is None:
506
+ assert value is None
507
+ k = v = None
508
+ else:
509
+ k = self.k_proj(key)
510
+ v = self.v_proj(key)
511
+
512
+ else:
513
+ assert key is not None and value is not None
514
+ q = self.q_proj(query)
515
+ k = self.k_proj(key)
516
+ v = self.v_proj(value)
517
+ q *= self.scaling
518
+ alpha = 32
519
+ q *= 1 / alpha
520
+
521
+ if self.bias_k is not None:
522
+ assert self.bias_v is not None
523
+ k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
524
+ v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
525
+ if attn_mask is not None:
526
+ attn_mask = torch.cat(
527
+ [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
528
+ )
529
+ if key_padding_mask is not None:
530
+ key_padding_mask = torch.cat(
531
+ [
532
+ key_padding_mask,
533
+ key_padding_mask.new_zeros(key_padding_mask.size(0), 1),
534
+ ],
535
+ dim=1,
536
+ )
537
+
538
+ q = (
539
+ q.contiguous()
540
+ .view(tgt_len, bsz * self.num_heads, self.q_head_dim)
541
+ .transpose(0, 1)
542
+ )
543
+ if k is not None:
544
+ k = (
545
+ k.contiguous()
546
+ .view(-1, bsz * self.num_heads, self.k_head_dim)
547
+ .transpose(0, 1)
548
+ )
549
+ if v is not None:
550
+ v = (
551
+ v.contiguous()
552
+ .view(-1, bsz * self.num_heads, self.head_dim)
553
+ .transpose(0, 1)
554
+ )
555
+
556
+ if saved_state is not None:
557
+ # saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
558
+ if "prev_key" in saved_state:
559
+ _prev_key = saved_state["prev_key"]
560
+ assert _prev_key is not None
561
+ prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim)
562
+ if static_kv:
563
+ k = prev_key
564
+ else:
565
+ assert k is not None
566
+ k = torch.cat([prev_key, k], dim=1)
567
+ src_len = k.size(1)
568
+ if "prev_value" in saved_state:
569
+ _prev_value = saved_state["prev_value"]
570
+ assert _prev_value is not None
571
+ prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim)
572
+ if static_kv:
573
+ v = prev_value
574
+ else:
575
+ assert v is not None
576
+ v = torch.cat([prev_value, v], dim=1)
577
+ prev_key_padding_mask: Optional[Tensor] = None
578
+ if "prev_key_padding_mask" in saved_state:
579
+ prev_key_padding_mask = saved_state["prev_key_padding_mask"]
580
+ assert k is not None and v is not None
581
+ key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(
582
+ key_padding_mask=key_padding_mask,
583
+ prev_key_padding_mask=prev_key_padding_mask,
584
+ batch_size=bsz,
585
+ src_len=k.size(1),
586
+ static_kv=static_kv,
587
+ )
588
+
589
+ saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim)
590
+ saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim)
591
+ saved_state["prev_key_padding_mask"] = key_padding_mask
592
+ # In this branch incremental_state is never None
593
+ assert incremental_state is not None
594
+ incremental_state = self._set_input_buffer(incremental_state, saved_state)
595
+ assert k is not None
596
+ assert k.size(1) == src_len
597
+
598
+ # This is part of a workaround to get around fork/join parallelism
599
+ # not supporting Optional types.
600
+ if key_padding_mask is not None and key_padding_mask.dim() == 0:
601
+ key_padding_mask = None
602
+
603
+ if key_padding_mask is not None:
604
+ assert key_padding_mask.size(0) == bsz
605
+ assert key_padding_mask.size(1) == src_len
606
+
607
+ if self.add_zero_attn:
608
+ assert v is not None
609
+ src_len += 1
610
+ k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
611
+ v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
612
+ if attn_mask is not None:
613
+ attn_mask = torch.cat(
614
+ [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
615
+ )
616
+ if key_padding_mask is not None:
617
+ key_padding_mask = torch.cat(
618
+ [
619
+ key_padding_mask,
620
+ torch.zeros(key_padding_mask.size(0), 1).type_as(
621
+ key_padding_mask
622
+ ),
623
+ ],
624
+ dim=1,
625
+ )
626
+
627
+ attn_weights = torch.bmm(q, k.transpose(1, 2))
628
+ attn_weights = (attn_weights - attn_weights.max(dim=-1, keepdim=True)[0]) * alpha
629
+ attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
630
+
631
+ assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
632
+
633
+ if attn_mask is not None:
634
+ attn_mask = attn_mask.unsqueeze(0)
635
+ attn_weights += attn_mask
636
+
637
+ if key_padding_mask is not None:
638
+ # don't attend to padding symbols
639
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
640
+ if not is_tpu:
641
+ attn_weights = attn_weights.masked_fill(
642
+ key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
643
+ float("-inf"),
644
+ )
645
+ else:
646
+ attn_weights = attn_weights.transpose(0, 2)
647
+ attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf"))
648
+ attn_weights = attn_weights.transpose(0, 2)
649
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
650
+
651
+ if before_softmax:
652
+ return attn_weights, v, position_bias
653
+
654
+ if position_bias is not None:
655
+ attn_mask_rel_pos = position_bias
656
+ if self.gru_rel_pos == 1:
657
+ query_layer = q.view(bsz, self.num_heads, tgt_len, self.q_head_dim) * alpha / self.scaling
658
+ _B, _H, _L, __ = query_layer.size()
659
+ gate_a, gate_b = torch.sigmoid(self.grep_linear(query_layer).view(
660
+ _B, _H, _L, 2, 4).sum(-1, keepdim=False)).chunk(2, dim=-1)
661
+ gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0
662
+ attn_mask_rel_pos = gate_a_1.view(bsz * self.num_heads, tgt_len, 1) * position_bias
663
+
664
+ attn_mask_rel_pos = attn_mask_rel_pos.view(attn_weights.size())
665
+
666
+ attn_weights = attn_weights + attn_mask_rel_pos
667
+
668
+ attn_weights_float = F.softmax(
669
+ attn_weights, dim=-1
670
+ )
671
+ attn_weights = attn_weights_float.type_as(attn_weights)
672
+ attn_probs = self.dropout_module(attn_weights)
673
+
674
+ assert v is not None
675
+ attn = torch.bmm(attn_probs, v)
676
+ assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
677
+ attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
678
+ attn = self.out_proj(attn)
679
+ attn_weights: Optional[Tensor] = None
680
+ if need_weights:
681
+ attn_weights = attn_weights_float.view(
682
+ bsz, self.num_heads, tgt_len, src_len
683
+ ).transpose(1, 0)
684
+ if not need_head_weights:
685
+ # average attention weights over heads
686
+ attn_weights = attn_weights.mean(dim=0)
687
+
688
+ return attn, attn_weights, position_bias
689
+
690
+ @staticmethod
691
+ def _append_prev_key_padding_mask(
692
+ key_padding_mask: Optional[Tensor],
693
+ prev_key_padding_mask: Optional[Tensor],
694
+ batch_size: int,
695
+ src_len: int,
696
+ static_kv: bool,
697
+ ) -> Optional[Tensor]:
698
+ # saved key padding masks have shape (bsz, seq_len)
699
+ if prev_key_padding_mask is not None and static_kv:
700
+ new_key_padding_mask = prev_key_padding_mask
701
+ elif prev_key_padding_mask is not None and key_padding_mask is not None:
702
+ new_key_padding_mask = torch.cat(
703
+ [prev_key_padding_mask.float(), key_padding_mask.float()], dim=1
704
+ )
705
+ # During incremental decoding, as the padding token enters and
706
+ # leaves the frame, there will be a time when prev or current
707
+ # is None
708
+ elif prev_key_padding_mask is not None:
709
+ if src_len > prev_key_padding_mask.size(1):
710
+ filler = torch.zeros(
711
+ (batch_size, src_len - prev_key_padding_mask.size(1)),
712
+ device=prev_key_padding_mask.device,
713
+ )
714
+ new_key_padding_mask = torch.cat(
715
+ [prev_key_padding_mask.float(), filler.float()], dim=1
716
+ )
717
+ else:
718
+ new_key_padding_mask = prev_key_padding_mask.float()
719
+ elif key_padding_mask is not None:
720
+ if src_len > key_padding_mask.size(1):
721
+ filler = torch.zeros(
722
+ (batch_size, src_len - key_padding_mask.size(1)),
723
+ device=key_padding_mask.device,
724
+ )
725
+ new_key_padding_mask = torch.cat(
726
+ [filler.float(), key_padding_mask.float()], dim=1
727
+ )
728
+ else:
729
+ new_key_padding_mask = key_padding_mask.float()
730
+ else:
731
+ new_key_padding_mask = prev_key_padding_mask
732
+ return new_key_padding_mask
733
+
734
+ def _get_input_buffer(
735
+ self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]
736
+ ) -> Dict[str, Optional[Tensor]]:
737
+ result = self.get_incremental_state(incremental_state, "attn_state")
738
+ if result is not None:
739
+ return result
740
+ else:
741
+ empty_result: Dict[str, Optional[Tensor]] = {}
742
+ return empty_result
743
+
744
+ def _set_input_buffer(
745
+ self,
746
+ incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
747
+ buffer: Dict[str, Optional[Tensor]],
748
+ ):
749
+ return self.set_incremental_state(incremental_state, "attn_state", buffer)
750
+
751
+ def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int):
752
+ return attn_weights
753
+
754
+
755
+ def init_bert_params(module):
756
+ """
757
+ Initialize the weights specific to the BERT Model.
758
+ This overrides the default initializations depending on the specified arguments.
759
+ 1. If normal_init_linear_weights is set then weights of linear
760
+ layer will be initialized using the normal distribution and
761
+ bais will be set to the specified value.
762
+ 2. If normal_init_embed_weights is set then weights of embedding
763
+ layer will be initialized using the normal distribution.
764
+ 3. If normal_init_proj_weights is set then weights of
765
+ in_project_weight for MultiHeadAttention initialized using
766
+ the normal distribution (to be validated).
767
+ """
768
+
769
+ def normal_(data):
770
+ # with FSDP, module params will be on CUDA, so we cast them back to CPU
771
+ # so that the RNG is consistent with and without FSDP
772
+ data.copy_(
773
+ data.cpu().normal_(mean=0.0, std=0.02).to(data.device)
774
+ )
775
+
776
+ if isinstance(module, nn.Linear):
777
+ normal_(module.weight.data)
778
+ if module.bias is not None:
779
+ module.bias.data.zero_()
780
+ if isinstance(module, nn.Embedding):
781
+ normal_(module.weight.data)
782
+ if module.padding_idx is not None:
783
+ module.weight.data[module.padding_idx].zero_()
784
+ if isinstance(module, MultiheadAttention):
785
+ normal_(module.q_proj.weight.data)
786
+ normal_(module.k_proj.weight.data)
787
+ normal_(module.v_proj.weight.data)
modules/beats/modules.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # beats: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
3
+ # Github source: https://github.com/microsoft/unilm/tree/master/beats
4
+ # Copyright (c) 2022 Microsoft
5
+ # Licensed under The MIT License [see LICENSE for details]
6
+ # Based on fairseq code bases
7
+ # https://github.com/pytorch/fairseq
8
+ # --------------------------------------------------------
9
+
10
+ import math
11
+ import warnings
12
+ import torch
13
+ from torch import Tensor, nn
14
+ import torch.nn.functional as F
15
+
16
+
17
+ class GradMultiply(torch.autograd.Function):
18
+ @staticmethod
19
+ def forward(ctx, x, scale):
20
+ ctx.scale = scale
21
+ res = x.new(x)
22
+ return res
23
+
24
+ @staticmethod
25
+ def backward(ctx, grad):
26
+ return grad * ctx.scale, None
27
+
28
+
29
+ class SamePad(nn.Module):
30
+ def __init__(self, kernel_size, causal=False):
31
+ super().__init__()
32
+ if causal:
33
+ self.remove = kernel_size - 1
34
+ else:
35
+ self.remove = 1 if kernel_size % 2 == 0 else 0
36
+
37
+ def forward(self, x):
38
+ if self.remove > 0:
39
+ x = x[:, :, : -self.remove]
40
+ return x
41
+
42
+
43
+ class Swish(nn.Module):
44
+ def __init__(self):
45
+ super(Swish, self).__init__()
46
+ self.act = torch.nn.Sigmoid()
47
+
48
+ def forward(self, x):
49
+ return x * self.act(x)
50
+
51
+
52
+ class GLU_Linear(nn.Module):
53
+ def __init__(self, input_dim, output_dim, glu_type="sigmoid", bias_in_glu=True):
54
+ super(GLU_Linear, self).__init__()
55
+
56
+ self.glu_type = glu_type
57
+ self.output_dim = output_dim
58
+
59
+ if glu_type == "sigmoid":
60
+ self.glu_act = torch.nn.Sigmoid()
61
+ elif glu_type == "swish":
62
+ self.glu_act = Swish()
63
+ elif glu_type == "relu":
64
+ self.glu_act = torch.nn.ReLU()
65
+ elif glu_type == "gelu":
66
+ self.glu_act = torch.nn.GELU()
67
+
68
+ if bias_in_glu:
69
+ self.linear = nn.Linear(input_dim, output_dim * 2, True)
70
+ else:
71
+ self.linear = nn.Linear(input_dim, output_dim * 2, False)
72
+
73
+ def forward(self, x):
74
+ # to be consistent with GLU_Linear, we assume the input always has the #channel (#dim) in the last dimension of the tensor, so need to switch the dimension first for 1D-Conv case
75
+ x = self.linear(x)
76
+
77
+ if self.glu_type == "bilinear":
78
+ x = (x[:, :, 0:self.output_dim] * x[:, :, self.output_dim:self.output_dim * 2])
79
+ else:
80
+ x = (x[:, :, 0:self.output_dim] * self.glu_act(x[:, :, self.output_dim:self.output_dim * 2]))
81
+
82
+ return x
83
+
84
+
85
+ def gelu_accurate(x):
86
+ if not hasattr(gelu_accurate, "_a"):
87
+ gelu_accurate._a = math.sqrt(2 / math.pi)
88
+ return (
89
+ 0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3))))
90
+ )
91
+
92
+
93
+ def gelu(x: torch.Tensor) -> torch.Tensor:
94
+ return torch.nn.functional.gelu(x.float()).type_as(x)
95
+
96
+
97
+ def get_activation_fn(activation: str):
98
+ """Returns the activation function corresponding to `activation`"""
99
+
100
+ if activation == "relu":
101
+ return F.relu
102
+ elif activation == "gelu":
103
+ return gelu
104
+ elif activation == "gelu_fast":
105
+ warnings.warn(
106
+ "--activation-fn=gelu_fast has been renamed to gelu_accurate"
107
+ )
108
+ return gelu_accurate
109
+ elif activation == "gelu_accurate":
110
+ return gelu_accurate
111
+ elif activation == "tanh":
112
+ return torch.tanh
113
+ elif activation == "linear":
114
+ return lambda x: x
115
+ elif activation == "glu":
116
+ return lambda x: x
117
+ else:
118
+ raise RuntimeError("--activation-fn {} not supported".format(activation))
119
+
120
+
121
+ def quant_noise(module, p, block_size):
122
+ """
123
+ Wraps modules and applies quantization noise to the weights for
124
+ subsequent quantization with Iterative Product Quantization as
125
+ described in "Training with Quantization Noise for Extreme Model Compression"
126
+
127
+ Args:
128
+ - module: nn.Module
129
+ - p: amount of Quantization Noise
130
+ - block_size: size of the blocks for subsequent quantization with iPQ
131
+
132
+ Remarks:
133
+ - Module weights must have the right sizes wrt the block size
134
+ - Only Linear, Embedding and Conv2d modules are supported for the moment
135
+ - For more detail on how to quantize by blocks with convolutional weights,
136
+ see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks"
137
+ - We implement the simplest form of noise here as stated in the paper
138
+ which consists in randomly dropping blocks
139
+ """
140
+
141
+ # if no quantization noise, don't register hook
142
+ if p <= 0:
143
+ return module
144
+
145
+ # supported modules
146
+ assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d))
147
+
148
+ # test whether module.weight has the right sizes wrt block_size
149
+ is_conv = module.weight.ndim == 4
150
+
151
+ # 2D matrix
152
+ if not is_conv:
153
+ assert (
154
+ module.weight.size(1) % block_size == 0
155
+ ), "Input features must be a multiple of block sizes"
156
+
157
+ # 4D matrix
158
+ else:
159
+ # 1x1 convolutions
160
+ if module.kernel_size == (1, 1):
161
+ assert (
162
+ module.in_channels % block_size == 0
163
+ ), "Input channels must be a multiple of block sizes"
164
+ # regular convolutions
165
+ else:
166
+ k = module.kernel_size[0] * module.kernel_size[1]
167
+ assert k % block_size == 0, "Kernel size must be a multiple of block size"
168
+
169
+ def _forward_pre_hook(mod, input):
170
+ # no noise for evaluation
171
+ if mod.training:
172
+ if not is_conv:
173
+ # gather weight and sizes
174
+ weight = mod.weight
175
+ in_features = weight.size(1)
176
+ out_features = weight.size(0)
177
+
178
+ # split weight matrix into blocks and randomly drop selected blocks
179
+ mask = torch.zeros(
180
+ in_features // block_size * out_features, device=weight.device
181
+ )
182
+ mask.bernoulli_(p)
183
+ mask = mask.repeat_interleave(block_size, -1).view(-1, in_features)
184
+
185
+ else:
186
+ # gather weight and sizes
187
+ weight = mod.weight
188
+ in_channels = mod.in_channels
189
+ out_channels = mod.out_channels
190
+
191
+ # split weight matrix into blocks and randomly drop selected blocks
192
+ if mod.kernel_size == (1, 1):
193
+ mask = torch.zeros(
194
+ int(in_channels // block_size * out_channels),
195
+ device=weight.device,
196
+ )
197
+ mask.bernoulli_(p)
198
+ mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels)
199
+ else:
200
+ mask = torch.zeros(
201
+ weight.size(0), weight.size(1), device=weight.device
202
+ )
203
+ mask.bernoulli_(p)
204
+ mask = (
205
+ mask.unsqueeze(2)
206
+ .unsqueeze(3)
207
+ .repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1])
208
+ )
209
+
210
+ # scale weights and apply mask
211
+ mask = mask.to(
212
+ torch.bool
213
+ ) # x.bool() is not currently supported in TorchScript
214
+ s = 1 / (1 - p)
215
+ mod.weight.data = s * weight.masked_fill(mask, 0)
216
+
217
+ module.register_forward_pre_hook(_forward_pre_hook)
218
+ return module
modules/beats/quantizer.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # beats: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
3
+ # Github source: https://github.com/microsoft/unilm/tree/master/beats
4
+ # Copyright (c) 2022 Microsoft
5
+ # Licensed under The MIT License [see LICENSE for details]
6
+ # Based on VQGAN code bases
7
+ # https://github.com/CompVis/taming-transformers
8
+ # --------------------------------------------------------'
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ import torch.distributed as distributed
14
+
15
+ try:
16
+ from einops import rearrange, repeat
17
+ except ImportError:
18
+ pass
19
+
20
+
21
+ def l2norm(t):
22
+ return F.normalize(t, p=2, dim=-1)
23
+
24
+
25
+ def ema_inplace(moving_avg, new, decay):
26
+ moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
27
+
28
+
29
+ def sample_vectors(samples, num):
30
+ num_samples, device = samples.shape[0], samples.device
31
+
32
+ if num_samples >= num:
33
+ indices = torch.randperm(num_samples, device=device)[:num]
34
+ else:
35
+ indices = torch.randint(0, num_samples, (num,), device=device)
36
+
37
+ return samples[indices]
38
+
39
+
40
+ def kmeans(samples, num_clusters, num_iters=10, use_cosine_sim=False):
41
+ dim, dtype, device = samples.shape[-1], samples.dtype, samples.device
42
+
43
+ means = sample_vectors(samples, num_clusters)
44
+
45
+ for _ in range(num_iters):
46
+ if use_cosine_sim:
47
+ dists = samples @ means.t()
48
+ else:
49
+ diffs = rearrange(samples, 'n d -> n () d') \
50
+ - rearrange(means, 'c d -> () c d')
51
+ dists = -(diffs ** 2).sum(dim=-1)
52
+
53
+ buckets = dists.max(dim=-1).indices
54
+ bins = torch.bincount(buckets, minlength=num_clusters)
55
+ zero_mask = bins == 0
56
+ bins_min_clamped = bins.masked_fill(zero_mask, 1)
57
+
58
+ new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
59
+ new_means.scatter_add_(0, repeat(buckets, 'n -> n d', d=dim), samples)
60
+ new_means = new_means / bins_min_clamped[..., None]
61
+
62
+ if use_cosine_sim:
63
+ new_means = l2norm(new_means)
64
+
65
+ means = torch.where(zero_mask[..., None], means, new_means)
66
+
67
+ return means, bins
68
+
69
+
70
+ class EmbeddingEMA(nn.Module):
71
+ def __init__(self, num_tokens, codebook_dim, decay=0.99, eps=1e-5, kmeans_init=True, codebook_init_path=''):
72
+ super().__init__()
73
+ self.num_tokens = num_tokens
74
+ self.codebook_dim = codebook_dim
75
+ self.decay = decay
76
+ self.eps = eps
77
+ if codebook_init_path == '':
78
+ if not kmeans_init:
79
+ weight = torch.randn(num_tokens, codebook_dim)
80
+ weight = l2norm(weight)
81
+ else:
82
+ weight = torch.zeros(num_tokens, codebook_dim)
83
+ self.register_buffer('initted', torch.Tensor([not kmeans_init]))
84
+ else:
85
+ print(f"load init codebook weight from {codebook_init_path}")
86
+ codebook_ckpt_weight = torch.load(codebook_init_path, map_location='cpu')
87
+ weight = codebook_ckpt_weight.clone()
88
+ self.register_buffer('initted', torch.Tensor([True]))
89
+
90
+ self.weight = nn.Parameter(weight, requires_grad=False)
91
+ self.cluster_size = nn.Parameter(torch.zeros(num_tokens), requires_grad=False)
92
+ self.embed_avg = nn.Parameter(weight.clone(), requires_grad=False)
93
+ # self.register_buffer('initted', torch.Tensor([not kmeans_init]))
94
+ self.update = True
95
+
96
+ @torch.jit.ignore
97
+ def init_embed_(self, data):
98
+ if self.initted:
99
+ return
100
+ print("Performing Kemans init for codebook")
101
+ embed, cluster_size = kmeans(data, self.num_tokens, 10, use_cosine_sim=True)
102
+ self.weight.data.copy_(embed)
103
+ self.cluster_size.data.copy_(cluster_size)
104
+ self.initted.data.copy_(torch.Tensor([True]))
105
+
106
+ def forward(self, embed_id):
107
+ return F.embedding(embed_id, self.weight)
108
+
109
+ def cluster_size_ema_update(self, new_cluster_size):
110
+ self.cluster_size.data.mul_(self.decay).add_(new_cluster_size, alpha=1 - self.decay)
111
+
112
+ def embed_avg_ema_update(self, new_embed_avg):
113
+ self.embed_avg.data.mul_(self.decay).add_(new_embed_avg, alpha=1 - self.decay)
114
+
115
+ def weight_update(self, num_tokens):
116
+ n = self.cluster_size.sum()
117
+ smoothed_cluster_size = (
118
+ (self.cluster_size + self.eps) / (n + num_tokens * self.eps) * n
119
+ )
120
+ # normalize embedding average with smoothed cluster size
121
+ embed_normalized = self.embed_avg / smoothed_cluster_size.unsqueeze(1)
122
+ # embed_normalized = l2norm(self.embed_avg / smoothed_cluster_size.unsqueeze(1))
123
+ self.weight.data.copy_(embed_normalized)
124
+
125
+
126
+ def norm_ema_inplace(moving_avg, new, decay):
127
+ moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
128
+ moving_avg.data.copy_(l2norm(moving_avg.data))
129
+
130
+
131
+ class NormEMAVectorQuantizer(nn.Module):
132
+ def __init__(self, n_embed, embedding_dim, beta, decay=0.99, eps=1e-5,
133
+ statistic_code_usage=True, kmeans_init=False, codebook_init_path=''):
134
+ super().__init__()
135
+ self.codebook_dim = embedding_dim
136
+ self.num_tokens = n_embed
137
+ self.beta = beta
138
+ self.decay = decay
139
+
140
+ # learnable = True if orthogonal_reg_weight > 0 else False
141
+ self.embedding = EmbeddingEMA(self.num_tokens, self.codebook_dim, decay, eps, kmeans_init, codebook_init_path)
142
+
143
+ self.statistic_code_usage = statistic_code_usage
144
+ if statistic_code_usage:
145
+ self.register_buffer('cluster_size', torch.zeros(n_embed))
146
+ if distributed.is_available() and distributed.is_initialized():
147
+ print("ddp is enable, so use ddp_reduce to sync the statistic_code_usage for each gpu!")
148
+ self.all_reduce_fn = distributed.all_reduce
149
+ else:
150
+ self.all_reduce_fn = nn.Identity()
151
+
152
+ def reset_cluster_size(self, device):
153
+ if self.statistic_code_usage:
154
+ self.register_buffer('cluster_size', torch.zeros(self.num_tokens))
155
+ self.cluster_size = self.cluster_size.to(device)
156
+
157
+ def forward(self, z):
158
+ # reshape z -> (batch, height, width, channel) and flatten
159
+ # z, 'b c h w -> b h w c'
160
+ # z = rearrange(z, 'b c h w -> b h w c')
161
+ # z = z.transpose(1, 2)
162
+ z = l2norm(z)
163
+ z_flattened = z.reshape(-1, self.codebook_dim)
164
+
165
+ self.embedding.init_embed_(z_flattened)
166
+
167
+ d = z_flattened.pow(2).sum(dim=1, keepdim=True) + \
168
+ self.embedding.weight.pow(2).sum(dim=1) - 2 * \
169
+ torch.einsum('bd,nd->bn', z_flattened, self.embedding.weight) # 'n d -> d n'
170
+
171
+ encoding_indices = torch.argmin(d, dim=1)
172
+
173
+ z_q = self.embedding(encoding_indices).view(z.shape)
174
+
175
+ encodings = F.one_hot(encoding_indices, self.num_tokens).type(z.dtype)
176
+
177
+ if not self.training:
178
+ with torch.no_grad():
179
+ cluster_size = encodings.sum(0)
180
+ self.all_reduce_fn(cluster_size)
181
+ ema_inplace(self.cluster_size, cluster_size, self.decay)
182
+
183
+ if self.training and self.embedding.update:
184
+ # EMA cluster size
185
+
186
+ bins = encodings.sum(0)
187
+ self.all_reduce_fn(bins)
188
+
189
+ # self.embedding.cluster_size_ema_update(bins)
190
+ ema_inplace(self.cluster_size, bins, self.decay)
191
+
192
+ zero_mask = (bins == 0)
193
+ bins = bins.masked_fill(zero_mask, 1.)
194
+
195
+ embed_sum = z_flattened.t() @ encodings
196
+ self.all_reduce_fn(embed_sum)
197
+
198
+ embed_normalized = (embed_sum / bins.unsqueeze(0)).t()
199
+ embed_normalized = l2norm(embed_normalized)
200
+
201
+ embed_normalized = torch.where(zero_mask[..., None], self.embedding.weight,
202
+ embed_normalized)
203
+ norm_ema_inplace(self.embedding.weight, embed_normalized, self.decay)
204
+
205
+ # compute loss for embedding
206
+ loss = self.beta * F.mse_loss(z_q.detach(), z)
207
+
208
+ # preserve gradients
209
+ z_q = z + (z_q - z).detach()
210
+
211
+ # reshape back to match original input shape
212
+ # z_q, 'b h w c -> b c h w'
213
+ # z_q = rearrange(z_q, 'b h w c -> b c h w')
214
+ # z_q = z_q.transpose(1, 2)
215
+ return z_q, loss, encoding_indices
modules/fga/atten.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from torch.autograd import Variable
6
+ from itertools import product, permutations, combinations_with_replacement, chain
7
+
8
+
9
+ class Unary(nn.Module):
10
+ def __init__(self, embed_size):
11
+ """
12
+ Captures local entity information
13
+ :param embed_size: the embedding dimension
14
+ """
15
+ super(Unary, self).__init__()
16
+ self.embed = nn.Conv1d(embed_size, embed_size, 1)
17
+ self.feature_reduce = nn.Conv1d(embed_size, 1, 1)
18
+
19
+ def forward(self, X):
20
+ X = X.transpose(1, 2)
21
+
22
+ X_embed = self.embed(X)
23
+
24
+ X_nl_embed = F.dropout(F.relu(X_embed))
25
+ X_poten = self.feature_reduce(X_nl_embed)
26
+ return X_poten.squeeze(1)
27
+
28
+
29
+ class Pairwise(nn.Module):
30
+ def __init__(self, embed_x_size, x_spatial_dim=None, embed_y_size=None, y_spatial_dim=None):
31
+ """
32
+ Captures interaction between utilities or entities of the same utility
33
+ :param embed_x_size: the embedding dimension of the first utility
34
+ :param x_spatial_dim: the spatial dimension of the first utility for batch norm and weighted marginalization
35
+ :param embed_y_size: the embedding dimension of the second utility (none for self-interactions)
36
+ :param y_spatial_dim: the spatial dimension of the second utility for batch norm and weighted marginalization
37
+ """
38
+
39
+ super(Pairwise, self).__init__()
40
+ embed_y_size = embed_y_size if y_spatial_dim is not None else embed_x_size
41
+ self.y_spatial_dim = y_spatial_dim if y_spatial_dim is not None else x_spatial_dim
42
+
43
+ self.embed_size = max(embed_x_size, embed_y_size)
44
+ self.x_spatial_dim = x_spatial_dim
45
+
46
+ self.embed_X = nn.Conv1d(embed_x_size, self.embed_size, 1)
47
+ self.embed_Y = nn.Conv1d(embed_y_size, self.embed_size, 1)
48
+ if x_spatial_dim is not None:
49
+ self.normalize_S = nn.BatchNorm1d(self.x_spatial_dim * self.y_spatial_dim)
50
+
51
+ self.margin_X = nn.Conv1d(self.y_spatial_dim, 1, 1)
52
+ self.margin_Y = nn.Conv1d(self.x_spatial_dim, 1, 1)
53
+
54
+ def forward(self, X, Y=None):
55
+
56
+ X_t = X.transpose(1, 2)
57
+ Y_t = Y.transpose(1, 2) if Y is not None else X_t
58
+
59
+
60
+ X_embed = self.embed_X(X_t)
61
+ Y_embed = self.embed_Y(Y_t)
62
+
63
+ X_norm = F.normalize(X_embed)
64
+ Y_norm = F.normalize(Y_embed)
65
+
66
+ S = X_norm.transpose(1, 2).bmm(Y_norm)
67
+ if self.x_spatial_dim is not None:
68
+ S = self.normalize_S(S.view(-1, self.x_spatial_dim * self.y_spatial_dim)) \
69
+ .view(-1, self.x_spatial_dim, self.y_spatial_dim)
70
+
71
+ X_poten = self.margin_X(S.transpose(1, 2)).transpose(1, 2).squeeze(2)
72
+ Y_poten = self.margin_Y(S).transpose(1, 2).squeeze(2)
73
+ else:
74
+ X_poten = S.mean(dim=2, keepdim=False)
75
+ Y_poten = S.mean(dim=1, keepdim=False)
76
+
77
+ if Y is None:
78
+ return X_poten
79
+ else:
80
+ return X_poten, Y_poten
81
+
82
+
83
+ class Atten(nn.Module):
84
+ def __init__(self, util_e, sharing_factor_weights=[], prior_flag=False,
85
+ sizes=[], size_force=False, pairwise_flag=True,
86
+ unary_flag=True, self_flag=True):
87
+ """
88
+ The class performs an attention on a given list of utilities representation.
89
+ :param util_e: the embedding dimensions
90
+ :param sharing_factor_weights: To share weights, provide a dict of tuples:
91
+ {idx: (num_utils, connected utils)
92
+ Note, for efficiency, the shared utils (i.e., history, are connected to ans
93
+ and question only.
94
+ TODO: connections between shared utils
95
+ :param prior_flag: is prior factor provided
96
+ :param sizes: the spatial simension (used for batch-norm and weighted marginalization)
97
+ :param size_force: force spatial size with adaptive avg pooling.
98
+ :param pairwise_flag: use pairwise interaction between utilities
99
+ :param unary_flag: use local information
100
+ :param self_flag: use self interactions between utilitie's entities
101
+ """
102
+ super(Atten, self).__init__()
103
+ self.util_e = util_e
104
+
105
+ self.prior_flag = prior_flag
106
+
107
+ self.n_utils = len(util_e)
108
+
109
+ self.spatial_pool = nn.ModuleDict()
110
+
111
+ self.un_models = nn.ModuleList()
112
+
113
+ self.self_flag = self_flag
114
+ self.pairwise_flag = pairwise_flag
115
+ self.unary_flag = unary_flag
116
+ self.size_force = size_force
117
+
118
+ if len(sizes) == 0:
119
+ sizes = [None for _ in util_e]
120
+
121
+ self.sharing_factor_weights = sharing_factor_weights
122
+
123
+ #force the provided size
124
+ for idx, e_dim in enumerate(util_e):
125
+ self.un_models.append(Unary(e_dim))
126
+ if self.size_force:
127
+ self.spatial_pool[str(idx)] = nn.AdaptiveAvgPool1d(sizes[idx])
128
+
129
+ #Pairwise
130
+ self.pp_models = nn.ModuleDict()
131
+ for ((idx1, e_dim_1), (idx2, e_dim_2)) \
132
+ in combinations_with_replacement(enumerate(util_e), 2):
133
+ # self
134
+ if self.self_flag and idx1 == idx2:
135
+ self.pp_models[str(idx1)] = Pairwise(e_dim_1, sizes[idx1])
136
+ else:
137
+ if pairwise_flag:
138
+ if idx1 in self.sharing_factor_weights:
139
+ # not connected
140
+ if idx2 not in self.sharing_factor_weights[idx1][1]:
141
+ continue
142
+ if idx2 in self.sharing_factor_weights:
143
+ # not connected
144
+ if idx1 not in self.sharing_factor_weights[idx2][1]:
145
+ continue
146
+ self.pp_models[str((idx1, idx2))] = Pairwise(e_dim_1, sizes[idx1], e_dim_2, sizes[idx2])
147
+
148
+ # Handle reduce potentials (with scalars)
149
+ self.reduce_potentials = nn.ModuleList()
150
+
151
+ self.num_of_potentials = dict()
152
+
153
+ self.default_num_of_potentials = 0
154
+
155
+ if self.self_flag:
156
+ self.default_num_of_potentials += 1
157
+ if self.unary_flag:
158
+ self.default_num_of_potentials += 1
159
+ if self.prior_flag:
160
+ self.default_num_of_potentials += 1
161
+ for idx in range(self.n_utils):
162
+ self.num_of_potentials[idx] = self.default_num_of_potentials
163
+
164
+ '''
165
+ All other utilities
166
+ '''
167
+ if pairwise_flag:
168
+ for idx, (num_utils, connected_utils) in sharing_factor_weights:
169
+ for c_u in connected_utils:
170
+ self.num_of_potentials[c_u] += num_utils
171
+ self.num_of_potentials[idx] += 1
172
+ for k in self.num_of_potentials:
173
+ if k not in self.sharing_factor_weights:
174
+ self.num_of_potentials[k] += (self.n_utils - 1) \
175
+ - len(sharing_factor_weights)
176
+
177
+ for idx in range(self.n_utils):
178
+ self.reduce_potentials.append(nn.Conv1d(self.num_of_potentials[idx],
179
+ 1, 1, bias=False))
180
+
181
+ def forward(self, utils, priors=None):
182
+ assert self.n_utils == len(utils)
183
+ assert (priors is None and not self.prior_flag) \
184
+ or (priors is not None
185
+ and self.prior_flag
186
+ and len(priors) == self.n_utils)
187
+ b_size = utils[0].size(0)
188
+ util_factors = dict()
189
+ attention = list()
190
+
191
+ #Force size, constant size is used for pairwise batch normalization
192
+ if self.size_force:
193
+ for i, (num_utils, _) in self.sharing_factor_weights.items():
194
+ if str(i) not in self.spatial_pool.keys():
195
+ continue
196
+ else:
197
+ high_util = utils[i]
198
+ high_util = high_util.view(num_utils * b_size, high_util.size(2), high_util.size(3))
199
+ high_util = high_util.transpose(1, 2)
200
+ utils[i] = self.spatial_pool[str(i)](high_util).transpose(1, 2)
201
+
202
+ for i in range(self.n_utils):
203
+ if i in self.sharing_factor_weights \
204
+ or str(i) not in self.spatial_pool.keys():
205
+ continue
206
+ utils[i] = utils[i].transpose(1, 2)
207
+ utils[i] = self.spatial_pool[str(i)](utils[i]).transpose(1, 2)
208
+ if self.prior_flag and priors[i] is not None:
209
+ priors[i] = self.spatial_pool[str(i)](priors[i].unsqueeze(1)).squeeze(1)
210
+
211
+ # handle Shared weights
212
+ for i, (num_utils, connected_list) in self.sharing_factor_weights:
213
+ if self.unary_flag:
214
+ util_factors.setdefault(i, []).append(self.un_models[i](utils[i]))
215
+
216
+ if self.self_flag:
217
+ util_factors.setdefault(i, []).append(self.pp_models[str(i)](utils[i]))
218
+
219
+ if self.pairwise_flag:
220
+ for j in connected_list:
221
+ other_util = utils[j]
222
+ expanded_util = other_util.unsqueeze(1).expand(b_size,
223
+ num_utils,
224
+ other_util.size(1),
225
+ other_util.size(2)).contiguous().view(
226
+ b_size * num_utils,
227
+ other_util.size(1),
228
+ other_util.size(2))
229
+
230
+ if i < j:
231
+ factor_ij, factor_ji = self.pp_models[str((i, j))](utils[i], expanded_util)
232
+ else:
233
+ factor_ji, factor_ij = self.pp_models[str((j, i))](expanded_util, utils[i])
234
+ util_factors[i].append(factor_ij)
235
+ util_factors.setdefault(j, []).append(factor_ji.view(b_size, num_utils, factor_ji.size(1)))
236
+
237
+ # handle local factors
238
+ for i in range(self.n_utils):
239
+ if i in self.sharing_factor_weights:
240
+ continue
241
+ if self.unary_flag:
242
+ util_factors.setdefault(i, []).append(self.un_models[i](utils[i]))
243
+ if self.self_flag:
244
+ util_factors.setdefault(i, []).append(self.pp_models[str(i)](utils[i]))
245
+
246
+ # joint
247
+ if self.pairwise_flag:
248
+ for (i, j) in combinations_with_replacement(range(self.n_utils), 2):
249
+ if i in self.sharing_factor_weights \
250
+ or j in self.sharing_factor_weights:
251
+ continue
252
+ if i == j:
253
+ continue
254
+ else:
255
+ factor_ij, factor_ji = self.pp_models[str((i, j))](utils[i], utils[j])
256
+ util_factors.setdefault(i, []).append(factor_ij)
257
+ util_factors.setdefault(j, []).append(factor_ji)
258
+
259
+ # perform attention
260
+ for i in range(self.n_utils):
261
+ if self.prior_flag:
262
+ prior = priors[i] \
263
+ if priors[i] is not None \
264
+ else torch.zeros_like(util_factors[i][0], requires_grad=False).cuda()
265
+
266
+ util_factors[i].append(prior)
267
+
268
+ util_factors[i] = torch.cat([p if len(p.size()) == 3 else p.unsqueeze(1)
269
+ for p in util_factors[i]], dim=1)
270
+ util_factors[i] = self.reduce_potentials[i](util_factors[i]).squeeze(1)
271
+ util_factors[i] = F.softmax(util_factors[i], dim=1).unsqueeze(2)
272
+ attention.append(torch.bmm(utils[i].transpose(1, 2), util_factors[i]).squeeze(2))
273
+
274
+ return attention
275
+
276
+
277
+ class NaiveAttention(nn.Module):
278
+ def __init__(self):
279
+ """
280
+ Used for ablation analysis - removing attention.
281
+ """
282
+ super(NaiveAttention, self).__init__()
283
+
284
+ def forward(self, utils, priors):
285
+ atten = []
286
+ spatial_atten = []
287
+ for u, p in zip(utils, priors):
288
+ if type(u) is tuple:
289
+ u = u[1]
290
+ num_elements = u.shape[0]
291
+ if p is not None:
292
+ u = u.view(-1, u.shape[-2], u.shape[-1])
293
+ p = p.view(-1, p.shape[-2], p.shape[-1])
294
+ spatial_atten.append(
295
+ torch.bmm(p.transpose(1, 2), u).squeeze(2).view(num_elements, -1, u.shape[-2], u.shape[-1]))
296
+ else:
297
+ spatial_atten.append(u.mean(2))
298
+ continue
299
+ if p is not None:
300
+ atten.append(torch.bmm(u.transpose(1, 2), p.unsqueeze(2)).squeeze(2))
301
+ else:
302
+ atten.append(u.mean(1))
303
+ return atten, spatial_atten
modules/fga/fga_model.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from atten import Atten
5
+
6
+
7
+ class FGA(nn.Module):
8
+ def __init__(self, vocab_size, word_embed_dim, hidden_ques_dim, hidden_ans_dim,
9
+ hidden_hist_dim, hidden_cap_dim, hidden_img_dim):
10
+ '''
11
+ Factor Graph Attention
12
+ :param vocab_size: vocabulary size
13
+ :param word_embed_dim
14
+ :param hidden_ques_dim:
15
+ :param hidden_ans_dim:
16
+ :param hidden_hist_dim:
17
+ :param img_features_dim:
18
+ '''
19
+ super(FGA, self).__init__()
20
+
21
+ print("Init FGA with vocab size %s, word embed %s, hidden ques %s, hidden ans %s,"
22
+ " hidden hist %s, hidden cap %s, hidden img %s" % (vocab_size, word_embed_dim,
23
+ hidden_ques_dim,
24
+ hidden_ans_dim,
25
+ hidden_hist_dim,
26
+ hidden_cap_dim,
27
+ hidden_img_dim))
28
+ self.hidden_ques_dim = hidden_ques_dim
29
+ self.hidden_ans_dim = hidden_ans_dim
30
+ self.hidden_cap_dim = hidden_cap_dim
31
+ self.hidden_img_dim = hidden_img_dim
32
+ self.hidden_hist_dim = hidden_hist_dim
33
+
34
+ # Vocab of History LSTMs is one more as we are keeping a stop id (the last id)
35
+ self.word_embedddings = nn.Embedding(vocab_size+1+1, word_embed_dim, padding_idx=0)
36
+
37
+ self.lstm_ques = nn.LSTM(word_embed_dim, self.hidden_ques_dim, batch_first=True)
38
+ self.lstm_ans = nn.LSTM(word_embed_dim, self.hidden_ans_dim, batch_first=True)
39
+
40
+ self.lstm_hist_ques = nn.LSTM(word_embed_dim, self.hidden_hist_dim, batch_first=True)
41
+ self.lstm_hist_ans = nn.LSTM(word_embed_dim, self.hidden_hist_dim, batch_first=True)
42
+
43
+ self.lstm_hist_cap = nn.LSTM(word_embed_dim, self.hidden_cap_dim, batch_first=True)
44
+
45
+
46
+ self.qahistnet = nn.Sequential(
47
+ nn.Linear(self.hidden_hist_dim*2, self.hidden_hist_dim),
48
+ nn.ReLU(inplace=True)
49
+ )
50
+
51
+ self.concat_dim = self.hidden_ques_dim + self.hidden_ans_dim + \
52
+ self.hidden_ans_dim + self.hidden_img_dim + \
53
+ self.hidden_cap_dim + self.hidden_hist_dim*9
54
+
55
+ self.simnet = nn.Sequential(
56
+ nn.Linear(self.concat_dim, (self.concat_dim)//2, bias=False),
57
+ nn.BatchNorm1d((self.concat_dim) // 2),
58
+ nn.ReLU(inplace=True),
59
+ nn.Linear((self.concat_dim)//2, (self.concat_dim)//4, bias=False),
60
+ nn.BatchNorm1d((self.concat_dim) // 4),
61
+ nn.ReLU(inplace=True),
62
+ nn.Dropout(0.5),
63
+ nn.Linear((self.concat_dim)//4, 1)
64
+ )
65
+
66
+ # To share weights, provide list of tuples: (idx, list of connected utils)
67
+ # Note, for efficiency, the shared utils (i.e., history, are connected to ans and question only.
68
+ # connecting shared factors is not supported (!)
69
+ sharing_factor_weights = {4: (9, [0, 1]),
70
+ 5: (9, [0, 1])}
71
+
72
+ self.mul_atten = Atten(util_e=[self.hidden_ans_dim, # Answer modal
73
+ self.hidden_ques_dim, # Question modal
74
+ self.hidden_cap_dim, # Caption modal
75
+ self.hidden_img_dim, # Image modal
76
+ self.hidden_hist_dim, # Question-history modal
77
+ self.hidden_hist_dim # Answer-history modal
78
+ ],
79
+ sharing_factor_weights=sharing_factor_weights,
80
+ sizes=[100, # 100 Answers
81
+ 21, # Question length
82
+ 41, # Caption length
83
+ 37, # 36 Image regions
84
+ 21, # History-Question length
85
+ 21 # History-Answer length
86
+ ] # The spatial dim used for pairwise normalization (use force for adaptive)
87
+ , prior_flag=True,
88
+ pairwise_flag=True)
89
+
90
+
91
+
92
+ def forward(self, input_ques, input_ans, input_hist_ques, input_hist_ans, input_hist_cap,
93
+ input_ques_length, input_ans_length, input_cap_length, i_e):
94
+ """
95
+
96
+ :param input_ques:
97
+ :param input_ans:
98
+ :param input_hist_ques:
99
+ :param input_hist_ans:
100
+ :param input_hist_cap:
101
+ :param input_ques_length:
102
+ :param input_ans_length:
103
+ :param input_cap_length:
104
+ :param i_e:
105
+ :return:
106
+ """
107
+
108
+
109
+ n_options = input_ans.size()[1]
110
+ batch_size = input_ques.size()[0]
111
+
112
+
113
+
114
+ nqa_per_dial, nwords_per_qa = input_hist_ques.size()[1], input_hist_ques.size()[2]
115
+ nwords_per_cap = input_hist_cap.size()[1]
116
+ max_length_input_ans = input_ans.size()[-1]
117
+
118
+ assert batch_size == input_hist_ques.size()[0] == input_hist_ans.size()[0] == input_ques.size()[0] == \
119
+ input_ans.size()[0] == input_hist_cap.size()[0]
120
+ assert nqa_per_dial == input_hist_ques.size()[1] == input_hist_ans.size()[1]
121
+ assert nwords_per_qa == input_hist_ques.size()[2] == input_hist_ans.size()[2]
122
+
123
+ q_we = self.word_embedddings(input_ques)
124
+ a_we = self.word_embedddings(input_ans.view(-1, max_length_input_ans))
125
+ hq_we = self.word_embedddings(input_hist_ques.view(-1, nwords_per_qa))
126
+ ha_we = self.word_embedddings(input_hist_ans.view(-1, nwords_per_qa))
127
+ c_we = self.word_embedddings(input_hist_cap.view(-1, nwords_per_cap))
128
+
129
+
130
+
131
+ '''
132
+ q_we = batch x 20 x embed_ques_dim
133
+ a_we = 100*batch x 20 x embed_ans_dim
134
+ hq_we = batch*nqa_per_dial, nwords_per_qa, embed_hist_dim
135
+ ha_we = batch*nqa_per_dial, nwords_per_qa, embed_hist_dim
136
+ c_we = batch*ncap_per_dial, nwords_per_cap, embed_hist_dim
137
+ '''
138
+ self.lstm_ques.flatten_parameters()
139
+ self.lstm_ans.flatten_parameters()
140
+ self.lstm_hist_ques.flatten_parameters()
141
+ self.lstm_hist_ans.flatten_parameters()
142
+ self.lstm_hist_cap.flatten_parameters()
143
+
144
+
145
+ i_feat = i_e
146
+
147
+ q_seq, self.hidden_ques = self.lstm_ques(q_we)
148
+ a_seq, self.hidden_ans = self.lstm_ans(a_we)
149
+ hq_seq, self.hidden_hist_ques = self.lstm_hist_ques(hq_we)
150
+ ha_seq, self.hidden_hist_ans = self.lstm_hist_ans(ha_we)
151
+ cap_seq, self.hidden_cap = self.lstm_hist_cap(c_we)
152
+
153
+
154
+ '''
155
+ length is used for attention prior
156
+ '''
157
+ q_len = input_ques_length.data - 1
158
+ c_len = input_cap_length.data.view(-1) - 1
159
+
160
+
161
+ ans_index = torch.arange(0, n_options * batch_size).long().cuda()
162
+ ans_len = input_ans_length.data.view(-1) - 1
163
+ ans_seq = a_seq[ans_index, ans_len, :]
164
+ ans_seq = ans_seq.view(batch_size, n_options, self.hidden_ans_dim)
165
+
166
+ batch_index = torch.arange(0, batch_size).long().cuda()
167
+ q_prior = torch.zeros(batch_size, q_seq.size(1)).cuda()
168
+ q_prior[batch_index, q_len] = 100
169
+ c_prior = torch.zeros(batch_size, cap_seq.size(1)).cuda()
170
+ c_prior[batch_index, c_len] = 100
171
+ ans_prior = torch.ones(batch_size, ans_seq.size(1)).cuda()
172
+ img_prior = torch.ones(batch_size, i_feat.size(1)).cuda()
173
+
174
+ (ans_atten, ques_atten, cap_atten, img_atten, hq_atten, ha_atten) = \
175
+ self.mul_atten([ans_seq, q_seq, cap_seq, i_feat, hq_seq, ha_seq],
176
+ priors=[ans_prior, q_prior, c_prior, img_prior, None, None])
177
+
178
+ '''
179
+ expand to answers based
180
+ '''
181
+ ques_atten = torch.unsqueeze(ques_atten, 1).expand(batch_size,
182
+ n_options,
183
+ self.hidden_ques_dim)
184
+ cap_atten = torch.unsqueeze(cap_atten, 1).expand(batch_size,
185
+ n_options,
186
+ self.hidden_cap_dim)
187
+ img_atten = torch.unsqueeze(img_atten, 1).expand(batch_size, n_options,
188
+ self.hidden_img_dim)
189
+ ans_atten = torch.unsqueeze(ans_atten, 1).expand(batch_size, n_options,
190
+ self.hidden_ans_dim)
191
+
192
+
193
+ '''
194
+ combine history
195
+ '''
196
+
197
+ input_qahistnet = torch.cat((hq_atten, ha_atten), 1)
198
+ # input_qahistnet: (nqa_per_dial*batch x 2*hidden_hist_dim)
199
+ output_qahistnet = self.qahistnet(input_qahistnet)
200
+ # output_qahistnet: (nqa_per_dial*batch x hidden_hist_dim)
201
+ output_qahistnet = output_qahistnet.view(batch_size,
202
+ nqa_per_dial * self.hidden_hist_dim)
203
+ # output_qahistnet: (batch x nqa_per_dial*hidden_hist_dim)
204
+ output_qahistnet = torch.unsqueeze(output_qahistnet, 1)\
205
+ .expand(batch_size,
206
+ n_options,
207
+ nqa_per_dial * self.hidden_hist_dim)
208
+
209
+ input_qa = torch.cat((ans_seq, ques_atten, ans_atten, img_atten,
210
+ output_qahistnet, cap_atten), 2) # Concatenate last dimension
211
+
212
+ input_qa = input_qa.view(batch_size * n_options, self.concat_dim)
213
+
214
+ out_scores = self.simnet(input_qa)
215
+
216
+ out_scores = out_scores.squeeze(dim=1)
217
+ out_scores = out_scores.view(batch_size, n_options)
218
+
219
+ return out_scores
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ diffusers
2
+ accelerate
3
+ torchvision
4
+ transformers>=4.25.1
5
+ ftfy
6
+ tensorboard
7
+ opencv-python
8
+ Pillow
9
+ pandas
10
+ torchaudio
11
+ datasets