klemenk commited on
Commit
c07a579
·
verified ·
1 Parent(s): f61bcb7

Upload AuriStream Parallel base model code

Browse files
README.md ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ tags:
4
+ - audio
5
+ - speech
6
+ - language-model
7
+ - auristream
8
+ - discrete-diffusion
9
+ library_name: transformers
10
+ ---
11
+
12
+ # AuriStream Parallel - Speech Language Model
13
+
14
+ **AuriStream Parallel** is a discrete diffusion speech language model by **Greta Tuckute** and **Klemen Kotar**.
15
+
16
+ This repository contains shared model code for AuriStream Parallel checkpoints.
17
+
18
+ ## Overview
19
+
20
+ AuriStream Parallel uses:
21
+ - bidirectional transformer attention
22
+ - grouped token projection (`group_size=4` by default)
23
+ - parallel token heads
24
+ - partial-masking diffusion objective
25
+
26
+ ## Usage
27
+
28
+ Load a checkpoint repository that references this base code:
29
+
30
+ ```python
31
+ from transformers import AutoModel
32
+
33
+ model = AutoModel.from_pretrained(
34
+ "TuKoResearch/AuriStreamParallel100M_Group4_BigAudioDataset_180k",
35
+ trust_remote_code=True,
36
+ )
37
+ ```
38
+
39
+ ## Files
40
+
41
+ - `configuration_auristream_parallel.py` - Configuration class
42
+ - `modeling_auristream_parallel.py` - Model implementation
configuration_auristream_parallel.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ AuriStream Parallel Configuration for HuggingFace Transformers.
3
+ """
4
+
5
+ from transformers import PretrainedConfig
6
+
7
+
8
+ class AuriStreamParallelConfig(PretrainedConfig):
9
+ """Configuration class for AuriStream Parallel models."""
10
+
11
+ model_type = "AuriStreamParallel"
12
+
13
+ def __init__(
14
+ self,
15
+ vocab_size: int = 8193,
16
+ base_vocab_size: int = 8192,
17
+ mask_token_id: int = 8192,
18
+ ignore_index: int = -100,
19
+ n_embd: int = 768,
20
+ n_layer: int = 12,
21
+ n_head: int = 12,
22
+ dropout: float = 0.0,
23
+ bias: bool = False,
24
+ rope_theta: float = 10000.0,
25
+ use_rope: bool = True,
26
+ group_size: int = 4,
27
+ seq_len: int = 4096,
28
+ skip_connections: bool = False,
29
+ mask_schedule: str = "linear_text_prime",
30
+ **kwargs,
31
+ ):
32
+ self.vocab_size = vocab_size
33
+ self.base_vocab_size = base_vocab_size
34
+ self.mask_token_id = mask_token_id
35
+ self.ignore_index = ignore_index
36
+ self.n_embd = n_embd
37
+ self.n_layer = n_layer
38
+ self.n_head = n_head
39
+ self.dropout = dropout
40
+ self.bias = bias
41
+ self.rope_theta = rope_theta
42
+ self.use_rope = use_rope
43
+ self.group_size = group_size
44
+ self.seq_len = seq_len
45
+ self.skip_connections = skip_connections
46
+ self.mask_schedule = mask_schedule
47
+
48
+ super().__init__(**kwargs)
49
+
50
+ @classmethod
51
+ def from_local_config(cls, local_cfg):
52
+ """Create AuriStreamParallelConfig from local dataclass config."""
53
+ config_dict = {}
54
+ known_attrs = [
55
+ "vocab_size", "base_vocab_size", "mask_token_id", "ignore_index",
56
+ "n_embd", "n_layer", "n_head", "dropout", "bias", "rope_theta",
57
+ "use_rope", "group_size", "seq_len", "skip_connections", "mask_schedule",
58
+ ]
59
+ for attr in known_attrs:
60
+ if hasattr(local_cfg, attr):
61
+ config_dict[attr] = getattr(local_cfg, attr)
62
+ return cls(**config_dict)
modeling_auristream_parallel.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ AuriStream Parallel model for HuggingFace Transformers.
3
+ """
4
+
5
+ import math
6
+ from typing import Optional, Tuple
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ from torch.nn import functional as F
11
+
12
+ from transformers import PreTrainedModel
13
+ from transformers.modeling_outputs import CausalLMOutput
14
+
15
+ from .configuration_auristream_parallel import AuriStreamParallelConfig
16
+
17
+
18
+ class RMSNorm(nn.Module):
19
+ def __init__(self, dim: int, weight: bool = True, bias: bool = False, eps: float = 1e-6):
20
+ super().__init__()
21
+ self.eps = eps
22
+ self.weight = nn.Parameter(torch.ones(dim)) if weight else None
23
+
24
+ def _norm(self, x):
25
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
26
+
27
+ def forward(self, x):
28
+ out = self._norm(x.float()).type_as(x)
29
+ return out * self.weight if self.weight is not None else out
30
+
31
+
32
+ class Rotary(nn.Module):
33
+ def __init__(self, dim: int, base: float = 10000):
34
+ super().__init__()
35
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
36
+ self.register_buffer("inv_freq", inv_freq)
37
+
38
+ def forward(self, x):
39
+ seq_len = x.shape[1]
40
+ t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
41
+ freqs = torch.outer(t, self.inv_freq).to(x.device)
42
+ return freqs.cos()[None, :, None, :], freqs.sin()[None, :, None, :]
43
+
44
+
45
+ def apply_rotary_emb(x, cos, sin):
46
+ d = x.shape[3] // 2
47
+ x1 = x[..., :d]
48
+ x2 = x[..., d:]
49
+ y1 = x1 * cos + x2 * sin
50
+ y2 = x1 * (-sin) + x2 * cos
51
+ return torch.cat([y1, y2], dim=3)
52
+
53
+
54
+ class BidirectionalSelfAttention(nn.Module):
55
+ def __init__(self, config: AuriStreamParallelConfig):
56
+ super().__init__()
57
+ self.n_head = config.n_head
58
+ self.n_embd = config.n_embd
59
+ self.head_dim = self.n_embd // self.n_head
60
+ assert self.n_embd % self.n_head == 0
61
+
62
+ self.c_attn = nn.Linear(self.n_embd, 3 * self.n_embd, bias=False)
63
+ self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)
64
+ self.attn_dropout = nn.Dropout(config.dropout)
65
+
66
+ self.rotary = None
67
+ if getattr(config, "use_rope", True):
68
+ rope_theta = getattr(config, "rope_theta", 10000.0) or 10000.0
69
+ self.rotary = Rotary(self.head_dim, base=rope_theta)
70
+
71
+ def forward(self, x):
72
+ bsz, tsz, channels = x.size()
73
+
74
+ qkv = self.c_attn(x)
75
+ q, k, v = qkv.split(self.n_embd, dim=2)
76
+ q = q.view(bsz, tsz, self.n_head, self.head_dim)
77
+ k = k.view(bsz, tsz, self.n_head, self.head_dim)
78
+ v = v.view(bsz, tsz, self.n_head, self.head_dim)
79
+
80
+ if self.rotary is not None:
81
+ cos, sin = self.rotary(q)
82
+ q = apply_rotary_emb(q, cos, sin)
83
+ k = apply_rotary_emb(k, cos, sin)
84
+
85
+ y = F.scaled_dot_product_attention(
86
+ q.transpose(1, 2),
87
+ k.transpose(1, 2),
88
+ v.transpose(1, 2),
89
+ is_causal=False,
90
+ )
91
+
92
+ y = y.transpose(1, 2).contiguous().view(bsz, tsz, channels)
93
+ return self.c_proj(y)
94
+
95
+
96
+ class MLP(nn.Module):
97
+ def __init__(self, config: AuriStreamParallelConfig):
98
+ super().__init__()
99
+ self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
100
+ self.act = nn.SiLU()
101
+ self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
102
+ self.dropout = nn.Dropout(config.dropout)
103
+
104
+ def forward(self, x):
105
+ x = self.c_fc(x)
106
+ x = self.act(x)
107
+ x = self.c_proj(x)
108
+ return self.dropout(x)
109
+
110
+
111
+ class Block(nn.Module):
112
+ def __init__(self, config: AuriStreamParallelConfig):
113
+ super().__init__()
114
+ self.attn = BidirectionalSelfAttention(config)
115
+ self.mlp = MLP(config)
116
+ self.norm1 = RMSNorm(config.n_embd, bias=config.bias)
117
+ self.norm2 = RMSNorm(config.n_embd, bias=config.bias)
118
+
119
+ def forward(self, x):
120
+ x = x + self.attn(self.norm1(x))
121
+ x = x + self.mlp(self.norm2(x))
122
+ return x
123
+
124
+
125
+ class AuriStreamPreTrainedModel(PreTrainedModel):
126
+ config_class = AuriStreamParallelConfig
127
+ base_model_prefix = "model"
128
+ supports_gradient_checkpointing = True
129
+ _no_split_modules = ["Block"]
130
+
131
+ def _init_weights(self, module):
132
+ if isinstance(module, nn.Linear):
133
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
134
+ if module.bias is not None:
135
+ torch.nn.init.zeros_(module.bias)
136
+ elif isinstance(module, nn.Embedding):
137
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
138
+
139
+
140
+ class AuriStreamModel(AuriStreamPreTrainedModel):
141
+ """HF-compatible AuriStream Parallel model."""
142
+
143
+ config_class = AuriStreamParallelConfig
144
+
145
+ def __init__(self, config: AuriStreamParallelConfig):
146
+ super().__init__(config)
147
+ self.config = config
148
+
149
+ self.group_size = int(getattr(config, "group_size", 4))
150
+ grouped_seq_len = max(1, config.seq_len // self.group_size)
151
+
152
+ self.wte = nn.Embedding(config.vocab_size, config.n_embd)
153
+ self.wpe = None
154
+ if not getattr(config, "use_rope", True):
155
+ self.wpe = nn.Embedding(grouped_seq_len, config.n_embd)
156
+
157
+ self.drop = nn.Dropout(config.dropout)
158
+ self.h = nn.ModuleList([Block(config) for _ in range(config.n_layer)])
159
+ self.ln_f = RMSNorm(config.n_embd, bias=config.bias)
160
+
161
+ self.group_in_proj = nn.Linear(self.group_size * config.n_embd, config.n_embd, bias=False)
162
+ self.parallel_heads = nn.ModuleList(
163
+ [nn.Linear(config.n_embd, config.vocab_size, bias=False) for _ in range(self.group_size)]
164
+ )
165
+
166
+ self.apply(self._init_weights)
167
+ for name, param in self.named_parameters():
168
+ if name.endswith("c_proj.weight"):
169
+ torch.nn.init.normal_(param, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer))
170
+
171
+ def get_input_embeddings(self):
172
+ return self.wte
173
+
174
+ def set_input_embeddings(self, value):
175
+ self.wte = value
176
+
177
+ def _group_embed(self, input_ids: torch.LongTensor) -> torch.Tensor:
178
+ bsz, tsz = input_ids.shape
179
+ if tsz % self.group_size != 0:
180
+ raise ValueError(
181
+ f"Sequence length {tsz} must be divisible by group_size={self.group_size}"
182
+ )
183
+
184
+ tok_emb = self.wte(input_ids)
185
+ grouped = tok_emb.view(bsz, tsz // self.group_size, self.group_size, self.config.n_embd)
186
+ grouped = grouped.reshape(bsz, tsz // self.group_size, self.group_size * self.config.n_embd)
187
+ x = self.group_in_proj(grouped)
188
+
189
+ if self.wpe is not None:
190
+ pos = torch.arange(x.size(1), device=input_ids.device)
191
+ x = x + self.wpe(pos)
192
+
193
+ return self.drop(x)
194
+
195
+ def _decode_parallel_logits(self, x: torch.Tensor) -> torch.Tensor:
196
+ per_head = [head(x) for head in self.parallel_heads]
197
+ logits = torch.stack(per_head, dim=2) # (B, T_g, G, V)
198
+ bsz, tg, gsz, vsz = logits.shape
199
+ return logits.reshape(bsz, tg * gsz, vsz)
200
+
201
+ def forward(
202
+ self,
203
+ input_ids: Optional[torch.LongTensor] = None,
204
+ labels: Optional[torch.LongTensor] = None,
205
+ output_hidden_states: Optional[bool] = False,
206
+ return_dict: Optional[bool] = True,
207
+ seq: Optional[torch.LongTensor] = None,
208
+ tgt: Optional[torch.LongTensor] = None,
209
+ ):
210
+ if seq is not None:
211
+ input_ids = seq
212
+ if tgt is not None:
213
+ labels = tgt
214
+ if input_ids is None:
215
+ raise ValueError("input_ids (or seq) must be provided")
216
+
217
+ x = self._group_embed(input_ids)
218
+
219
+ all_hidden_states = ()
220
+ if output_hidden_states:
221
+ all_hidden_states = (x,)
222
+
223
+ for block in self.h:
224
+ x = block(x)
225
+ if output_hidden_states:
226
+ all_hidden_states = all_hidden_states + (x,)
227
+
228
+ x = self.ln_f(x)
229
+ logits = self._decode_parallel_logits(x)
230
+
231
+ loss = None
232
+ if labels is not None:
233
+ loss = F.cross_entropy(
234
+ logits.reshape(-1, self.config.vocab_size),
235
+ labels.reshape(-1),
236
+ ignore_index=getattr(self.config, "ignore_index", -100),
237
+ )
238
+
239
+ if not return_dict:
240
+ out = (logits,)
241
+ if output_hidden_states:
242
+ out = out + (all_hidden_states,)
243
+ return ((loss,) + out) if loss is not None else out
244
+
245
+ return CausalLMOutput(
246
+ loss=loss,
247
+ logits=logits,
248
+ hidden_states=all_hidden_states if output_hidden_states else None,
249
+ attentions=None,
250
+ )