Heinrich Dinkel commited on
Commit
a09cac7
·
1 Parent(s): 3fbdad6

Updated GLAP, removed all dependencies to sonar.

Browse files
README.md CHANGED
@@ -42,8 +42,10 @@ library_name: glap_model
42
  ## Usage
43
 
44
 
45
- ```bash
46
- pip install glap_model
 
 
47
  ```
48
 
49
 
@@ -204,4 +206,4 @@ Title = {GLAP: General contrastive audio-text pretraining across domains and lan
204
  Year = {2025},
205
  Eprint = {arXiv:2506.11350},
206
  }
207
- ```
 
42
  ## Usage
43
 
44
 
45
+ ```python
46
+ from transformers import AutoModel
47
+ model = AutoModel.from_pretrained("mispeech/GLAP", trust_remote_code=True).eval()
48
+ print(model.score_forward(audio = torch.randn(1, 160000), text=['The sound of noise','The sound of a person']))
49
  ```
50
 
51
 
 
206
  Year = {2025},
207
  Eprint = {arXiv:2506.11350},
208
  }
209
+ ```
__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .modeling_glap import GlapModel
2
+ from .configuration_glap import GlapConfig
3
+
4
+ __all__ = ["GlapModel", "GlapConfig"]
config.json ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "GlapModel"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_glap.GlapConfig",
7
+ "AutoModel": "modeling_glap.GlapModel"
8
+ },
9
+ "model_type": "glap",
10
+ "audio_embed_dim": 768,
11
+ "audio_depth": 12,
12
+ "audio_num_heads": 12,
13
+ "patch_size": [
14
+ 64,
15
+ 4
16
+ ],
17
+ "patch_stride": [
18
+ 64,
19
+ 4
20
+ ],
21
+ "target_length": 1008,
22
+ "sample_rate": 16000,
23
+ "text_vocab_size": 256206,
24
+ "text_model_dim": 1024,
25
+ "text_num_layers": 24,
26
+ "text_num_heads": 16,
27
+ "text_ffn_inner_dim": 8192,
28
+ "text_max_seq_len": 514,
29
+ "text_pad_idx": 0,
30
+ "text_dropout_p": 0.1,
31
+ "embed_size": 1024
32
+ }
configuration_glap.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """GLAP (Generalized Language Audio Pretraining) configuration."""
2
+
3
+ from transformers import PretrainedConfig
4
+
5
+
6
+ class GlapConfig(PretrainedConfig):
7
+ model_type = "glap"
8
+
9
+ def __init__(
10
+ self,
11
+ # Audio encoder (Dasheng)
12
+ audio_embed_dim: int = 768,
13
+ audio_depth: int = 12,
14
+ audio_num_heads: int = 12,
15
+ patch_size: list = None,
16
+ patch_stride: list = None,
17
+ target_length: int = 1008,
18
+ sample_rate: int = 16000,
19
+ # Text encoder (SONAR)
20
+ text_vocab_size: int = 256206,
21
+ text_model_dim: int = 1024,
22
+ text_num_layers: int = 24,
23
+ text_num_heads: int = 16,
24
+ text_ffn_inner_dim: int = 8192,
25
+ text_max_seq_len: int = 514,
26
+ text_pad_idx: int = 0,
27
+ text_dropout_p: float = 0.1,
28
+ # Projection
29
+ embed_size: int = 1024,
30
+ **kwargs,
31
+ ):
32
+ super().__init__(**kwargs)
33
+ self.audio_embed_dim = audio_embed_dim
34
+ self.audio_depth = audio_depth
35
+ self.audio_num_heads = audio_num_heads
36
+ self.patch_size = patch_size or [64, 4]
37
+ self.patch_stride = patch_stride or [64, 4]
38
+ self.target_length = target_length
39
+ self.sample_rate = sample_rate
40
+ self.text_vocab_size = text_vocab_size
41
+ self.text_model_dim = text_model_dim
42
+ self.text_num_layers = text_num_layers
43
+ self.text_num_heads = text_num_heads
44
+ self.text_ffn_inner_dim = text_ffn_inner_dim
45
+ self.text_max_seq_len = text_max_seq_len
46
+ self.text_pad_idx = text_pad_idx
47
+ self.text_dropout_p = text_dropout_p
48
+ self.embed_size = embed_size
convert_checkpoint.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Convert GLAP checkpoint to HuggingFace safetensors format.
3
+
4
+ Usage:
5
+ python convert_checkpoint.py <input_checkpoint.pt> [output_dir]
6
+
7
+ The output_dir defaults to the current directory.
8
+ Produces: model.safetensors + config.json
9
+ """
10
+
11
+ import argparse
12
+ import json
13
+ import sys
14
+ from pathlib import Path
15
+
16
+ import torch
17
+
18
+
19
+ def convert_state_dict(old_state_dict: dict) -> dict:
20
+ """Map original GLAP state dict keys to HuggingFace format.
21
+
22
+ Original: audio_encoder.model.* (DashengWrapper wrapping dasheng_base)
23
+ HuggingFace: audio_encoder.* (DashengAudioEncoder directly)
24
+
25
+ Original: text_encoder.model.* (TextEncoderSonarWrapper wrapping SonarTextEncoder)
26
+ HuggingFace: text_encoder.* (SonarTextEncoder directly)
27
+ """
28
+ new_state_dict = {}
29
+ for key, value in old_state_dict.items():
30
+ # Skip outputlayer (Identity layer, no learnable params)
31
+ if "outputlayer" in key:
32
+ continue
33
+
34
+ # audio_encoder.model.X -> audio_encoder.X
35
+ if key.startswith("audio_encoder.model."):
36
+ new_key = "audio_encoder." + key[len("audio_encoder.model."):]
37
+ new_state_dict[new_key] = value
38
+ # text_encoder.model.X -> text_encoder.X
39
+ elif key.startswith("text_encoder.model."):
40
+ new_key = "text_encoder." + key[len("text_encoder.model."):]
41
+ new_state_dict[new_key] = value
42
+ # audio_proj.X -> audio_proj.X (unchanged)
43
+ elif key.startswith("audio_proj."):
44
+ new_state_dict[key] = value
45
+ # text_proj.X -> text_proj.X (unchanged)
46
+ elif key.startswith("text_proj."):
47
+ new_state_dict[key] = value
48
+ else:
49
+ # Unknown key, keep as-is with warning
50
+ print(f" Warning: unrecognized key: {key}", file=sys.stderr)
51
+ new_state_dict[key] = value
52
+
53
+ return new_state_dict
54
+
55
+
56
+ def extract_config(old_config: dict) -> dict:
57
+ """Extract HuggingFace config from original GLAP training config."""
58
+ model_args = old_config.get("model_args", {})
59
+
60
+ # Default values matching the pretrained model
61
+ config = {
62
+ "architectures": ["GlapModel"],
63
+ "auto_map": {
64
+ "AutoConfig": "configuration_glap.GlapConfig",
65
+ "AutoModel": "modeling_glap.GlapModel",
66
+ },
67
+ "model_type": "glap",
68
+ "audio_embed_dim": 768,
69
+ "audio_depth": 12,
70
+ "audio_num_heads": 12,
71
+ "patch_size": [64, 4],
72
+ "patch_stride": [64, 4],
73
+ "target_length": 1008,
74
+ "sample_rate": old_config.get("sample_rate", 16000),
75
+ "text_vocab_size": 256206,
76
+ "text_model_dim": 1024,
77
+ "text_num_layers": 24,
78
+ "text_num_heads": 16,
79
+ "text_ffn_inner_dim": 8192,
80
+ "text_max_seq_len": 514,
81
+ "text_pad_idx": 0,
82
+ "text_dropout_p": 0.1,
83
+ "embed_size": model_args.get("embed_size", 1024),
84
+ }
85
+ return config
86
+
87
+
88
+ def main():
89
+ parser = argparse.ArgumentParser(description="Convert GLAP checkpoint to HuggingFace format")
90
+ parser.add_argument("input", help="Path to original glap_checkpoint.pt")
91
+ parser.add_argument(
92
+ "-o", "--output-dir", default=".", help="Output directory (default: current dir)"
93
+ )
94
+ args = parser.parse_args()
95
+
96
+ input_path = Path(args.input)
97
+ output_dir = Path(args.output_dir)
98
+ output_dir.mkdir(parents=True, exist_ok=True)
99
+
100
+ print(f"Loading checkpoint from {input_path}...")
101
+ checkpoint = torch.load(str(input_path), map_location="cpu", weights_only=False)
102
+
103
+ if "model" not in checkpoint:
104
+ print("Error: checkpoint does not contain 'model' key", file=sys.stderr)
105
+ sys.exit(1)
106
+
107
+ print("Converting state dict...")
108
+ old_state_dict = checkpoint["model"]
109
+ new_state_dict = convert_state_dict(old_state_dict)
110
+
111
+ print(f" Original keys: {len(old_state_dict)}")
112
+ print(f" Converted keys: {len(new_state_dict)}")
113
+
114
+ # Save as safetensors
115
+ try:
116
+ from safetensors.torch import save_file
117
+
118
+ safetensors_path = output_dir / "model.safetensors"
119
+ print(f"Saving safetensors to {safetensors_path}...")
120
+ save_file(new_state_dict, str(safetensors_path))
121
+ print(" Done.")
122
+ except ImportError:
123
+ # Fall back to pytorch format
124
+ pt_path = output_dir / "pytorch_model.bin"
125
+ print(f"safetensors not installed, saving as {pt_path}...")
126
+ torch.save(new_state_dict, str(pt_path))
127
+ print(" Done. Install safetensors for HuggingFace compatibility: pip install safetensors")
128
+
129
+ # Save config
130
+ if "config" in checkpoint:
131
+ config = extract_config(checkpoint["config"])
132
+ else:
133
+ print("Warning: no config in checkpoint, using defaults", file=sys.stderr)
134
+ config = extract_config({})
135
+
136
+ config_path = output_dir / "config.json"
137
+ print(f"Saving config to {config_path}...")
138
+ with open(config_path, "w") as f:
139
+ json.dump(config, f, indent=2)
140
+
141
+ print("Conversion complete!")
142
+ print(f"Files in {output_dir}:")
143
+ for p in sorted(output_dir.iterdir()):
144
+ if p.suffix in (".safetensors", ".bin", ".json"):
145
+ size = p.stat().st_size
146
+ print(f" {p.name}: {size / 1024 / 1024:.1f} MB")
147
+
148
+
149
+ if __name__ == "__main__":
150
+ main()
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ee8fd92bba30d03b4c31c624739b1599a222506f0e045cde7cb51c34e3223864
3
+ size 3422036400
modeling_glap.py ADDED
@@ -0,0 +1,927 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """GLAP (Generalized Language Audio Pretraining) HuggingFace model.
2
+
3
+ Audio encoder adapted from dasheng-denoiser (Apache 2.0).
4
+ Text encoder adapted from SONAR standalone (Apache 2.0).
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import math
10
+ from pathlib import Path
11
+ from typing import List, Optional, Sequence
12
+
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+ from torch import Tensor
17
+
18
+ from einops import rearrange
19
+ from einops.layers.torch import Rearrange
20
+ from transformers import PreTrainedModel
21
+
22
+ from .configuration_glap import GlapConfig
23
+
24
+
25
+ # ============================================================================
26
+ # Audio Encoder (adapted from dasheng-denoiser/modeling_dasheng_encoder.py)
27
+ # ============================================================================
28
+
29
+
30
+ class FrontEnd(nn.Sequential):
31
+ def __init__(
32
+ self,
33
+ f_min: int = 0,
34
+ sample_rate: int = 16000,
35
+ win_size: int = 512,
36
+ center: bool = True,
37
+ n_fft: int = 512,
38
+ f_max: Optional[int] = 8000,
39
+ hop_size: int = 160,
40
+ n_mels: int = 64,
41
+ ):
42
+ audio_transforms = __import__("importlib").import_module(
43
+ "torchaudio.transforms"
44
+ )
45
+
46
+ self.f_min = f_min
47
+ self.sample_rate = sample_rate
48
+ self.win_size = win_size
49
+ self.center = center
50
+ self.n_fft = n_fft
51
+ self.f_max = f_max
52
+ self.hop_size = hop_size
53
+ self.n_mels = n_mels
54
+
55
+ with torch.device("cpu"):
56
+ super().__init__(
57
+ audio_transforms.MelSpectrogram(
58
+ f_min=self.f_min,
59
+ sample_rate=self.sample_rate,
60
+ win_length=self.win_size,
61
+ center=self.center,
62
+ n_fft=self.n_fft,
63
+ f_max=self.f_max,
64
+ hop_length=self.hop_size,
65
+ n_mels=self.n_mels,
66
+ ),
67
+ audio_transforms.AmplitudeToDB(top_db=120),
68
+ )
69
+
70
+ @torch.autocast(enabled=False, device_type="cuda")
71
+ def forward(self, x, attention_mask=None):
72
+ features = super().forward(x)
73
+ if attention_mask is not None:
74
+ lengths = attention_mask.float().sum(-1) // self.hop_size
75
+ attention_mask = (
76
+ torch.arange(features.shape[-1], device=features.device)
77
+ < lengths.unsqueeze(-1)
78
+ ).int()
79
+ return features, attention_mask
80
+
81
+
82
+ class Mlp(nn.Module):
83
+ def __init__(
84
+ self,
85
+ in_features: int,
86
+ hidden_features: Optional[int] = None,
87
+ out_features: Optional[int] = None,
88
+ act_layer: type[nn.Module] = nn.GELU,
89
+ drop: float = 0.0,
90
+ ):
91
+ super().__init__()
92
+ out_features = out_features or in_features
93
+ hidden_features = hidden_features or in_features
94
+ self.fc1 = nn.Linear(in_features, hidden_features)
95
+ self.act = act_layer()
96
+ self.fc2 = nn.Linear(hidden_features, out_features)
97
+ self.drop = nn.Dropout(drop)
98
+
99
+ def forward(self, x):
100
+ x = self.fc1(x)
101
+ x = self.act(x)
102
+ x = self.drop(x)
103
+ x = self.fc2(x)
104
+ x = self.drop(x)
105
+ return x
106
+
107
+
108
+ class AudioAttention(nn.Module):
109
+ def __init__(
110
+ self,
111
+ dim: int,
112
+ num_heads: int = 8,
113
+ qkv_bias: bool = True,
114
+ attn_drop: float = 0.0,
115
+ proj_drop: float = 0.0,
116
+ ):
117
+ super().__init__()
118
+ self.num_heads = num_heads
119
+ self.scale = (dim // num_heads) ** -0.5
120
+
121
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
122
+ self.attn_drop = nn.Dropout(attn_drop)
123
+ self.proj = nn.Linear(dim, dim)
124
+ self.proj_drop = nn.Dropout(proj_drop)
125
+
126
+ def forward(self, x, mask: Optional[torch.Tensor] = None):
127
+ B, N, C = x.shape
128
+ qkv = (
129
+ self.qkv(x)
130
+ .reshape(B, N, 3, self.num_heads, C // self.num_heads)
131
+ .permute(2, 0, 3, 1, 4)
132
+ )
133
+ q, k, v = qkv[0], qkv[1], qkv[2]
134
+
135
+ attn = (q @ k.transpose(-2, -1)) * self.scale
136
+
137
+ if mask is not None:
138
+ if mask.dtype != torch.bool:
139
+ padding_mask = mask == 0
140
+ else:
141
+ padding_mask = mask
142
+ padding_mask = padding_mask.view(B, 1, 1, N)
143
+ attn = attn.masked_fill(padding_mask, float("-inf"))
144
+
145
+ attn = attn.softmax(dim=-1).nan_to_num()
146
+ attn = self.attn_drop(attn)
147
+
148
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
149
+ return self.proj_drop(self.proj(x))
150
+
151
+
152
+ class AudioBlock(nn.Module):
153
+ def __init__(
154
+ self,
155
+ dim: int,
156
+ num_heads: int,
157
+ mlp_ratio: float = 4.0,
158
+ qkv_bias: bool = True,
159
+ drop: float = 0.0,
160
+ attn_drop: float = 0.0,
161
+ ):
162
+ super().__init__()
163
+ self.norm1 = nn.LayerNorm(dim, eps=1e-6)
164
+ self.attn = AudioAttention(dim, num_heads, qkv_bias, attn_drop, drop)
165
+ self.norm2 = nn.LayerNorm(dim, eps=1e-6)
166
+ self.mlp = Mlp(
167
+ in_features=dim,
168
+ hidden_features=int(dim * mlp_ratio),
169
+ act_layer=nn.GELU,
170
+ drop=drop,
171
+ )
172
+
173
+ def forward(self, x, mask=None):
174
+ x = x + self.attn(self.norm1(x), mask=mask)
175
+ x = x + self.mlp(self.norm2(x))
176
+ return x
177
+
178
+
179
+ class AudioPatchEmbed(nn.Module):
180
+ def __init__(self, *args, **kwargs):
181
+ super().__init__()
182
+ self.stride = kwargs.get("stride", [None, 4])[-1]
183
+ self.proj = nn.Conv2d(*args, **kwargs)
184
+
185
+ def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None):
186
+ x = self.proj(x)
187
+ if attention_mask is not None:
188
+ lengths = attention_mask.float().sum(-1) // self.stride
189
+ attention_mask = (
190
+ torch.arange(x.shape[-1], device=x.device) < lengths.unsqueeze(-1)
191
+ ).int()
192
+ return x, attention_mask
193
+
194
+
195
+ class DashengAudioEncoder(nn.Module):
196
+ """Dasheng audio encoder matching the original DashengWrapper.
197
+
198
+ Produces a single (B, embed_dim) embedding per audio input.
199
+ Pads spectrogram to a multiple of target_length, splits into chunks,
200
+ processes each chunk independently through the Transformer, then
201
+ mean-pools across chunks.
202
+ """
203
+
204
+ def __init__(
205
+ self,
206
+ embed_dim: int = 768,
207
+ depth: int = 12,
208
+ num_heads: int = 12,
209
+ patch_size: list = None,
210
+ patch_stride: list = None,
211
+ target_length: int = 1008,
212
+ ):
213
+ super().__init__()
214
+ patch_size = patch_size or [64, 4]
215
+ patch_stride = patch_stride or [64, 4]
216
+ self.embed_dim = embed_dim
217
+ self.target_length = target_length
218
+ self.patch_stride = patch_stride
219
+ self.time_patches = patch_stride[-1]
220
+ self.max_t_tokens = target_length // self.time_patches
221
+
222
+ self.front_end = FrontEnd()
223
+ self.patch_embed = AudioPatchEmbed(
224
+ 1, embed_dim, kernel_size=patch_size, stride=patch_stride
225
+ )
226
+ self.init_bn = nn.Sequential(
227
+ Rearrange("b c f t -> b f c t"),
228
+ nn.BatchNorm2d(self.front_end.n_mels, momentum=0.01),
229
+ Rearrange("b f c t -> b c f t"),
230
+ )
231
+
232
+ self.time_pos_embed = nn.Parameter(
233
+ torch.randn(1, embed_dim, 1, target_length // self.time_patches) * 0.02
234
+ )
235
+ self.freq_pos_embed = nn.Parameter(torch.randn(1, embed_dim, 1, 1) * 0.02)
236
+
237
+ self.blocks = nn.ModuleList(
238
+ [AudioBlock(embed_dim, num_heads) for _ in range(depth)]
239
+ )
240
+ self.norm = nn.LayerNorm(embed_dim, eps=1e-6)
241
+
242
+ def _forward_chunk(self, x, attention_mask=None):
243
+ x, attention_mask = self.patch_embed(x, attention_mask)
244
+ t = x.shape[-1]
245
+ x = x + self.time_pos_embed[:, :, :, :t] + self.freq_pos_embed
246
+ x = rearrange(x, "b c f t -> b (f t) c")
247
+ for block in self.blocks:
248
+ x = block(x, mask=attention_mask)
249
+ x = self.norm(x)
250
+ return x
251
+
252
+ def forward(
253
+ self,
254
+ x: torch.Tensor,
255
+ attention_mask: Optional[torch.Tensor] = None,
256
+ ) -> torch.Tensor:
257
+ # Compute spectrogram
258
+ x, attention_mask = self.front_end(x, attention_mask)
259
+ x = rearrange(x, "b f t -> b 1 f t")
260
+ x = self.init_bn(x)
261
+
262
+ # Pad spectrogram time dim to next multiple of target_length
263
+ if x.shape[-1] > self.target_length:
264
+ remainder = x.shape[-1] % self.target_length
265
+ if remainder != 0:
266
+ pad_amount = self.target_length - remainder
267
+ x = F.pad(x, (0, pad_amount))
268
+
269
+ # Split into chunks along time dimension
270
+ input_splits = x.split(self.target_length, dim=-1)
271
+ masks = [None for _ in range(len(input_splits))]
272
+
273
+ # Process each chunk independently
274
+ outputs = []
275
+ chunk_size_in_patches = self.target_length // self.patch_stride[-1]
276
+ for input_split_x in input_splits:
277
+ output = self._forward_chunk(input_split_x, attention_mask=None)
278
+ # Mean pool each chunk: (B, num_patches, embed_dim) -> (B, embed_dim)
279
+ chunks = output.split(chunk_size_in_patches, dim=1)
280
+ chunk_means = [c.mean(1) for c in chunks]
281
+ outputs.append(torch.stack(chunk_means).mean(0))
282
+
283
+ # Mean across all split outputs
284
+ emb = torch.stack(outputs).mean(0)
285
+ return emb
286
+
287
+
288
+ # ============================================================================
289
+ # Text Encoder (adapted from dasheng-glap SONAR standalone)
290
+ # ============================================================================
291
+
292
+
293
+ class SinusoidalPositionEncoder(nn.Module):
294
+ def __init__(self, encoding_dim: int, max_seq_len: int, _legacy_pad_idx: int = 1):
295
+ super().__init__()
296
+ assert encoding_dim % 2 == 0
297
+ self.encoding_dim = encoding_dim
298
+ self.max_seq_len = max_seq_len
299
+ self._legacy_pad_idx = _legacy_pad_idx
300
+ start_step = 1 + _legacy_pad_idx
301
+ steps = torch.arange(start_step, start_step + max_seq_len, dtype=torch.float32)
302
+ self.register_buffer(
303
+ "freqs", self._build_freqs(steps, encoding_dim), persistent=False
304
+ )
305
+
306
+ @staticmethod
307
+ def _build_freqs(steps: Tensor, encoding_dim: int) -> Tensor:
308
+ num_sin = encoding_dim // 2
309
+ indices = torch.arange(num_sin, dtype=torch.float32)
310
+ freq_vals = torch.exp(indices * -math.log(10000.0) / (num_sin - 1))
311
+ l_half = torch.outer(steps, freq_vals)
312
+ r_half = l_half[:, : encoding_dim - num_sin].clone()
313
+ return torch.cat([l_half.sin(), r_half.cos()], dim=-1)
314
+
315
+ def forward(self, seqs: Tensor) -> Tensor:
316
+ seq_len = seqs.size(-2)
317
+ return (seqs.float() + self.freqs[:seq_len]).type_as(seqs)
318
+
319
+
320
+ class SonarMultiheadAttention(nn.Module):
321
+ def __init__(self, model_dim: int, num_heads: int, dropout_p: float = 0.0):
322
+ super().__init__()
323
+ self.model_dim = model_dim
324
+ self.num_heads = num_heads
325
+ self.head_dim = model_dim // num_heads
326
+ assert model_dim % num_heads == 0
327
+
328
+ self.q_proj = nn.Linear(model_dim, model_dim, bias=True)
329
+ self.k_proj = nn.Linear(model_dim, model_dim, bias=True)
330
+ self.v_proj = nn.Linear(model_dim, model_dim, bias=True)
331
+ self.output_proj = nn.Linear(model_dim, model_dim, bias=True)
332
+ self.attn_dropout_p = dropout_p
333
+
334
+ def forward(
335
+ self,
336
+ queries: Tensor,
337
+ keys: Tensor,
338
+ values: Tensor,
339
+ padding_mask: Optional[Tensor] = None,
340
+ ) -> Tensor:
341
+ bsz, seq_len, _ = queries.shape
342
+
343
+ q = (
344
+ self.q_proj(queries)
345
+ .view(bsz, seq_len, self.num_heads, self.head_dim)
346
+ .transpose(1, 2)
347
+ )
348
+ k = (
349
+ self.k_proj(keys)
350
+ .view(bsz, -1, self.num_heads, self.head_dim)
351
+ .transpose(1, 2)
352
+ )
353
+ v = (
354
+ self.v_proj(values)
355
+ .view(bsz, -1, self.num_heads, self.head_dim)
356
+ .transpose(1, 2)
357
+ )
358
+
359
+ scale = self.head_dim**-0.5
360
+ attn_weights = torch.matmul(q, k.transpose(-2, -1)) * scale
361
+
362
+ if padding_mask is not None:
363
+ attn_weights = attn_weights.masked_fill(
364
+ padding_mask[:, None, None, :], float("-inf")
365
+ )
366
+
367
+ attn_weights = F.softmax(attn_weights.float(), dim=-1).type_as(attn_weights)
368
+
369
+ if self.training and self.attn_dropout_p > 0.0:
370
+ attn_weights = F.dropout(attn_weights, p=self.attn_dropout_p)
371
+
372
+ attn = torch.matmul(attn_weights, v)
373
+ attn = attn.transpose(1, 2).contiguous().view(bsz, seq_len, self.model_dim)
374
+ return self.output_proj(attn)
375
+
376
+
377
+ class _FeedForwardNetwork(nn.Module):
378
+ def __init__(self, model_dim: int, inner_dim: int, dropout_p: float = 0.1):
379
+ super().__init__()
380
+ self.inner_proj = nn.Linear(model_dim, inner_dim, bias=True)
381
+ self.output_proj = nn.Linear(inner_dim, model_dim, bias=True)
382
+ self.dropout = nn.Dropout(dropout_p)
383
+
384
+ def forward(self, x: Tensor) -> Tensor:
385
+ x = self.inner_proj(x)
386
+ x = F.relu(x)
387
+ x = self.dropout(x)
388
+ x = self.output_proj(x)
389
+ return x
390
+
391
+
392
+ class SonarTransformerEncoderLayer(nn.Module):
393
+ def __init__(
394
+ self, model_dim: int, num_heads: int, ffn_inner_dim: int, dropout_p: float = 0.1
395
+ ):
396
+ super().__init__()
397
+ self.self_attn_layer_norm = nn.LayerNorm(model_dim)
398
+ self.self_attn = SonarMultiheadAttention(
399
+ model_dim, num_heads, dropout_p=dropout_p
400
+ )
401
+ self.ffn_layer_norm = nn.LayerNorm(model_dim)
402
+ self.ffn = _FeedForwardNetwork(model_dim, ffn_inner_dim, dropout_p)
403
+
404
+ def forward(self, seqs: Tensor, padding_mask: Optional[Tensor] = None) -> Tensor:
405
+ residual = seqs
406
+ seqs = self.self_attn_layer_norm(seqs)
407
+ seqs = self.self_attn(seqs, seqs, seqs, padding_mask)
408
+ seqs = seqs + residual
409
+
410
+ residual = seqs
411
+ seqs = self.ffn_layer_norm(seqs)
412
+ seqs = self.ffn(seqs)
413
+ seqs = seqs + residual
414
+ return seqs
415
+
416
+
417
+ class _SonarTransformerEncoder(nn.Module):
418
+ def __init__(self, layers: nn.ModuleList):
419
+ super().__init__()
420
+ self.layers = layers
421
+
422
+ def forward(self, seqs: Tensor, padding_mask: Optional[Tensor] = None) -> Tensor:
423
+ for layer in self.layers:
424
+ seqs = layer(seqs, padding_mask)
425
+ return seqs
426
+
427
+
428
+ class _SonarEmbeddingFrontend(nn.Module):
429
+ def __init__(
430
+ self,
431
+ embed: nn.Embedding,
432
+ pos_encoder: SinusoidalPositionEncoder,
433
+ dropout_p: float = 0.1,
434
+ ):
435
+ super().__init__()
436
+ self.embed = embed
437
+ self.pos_encoder = pos_encoder
438
+ self.dropout = nn.Dropout(dropout_p)
439
+
440
+ def forward(self, token_ids: Tensor) -> Tensor:
441
+ seqs = self.embed(token_ids)
442
+ seqs = seqs * math.sqrt(seqs.size(-1))
443
+ seqs = self.pos_encoder(seqs)
444
+ seqs = self.dropout(seqs)
445
+ return seqs
446
+
447
+
448
+ class SonarTextEncoder(nn.Module):
449
+ """24-layer SONAR text encoder with sinusoidal PE and mean pooling."""
450
+
451
+ def __init__(
452
+ self,
453
+ vocab_size: int = 256206,
454
+ model_dim: int = 1024,
455
+ num_layers: int = 24,
456
+ num_heads: int = 16,
457
+ ffn_inner_dim: int = 8192,
458
+ max_seq_len: int = 514,
459
+ pad_idx: int = 0,
460
+ dropout_p: float = 0.1,
461
+ ):
462
+ super().__init__()
463
+ self.model_dim = model_dim
464
+ self.pad_idx = pad_idx
465
+
466
+ embed = nn.Embedding(vocab_size, model_dim, padding_idx=pad_idx)
467
+ pos_encoder = SinusoidalPositionEncoder(
468
+ model_dim, max_seq_len, _legacy_pad_idx=1
469
+ )
470
+ self.encoder_frontend = _SonarEmbeddingFrontend(embed, pos_encoder, dropout_p)
471
+
472
+ layers = nn.ModuleList(
473
+ [
474
+ SonarTransformerEncoderLayer(
475
+ model_dim, num_heads, ffn_inner_dim, dropout_p
476
+ )
477
+ for _ in range(num_layers)
478
+ ]
479
+ )
480
+ self.encoder = _SonarTransformerEncoder(layers)
481
+ self.layer_norm = nn.LayerNorm(model_dim)
482
+
483
+ def forward(
484
+ self,
485
+ token_ids: Tensor,
486
+ padding_mask: Optional[Tensor] = None,
487
+ ) -> Tensor:
488
+ seqs = self.encoder_frontend(token_ids)
489
+ seqs = self.encoder(seqs, padding_mask)
490
+ seqs = self.layer_norm(seqs)
491
+
492
+ if padding_mask is None:
493
+ sentence_embeddings = seqs.sum(dim=1) / (seqs.size(1) + 1e-7)
494
+ else:
495
+ mask = (~padding_mask).unsqueeze(-1).float()
496
+ seqs = seqs * mask
497
+ lengths = mask.sum(dim=1).clamp(min=1e-7)
498
+ sentence_embeddings = seqs.sum(dim=1) / lengths
499
+
500
+ return sentence_embeddings
501
+
502
+
503
+ # ============================================================================
504
+ # Tokenizer
505
+ # ============================================================================
506
+
507
+
508
+ class NllbTokenizer:
509
+ """Standalone NLLB tokenizer using sentencepiece."""
510
+
511
+ def __init__(self, model_path: str | Path, langs: Optional[List[str]] = None):
512
+ try:
513
+ import sentencepiece as spm
514
+ except ImportError:
515
+ raise ImportError("sentencepiece is required: pip install sentencepiece")
516
+
517
+ self.sp = spm.SentencePieceProcessor()
518
+ if not self.sp.load(str(model_path)):
519
+ raise RuntimeError(f"Failed to load SentencePiece model from {model_path}")
520
+
521
+ self.pad_idx = 0
522
+ self.unk_idx = 1
523
+ self.bos_idx = 2
524
+ self.eos_idx = 3
525
+
526
+ self._lang_token_to_idx = _NLLB_LANG_TOKEN_IDS
527
+
528
+ @property
529
+ def vocab_size(self) -> int:
530
+ return self.sp.get_piece_size() + 206
531
+
532
+ def create_encoder(self, lang: str = "eng_Latn"):
533
+ lang_idx = self._lang_token_to_idx.get(lang)
534
+ eos_idx = self.eos_idx
535
+
536
+ def encode(text: str) -> List[int]:
537
+ spm_ids = self.sp.encode(text, out_type=int)
538
+ content_ids = [tid + 1 for tid in spm_ids]
539
+ if lang_idx is not None:
540
+ token_ids = [lang_idx] + content_ids
541
+ else:
542
+ token_ids = content_ids
543
+ token_ids.append(eos_idx)
544
+ return token_ids
545
+
546
+ return encode
547
+
548
+
549
+ # Pre-computed NLLB language -> token ID mapping
550
+ _NLLB_LANG_TOKEN_IDS = {
551
+ "ace_Arab": 256001,
552
+ "ace_Latn": 256002,
553
+ "acm_Arab": 256003,
554
+ "acq_Arab": 256004,
555
+ "aeb_Arab": 256005,
556
+ "afr_Latn": 256006,
557
+ "ajp_Arab": 256007,
558
+ "aka_Latn": 256008,
559
+ "amh_Ethi": 256009,
560
+ "apc_Arab": 256010,
561
+ "arb_Arab": 256011,
562
+ "ars_Arab": 256012,
563
+ "ary_Arab": 256013,
564
+ "arz_Arab": 256014,
565
+ "asm_Beng": 256015,
566
+ "ast_Latn": 256016,
567
+ "awa_Deva": 256017,
568
+ "ayr_Latn": 256018,
569
+ "azb_Arab": 256019,
570
+ "azj_Latn": 256020,
571
+ "bak_Cyrl": 256021,
572
+ "bam_Latn": 256022,
573
+ "ban_Latn": 256023,
574
+ "bel_Cyrl": 256024,
575
+ "bem_Latn": 256025,
576
+ "ben_Beng": 256026,
577
+ "bho_Deva": 256027,
578
+ "bjn_Arab": 256028,
579
+ "bjn_Latn": 256029,
580
+ "bod_Tibt": 256030,
581
+ "bos_Latn": 256031,
582
+ "bug_Latn": 256032,
583
+ "bul_Cyrl": 256033,
584
+ "cat_Latn": 256034,
585
+ "ceb_Latn": 256035,
586
+ "ces_Latn": 256036,
587
+ "cjk_Latn": 256037,
588
+ "ckb_Arab": 256038,
589
+ "crh_Latn": 256039,
590
+ "cym_Latn": 256040,
591
+ "dan_Latn": 256041,
592
+ "deu_Latn": 256042,
593
+ "dik_Latn": 256043,
594
+ "dyu_Latn": 256044,
595
+ "dzo_Tibt": 256045,
596
+ "ell_Grek": 256046,
597
+ "eng_Latn": 256047,
598
+ "epo_Latn": 256048,
599
+ "est_Latn": 256049,
600
+ "eus_Latn": 256050,
601
+ "ewe_Latn": 256051,
602
+ "fao_Latn": 256052,
603
+ "pes_Arab": 256053,
604
+ "fij_Latn": 256054,
605
+ "fin_Latn": 256055,
606
+ "fon_Latn": 256056,
607
+ "fra_Latn": 256057,
608
+ "fur_Latn": 256058,
609
+ "fuv_Latn": 256059,
610
+ "gla_Latn": 256060,
611
+ "gle_Latn": 256061,
612
+ "glg_Latn": 256062,
613
+ "grn_Latn": 256063,
614
+ "guj_Gujr": 256064,
615
+ "hat_Latn": 256065,
616
+ "hau_Latn": 256066,
617
+ "heb_Hebr": 256067,
618
+ "hin_Deva": 256068,
619
+ "hne_Deva": 256069,
620
+ "hrv_Latn": 256070,
621
+ "hun_Latn": 256071,
622
+ "hye_Armn": 256072,
623
+ "ibo_Latn": 256073,
624
+ "ilo_Latn": 256074,
625
+ "ind_Latn": 256075,
626
+ "isl_Latn": 256076,
627
+ "ita_Latn": 256077,
628
+ "jav_Latn": 256078,
629
+ "jpn_Jpan": 256079,
630
+ "kab_Latn": 256080,
631
+ "kac_Latn": 256081,
632
+ "kam_Latn": 256082,
633
+ "kan_Knda": 256083,
634
+ "kas_Arab": 256084,
635
+ "kas_Deva": 256085,
636
+ "kat_Geor": 256086,
637
+ "knc_Arab": 256087,
638
+ "knc_Latn": 256088,
639
+ "kaz_Cyrl": 256089,
640
+ "kbp_Latn": 256090,
641
+ "kea_Latn": 256091,
642
+ "khm_Khmr": 256092,
643
+ "kik_Latn": 256093,
644
+ "kin_Latn": 256094,
645
+ "kir_Cyrl": 256095,
646
+ "kmb_Latn": 256096,
647
+ "kon_Latn": 256097,
648
+ "kor_Hang": 256098,
649
+ "kmr_Latn": 256099,
650
+ "lao_Laoo": 256100,
651
+ "lvs_Latn": 256101,
652
+ "lij_Latn": 256102,
653
+ "lim_Latn": 256103,
654
+ "lin_Latn": 256104,
655
+ "lit_Latn": 256105,
656
+ "lmo_Latn": 256106,
657
+ "ltg_Latn": 256107,
658
+ "ltz_Latn": 256108,
659
+ "lua_Latn": 256109,
660
+ "lug_Latn": 256110,
661
+ "luo_Latn": 256111,
662
+ "lus_Latn": 256112,
663
+ "mag_Deva": 256113,
664
+ "mai_Deva": 256114,
665
+ "mal_Mlym": 256115,
666
+ "mar_Deva": 256116,
667
+ "min_Latn": 256117,
668
+ "mkd_Cyrl": 256118,
669
+ "plt_Latn": 256119,
670
+ "mlt_Latn": 256120,
671
+ "mni_Beng": 256121,
672
+ "khk_Cyrl": 256122,
673
+ "mos_Latn": 256123,
674
+ "mri_Latn": 256124,
675
+ "zsm_Latn": 256125,
676
+ "mya_Mymr": 256126,
677
+ "nld_Latn": 256127,
678
+ "nno_Latn": 256128,
679
+ "nob_Latn": 256129,
680
+ "npi_Deva": 256130,
681
+ "nso_Latn": 256131,
682
+ "nus_Latn": 256132,
683
+ "nya_Latn": 256133,
684
+ "oci_Latn": 256134,
685
+ "gaz_Latn": 256135,
686
+ "ory_Orya": 256136,
687
+ "pag_Latn": 256137,
688
+ "pan_Guru": 256138,
689
+ "pap_Latn": 256139,
690
+ "pol_Latn": 256140,
691
+ "por_Latn": 256141,
692
+ "prs_Arab": 256142,
693
+ "pbt_Arab": 256143,
694
+ "quy_Latn": 256144,
695
+ "ron_Latn": 256145,
696
+ "run_Latn": 256146,
697
+ "rus_Cyrl": 256147,
698
+ "sag_Latn": 256148,
699
+ "san_Deva": 256149,
700
+ "sat_Beng": 256150,
701
+ "scn_Latn": 256151,
702
+ "shn_Mymr": 256152,
703
+ "sin_Sinh": 256153,
704
+ "slk_Latn": 256154,
705
+ "slv_Latn": 256155,
706
+ "smo_Latn": 256156,
707
+ "sna_Latn": 256157,
708
+ "snd_Arab": 256158,
709
+ "som_Latn": 256159,
710
+ "sot_Latn": 256160,
711
+ "spa_Latn": 256161,
712
+ "als_Latn": 256162,
713
+ "srd_Latn": 256163,
714
+ "srp_Cyrl": 256164,
715
+ "ssw_Latn": 256165,
716
+ "sun_Latn": 256166,
717
+ "swe_Latn": 256167,
718
+ "swh_Latn": 256168,
719
+ "szl_Latn": 256169,
720
+ "tam_Taml": 256170,
721
+ "tat_Cyrl": 256171,
722
+ "tel_Telu": 256172,
723
+ "tgk_Cyrl": 256173,
724
+ "tgl_Latn": 256174,
725
+ "tha_Thai": 256175,
726
+ "tir_Ethi": 256176,
727
+ "taq_Latn": 256177,
728
+ "taq_Tfng": 256178,
729
+ "tpi_Latn": 256179,
730
+ "tsn_Latn": 256180,
731
+ "tso_Latn": 256181,
732
+ "tuk_Latn": 256182,
733
+ "tum_Latn": 256183,
734
+ "tur_Latn": 256184,
735
+ "twi_Latn": 256185,
736
+ "tzm_Tfng": 256186,
737
+ "uig_Arab": 256187,
738
+ "ukr_Cyrl": 256188,
739
+ "umb_Latn": 256189,
740
+ "urd_Arab": 256190,
741
+ "uzn_Latn": 256191,
742
+ "vec_Latn": 256192,
743
+ "vie_Latn": 256193,
744
+ "war_Latn": 256194,
745
+ "wol_Latn": 256195,
746
+ "xho_Latn": 256196,
747
+ "ydd_Hebr": 256197,
748
+ "yor_Latn": 256198,
749
+ "yue_Hant": 256199,
750
+ "zho_Hans": 256200,
751
+ "zho_Hant": 256201,
752
+ "zul_Latn": 256202,
753
+ }
754
+
755
+
756
+ # ============================================================================
757
+ # GLAP Model
758
+ # ============================================================================
759
+
760
+
761
+ class GlapModel(PreTrainedModel):
762
+ config_class = GlapConfig
763
+
764
+ def __init__(self, config: GlapConfig):
765
+ super().__init__(config)
766
+ self.config = config
767
+
768
+ # Audio encoder
769
+ self.audio_encoder = DashengAudioEncoder(
770
+ embed_dim=config.audio_embed_dim,
771
+ depth=config.audio_depth,
772
+ num_heads=config.audio_num_heads,
773
+ patch_size=config.patch_size,
774
+ patch_stride=config.patch_stride,
775
+ target_length=config.target_length,
776
+ )
777
+
778
+ # Text encoder
779
+ self.text_encoder = SonarTextEncoder(
780
+ vocab_size=config.text_vocab_size,
781
+ model_dim=config.text_model_dim,
782
+ num_layers=config.text_num_layers,
783
+ num_heads=config.text_num_heads,
784
+ ffn_inner_dim=config.text_ffn_inner_dim,
785
+ max_seq_len=config.text_max_seq_len,
786
+ pad_idx=config.text_pad_idx,
787
+ dropout_p=config.text_dropout_p,
788
+ )
789
+
790
+ # Projection layers
791
+ self.audio_proj = nn.Sequential(
792
+ nn.Linear(config.audio_embed_dim, config.embed_size),
793
+ nn.ReLU(),
794
+ nn.Linear(config.embed_size, config.embed_size),
795
+ )
796
+ self.text_proj = nn.Sequential(
797
+ nn.Linear(config.text_model_dim, config.embed_size),
798
+ nn.ReLU(),
799
+ nn.Linear(config.embed_size, config.embed_size),
800
+ )
801
+
802
+ self.tokenizer: Optional[NllbTokenizer] = None
803
+ self.post_init()
804
+
805
+ def _init_weights(self, module):
806
+ if isinstance(module, SinusoidalPositionEncoder):
807
+ with torch.no_grad():
808
+ start_step = 1 + module._legacy_pad_idx
809
+ steps = torch.arange(
810
+ start_step,
811
+ start_step + module.max_seq_len,
812
+ dtype=torch.float32,
813
+ )
814
+ module.freqs.copy_(module._build_freqs(steps, module.encoding_dim))
815
+
816
+ def _get_tokenizer(self) -> NllbTokenizer:
817
+ if self.tokenizer is None:
818
+ # Find the model directory: HuggingFace copies .py files to its cache
819
+ # but not .model files, so we use config._name_or_path (the original
820
+ # model path) to locate the tokenizer.
821
+ model_dir = Path(self.config._name_or_path)
822
+ if not model_dir.is_dir():
823
+ model_dir = Path(__file__).parent
824
+ tokenizer_path = model_dir / "sentencepiece.source.256000.model"
825
+ if not tokenizer_path.exists():
826
+ tokenizer_path = (
827
+ Path(__file__).parent / "sentencepiece.source.256000.model"
828
+ )
829
+ self.tokenizer = NllbTokenizer(tokenizer_path)
830
+ return self.tokenizer
831
+
832
+ def encode_audio(
833
+ self,
834
+ audio: torch.Tensor,
835
+ audio_length: Optional[torch.Tensor] = None,
836
+ ) -> torch.Tensor:
837
+ audio_embeds = self.audio_encoder(audio)
838
+ audio_embeds = F.normalize(self.audio_proj(audio_embeds), dim=-1)
839
+ return audio_embeds
840
+
841
+ def encode_text(
842
+ self,
843
+ text: Sequence[str],
844
+ source_lang: str = "eng_Latn",
845
+ ) -> torch.Tensor:
846
+ tokenizer = self._get_tokenizer()
847
+ encoder_fn = tokenizer.create_encoder(lang=source_lang)
848
+
849
+ all_token_ids: List[List[int]] = []
850
+ max_seq_len = self.config.text_max_seq_len
851
+ for t in text:
852
+ token_ids = encoder_fn(t)[:max_seq_len]
853
+ all_token_ids.append(token_ids)
854
+
855
+ max_len = max(len(ids) for ids in all_token_ids) if all_token_ids else 0
856
+ batch_size = len(all_token_ids)
857
+
858
+ device = self.audio_proj[0].weight.device
859
+
860
+ padded_ids = torch.full(
861
+ (batch_size, max_len),
862
+ tokenizer.pad_idx,
863
+ dtype=torch.long,
864
+ device="cpu",
865
+ )
866
+ padding_mask = torch.zeros(batch_size, max_len, dtype=torch.bool, device="cpu")
867
+
868
+ for i, ids in enumerate(all_token_ids):
869
+ length = len(ids)
870
+ padded_ids[i, :length] = torch.tensor(ids, dtype=torch.long)
871
+ padding_mask[i, length:] = True
872
+
873
+ self.text_encoder.eval()
874
+ with torch.no_grad():
875
+ sentence_embeddings = self.text_encoder(padded_ids, padding_mask)
876
+
877
+ text_embeds = F.normalize(
878
+ self.text_proj(sentence_embeddings.to(device)), dim=-1
879
+ )
880
+ return text_embeds
881
+
882
+ def get_audio_features(
883
+ self,
884
+ audio: torch.Tensor,
885
+ audio_length: Optional[torch.Tensor] = None,
886
+ **kwargs,
887
+ ) -> torch.Tensor:
888
+ return self.encode_audio(audio, audio_length)
889
+
890
+ def get_text_features(
891
+ self,
892
+ text: Sequence[str],
893
+ source_lang: str = "eng_Latn",
894
+ **kwargs,
895
+ ) -> torch.Tensor:
896
+ return self.encode_text(text, source_lang=source_lang)
897
+
898
+ def forward(
899
+ self,
900
+ audio: Optional[torch.Tensor] = None,
901
+ text: Optional[Sequence[str]] = None,
902
+ audio_length: Optional[torch.Tensor] = None,
903
+ source_lang: str = "eng_Latn",
904
+ **kwargs,
905
+ ):
906
+ audio_embeds = None
907
+ text_embeds = None
908
+
909
+ if audio is not None:
910
+ audio_embeds = self.encode_audio(audio, audio_length)
911
+ if text is not None:
912
+ text_embeds = self.encode_text(text, source_lang=source_lang)
913
+
914
+ return audio_embeds, text_embeds
915
+
916
+ def score(self, audio_emb: torch.Tensor, text_emb: torch.Tensor) -> torch.Tensor:
917
+ return 100 * (audio_emb @ text_emb.T)
918
+
919
+ def score_forward(
920
+ self,
921
+ audio: torch.Tensor,
922
+ text: Sequence[str],
923
+ audio_length: Optional[torch.Tensor] = None,
924
+ source_lang: str = "eng_Latn",
925
+ ) -> torch.Tensor:
926
+ audio_emb, text_emb = self.forward(audio, text, audio_length, source_lang)
927
+ return self.score(audio_emb, text_emb)
sentencepiece.source.256000.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:14bb8dfb35c0ffdea7bc01e56cea38b9e3d5efcdcb9c251d6b40538e1aab555a
3
+ size 4852054