SinclairSchneider commited on
Commit
8fa65ea
1 Parent(s): add1d16

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .ipynb_checkpoints/modeling_dbrx-checkpoint.py +1455 -0
  2. LICENSE.txt +176 -0
  3. NOTICE.txt +1 -0
  4. README.md +172 -0
  5. config.json +38 -0
  6. configuration_dbrx.py +264 -0
  7. generation_config.json +7 -0
  8. huggingface-metadata.txt +65 -0
  9. model-00001-of-00061.safetensors +3 -0
  10. model-00002-of-00061.safetensors +3 -0
  11. model-00003-of-00061.safetensors +3 -0
  12. model-00004-of-00061.safetensors +3 -0
  13. model-00005-of-00061.safetensors +3 -0
  14. model-00006-of-00061.safetensors +3 -0
  15. model-00007-of-00061.safetensors +3 -0
  16. model-00008-of-00061.safetensors +3 -0
  17. model-00009-of-00061.safetensors +3 -0
  18. model-00010-of-00061.safetensors +3 -0
  19. model-00011-of-00061.safetensors +3 -0
  20. model-00012-of-00061.safetensors +3 -0
  21. model-00013-of-00061.safetensors +3 -0
  22. model-00014-of-00061.safetensors +3 -0
  23. model-00015-of-00061.safetensors +3 -0
  24. model-00016-of-00061.safetensors +3 -0
  25. model-00017-of-00061.safetensors +3 -0
  26. model-00018-of-00061.safetensors +3 -0
  27. model-00019-of-00061.safetensors +3 -0
  28. model-00020-of-00061.safetensors +3 -0
  29. model-00021-of-00061.safetensors +3 -0
  30. model-00022-of-00061.safetensors +3 -0
  31. model-00023-of-00061.safetensors +3 -0
  32. model-00024-of-00061.safetensors +3 -0
  33. model-00025-of-00061.safetensors +3 -0
  34. model-00026-of-00061.safetensors +3 -0
  35. model-00027-of-00061.safetensors +3 -0
  36. model-00028-of-00061.safetensors +3 -0
  37. model-00029-of-00061.safetensors +3 -0
  38. model-00030-of-00061.safetensors +3 -0
  39. model-00031-of-00061.safetensors +3 -0
  40. model-00032-of-00061.safetensors +3 -0
  41. model-00033-of-00061.safetensors +3 -0
  42. model-00034-of-00061.safetensors +3 -0
  43. model-00035-of-00061.safetensors +3 -0
  44. model-00036-of-00061.safetensors +3 -0
  45. model-00037-of-00061.safetensors +3 -0
  46. model-00038-of-00061.safetensors +3 -0
  47. model-00039-of-00061.safetensors +3 -0
  48. model-00040-of-00061.safetensors +3 -0
  49. model-00041-of-00061.safetensors +3 -0
  50. model-00042-of-00061.safetensors +3 -0
.ipynb_checkpoints/modeling_dbrx-checkpoint.py ADDED
@@ -0,0 +1,1455 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """PyTorch Dbrx model."""
2
+
3
+ import math
4
+ import warnings
5
+ from copy import deepcopy
6
+ from functools import partial
7
+ from typing import Any, Callable, Dict, Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn.functional as F
11
+ import torch.utils.checkpoint
12
+ from torch import nn
13
+ from transformers.cache_utils import Cache, DynamicCache, StaticCache
14
+ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
15
+ from transformers.modeling_outputs import (MoeCausalLMOutputWithPast,
16
+ MoeModelOutputWithPast)
17
+ from transformers.modeling_utils import PreTrainedModel
18
+ from transformers.utils import is_flash_attn_2_available, logging
19
+
20
+ from .configuration_dbrx import DbrxAttentionConfig, DbrxConfig, DbrxFFNConfig
21
+
22
+ if is_flash_attn_2_available():
23
+ try:
24
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
25
+ from flash_attn.bert_padding import pad_input # noqa
26
+ from flash_attn.bert_padding import index_first_axis, unpad_input
27
+ except:
28
+ pass
29
+
30
+ logger = logging.get_logger(__name__)
31
+
32
+ _CONFIG_FOR_DOC = 'DbrxConfig'
33
+
34
+ #############################################################################
35
+ # Copied from LLaMaRotaryEmbedding
36
+ #############################################################################
37
+
38
+
39
+ class DbrxRotaryEmbedding(nn.Module):
40
+
41
+ def __init__(self,
42
+ dim: int,
43
+ max_position_embeddings: int = 2048,
44
+ base: float = 10000.0,
45
+ scaling_factor: float = 1.0):
46
+ super().__init__()
47
+ self.scaling_factor = scaling_factor
48
+ self.dim = dim
49
+ self.max_position_embeddings = max_position_embeddings
50
+ self.base = base
51
+ inv_freq = 1.0 / (self.base**(
52
+ torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim))
53
+ self.register_buffer('inv_freq', inv_freq, persistent=False)
54
+ # For BC we register cos and sin cached
55
+ self.max_seq_len_cached = max_position_embeddings
56
+
57
+ @torch.no_grad()
58
+ def forward(
59
+ self, x: torch.Tensor, position_ids: torch.LongTensor
60
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
61
+ # x: [bs, num_attention_heads, seq_len, head_size]
62
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(
63
+ position_ids.shape[0], -1, 1)
64
+ position_ids_expanded = position_ids[:, None, :].float()
65
+ # Force float32 since bfloat16 loses precision on long contexts
66
+ # See https://github.com/huggingface/transformers/pull/29285
67
+ device_type = x.device.type
68
+ device_type = device_type if isinstance(
69
+ device_type, str) and device_type != 'mps' else 'cpu'
70
+ with torch.autocast(device_type=device_type, enabled=False):
71
+ freqs = (inv_freq_expanded.float()
72
+ @ position_ids_expanded.float()).transpose(1, 2)
73
+ emb = torch.cat((freqs, freqs), dim=-1)
74
+ cos = emb.cos()
75
+ sin = emb.sin()
76
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
77
+
78
+
79
+ def rotate_half(x: torch.Tensor) -> torch.Tensor:
80
+ """Rotates half the hidden dims of the input."""
81
+ x1 = x[..., :x.shape[-1] // 2]
82
+ x2 = x[..., x.shape[-1] // 2:]
83
+ return torch.cat((-x2, x1), dim=-1)
84
+
85
+
86
+ def apply_rotary_pos_emb(
87
+ q: torch.Tensor,
88
+ k: torch.Tensor,
89
+ cos: torch.Tensor,
90
+ sin: torch.Tensor,
91
+ unsqueeze_dim: int = 1) -> Tuple[torch.Tensor, torch.Tensor]:
92
+ """Applies Rotary Position Embedding to the query and key tensors.
93
+
94
+ Args:
95
+ q (`torch.Tensor`): The query tensor.
96
+ k (`torch.Tensor`): The key tensor.
97
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
98
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
99
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
100
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos and
101
+ sin so that they can be properly broadcasted to the dimensions of q and k. For example, note
102
+ that cos and sin have the shape [batch_size, seq_len, head_dim]. Then, if q and
103
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
104
+ cos and sin broadcastable to the shapes of q and k. Similarly, if q and k have
105
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
106
+
107
+ Returns:
108
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
109
+ """
110
+ cos = cos.unsqueeze(unsqueeze_dim)
111
+ sin = sin.unsqueeze(unsqueeze_dim)
112
+ q_embed = (q * cos) + (rotate_half(q) * sin)
113
+ k_embed = (k * cos) + (rotate_half(k) * sin)
114
+ return q_embed, k_embed
115
+
116
+
117
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
118
+ """Equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep).
119
+
120
+ The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to
121
+ (batch, num_attention_heads, seqlen, head_dim)
122
+ """
123
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
124
+ if n_rep == 1:
125
+ return hidden_states
126
+ hidden_states = hidden_states[:, :,
127
+ None, :, :].expand(batch, num_key_value_heads,
128
+ n_rep, slen, head_dim)
129
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen,
130
+ head_dim)
131
+
132
+
133
+ #############################################################################
134
+
135
+ #############################################################################
136
+ # Modified from modeling_mixtral
137
+ #############################################################################
138
+
139
+
140
+ def load_balancing_loss_func(
141
+ gate_logits: torch.Tensor,
142
+ num_experts: int,
143
+ top_k: int,
144
+ attention_mask: Optional[torch.Tensor],
145
+ ) -> torch.Tensor:
146
+ r"""Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
147
+
148
+ See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss
149
+ function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
150
+ experts is too unbalanced.
151
+
152
+ Args:
153
+ gate_logits (Union[`torch.Tensor`, Tuple[torch.Tensor]):
154
+ Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of
155
+ shape [batch_size X sequence_length, num_experts].
156
+ num_experts (`int`):
157
+ Number of experts.
158
+ top_k (`int`):
159
+ The number of experts each token is routed to.
160
+ attention_mask (`torch.Tensor`, None):
161
+ The attention_mask used in forward function
162
+ shape [batch_size X sequence_length] if not None.
163
+
164
+ Returns:
165
+ The auxiliary loss.
166
+ """
167
+ if gate_logits is None or not isinstance(gate_logits, tuple):
168
+ return torch.tensor(0.0)
169
+
170
+ if isinstance(gate_logits, tuple):
171
+ compute_device = gate_logits[0].device
172
+ concatenated_gate_logits = torch.cat(
173
+ [layer_gate.to(compute_device) for layer_gate in gate_logits],
174
+ dim=0)
175
+
176
+ routing_weights = torch.nn.functional.softmax(concatenated_gate_logits,
177
+ dim=-1)
178
+
179
+ _, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
180
+
181
+ expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
182
+
183
+ if attention_mask is None:
184
+ # Compute the percentage of tokens routed to each experts
185
+ tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
186
+
187
+ # Compute the average probability of routing to these experts
188
+ router_prob_per_expert = torch.mean(routing_weights, dim=0)
189
+ else:
190
+ batch_size, sequence_length = attention_mask.shape
191
+ num_hidden_layers = concatenated_gate_logits.shape[0] // (
192
+ batch_size * sequence_length)
193
+
194
+ # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
195
+ expert_attention_mask = (attention_mask[None, :, :, None, None].expand(
196
+ (num_hidden_layers, batch_size, sequence_length, top_k,
197
+ num_experts)).reshape(-1, top_k, num_experts).to(compute_device))
198
+
199
+ # Compute the percentage of tokens routed to each experts
200
+ tokens_per_expert = torch.sum(
201
+ expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(
202
+ expert_attention_mask, dim=0)
203
+
204
+ # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
205
+ router_per_expert_attention_mask = (
206
+ attention_mask[None, :, :, None].expand(
207
+ (num_hidden_layers, batch_size, sequence_length,
208
+ num_experts)).reshape(-1, num_experts).to(compute_device))
209
+
210
+ # Compute the average probability of routing to these experts
211
+ router_prob_per_expert = torch.sum(
212
+ routing_weights * router_per_expert_attention_mask,
213
+ dim=0) / torch.sum(router_per_expert_attention_mask, dim=0)
214
+
215
+ overall_loss = torch.sum(tokens_per_expert *
216
+ router_prob_per_expert.unsqueeze(0))
217
+ return overall_loss * num_experts
218
+
219
+
220
+ #############################################################################
221
+
222
+
223
+ def resolve_ffn_act_fn(
224
+ ffn_act_fn: dict) -> Callable[[torch.Tensor], torch.Tensor]:
225
+ """Resolve the activation function for the feed-forward network.
226
+
227
+ Args:
228
+ ffn_act_fn (dict): The configuration dictionary for the activation function.
229
+ The dict config must specify the 'name' of a torch.nn.functional activation
230
+ function. All of other key values pairs are bound to the function as a partial.
231
+
232
+ Returns:
233
+ Callable[[torch.Tensor], torch.Tensor]: The activation function.
234
+ """
235
+ config = deepcopy(ffn_act_fn)
236
+ name = config.pop('name')
237
+ if not hasattr(nn.functional, name):
238
+ raise ValueError(f'Unrecognised activation function name ({name}).')
239
+ act = getattr(nn.functional, name)
240
+ return partial(act, **config)
241
+
242
+
243
+ #############################################################################
244
+ # Copied from LLaMaAttention
245
+ #############################################################################
246
+
247
+
248
+ def _get_unpad_data(attention_mask: torch.Tensor):
249
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
250
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
251
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
252
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32),
253
+ (1, 0))
254
+ return (
255
+ indices,
256
+ cu_seqlens,
257
+ max_seqlen_in_batch,
258
+ )
259
+
260
+
261
+ class DbrxAttention(nn.Module):
262
+ """Multi-head self attention."""
263
+
264
+ def __init__(self,
265
+ hidden_size: int,
266
+ num_heads: int,
267
+ max_position_embeddings: int,
268
+ attn_config: DbrxAttentionConfig,
269
+ block_idx: Optional[int] = None):
270
+ super().__init__()
271
+ self.hidden_size = hidden_size
272
+ self.num_heads = num_heads
273
+ self.head_dim = self.hidden_size // self.num_heads
274
+ self.max_position_embeddings = max_position_embeddings
275
+ self.block_idx = block_idx
276
+ self.config = attn_config
277
+ if block_idx is None:
278
+ logger.warning_once(
279
+ f'Instantiating {self.__class__.__name__} without passing a `block_idx` is not recommended and will '
280
+ +
281
+ 'lead to errors during the forward call if caching is used. Please make sure to provide a `block_idx` '
282
+ + 'when creating this class.')
283
+
284
+ self.attn_pdrop = attn_config.attn_pdrop
285
+ self.clip_qkv = attn_config.clip_qkv
286
+ self.num_key_value_heads = attn_config.kv_n_heads
287
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
288
+ self.rope_theta = attn_config.rope_theta
289
+
290
+ self.Wqkv = nn.Linear(self.hidden_size,
291
+ self.hidden_size +
292
+ 2 * self.num_key_value_heads * self.head_dim,
293
+ bias=False)
294
+ self.out_proj = nn.Linear(self.hidden_size,
295
+ self.hidden_size,
296
+ bias=False)
297
+ self.rotary_emb = DbrxRotaryEmbedding(
298
+ self.head_dim,
299
+ max_position_embeddings=self.max_position_embeddings,
300
+ base=self.rope_theta,
301
+ )
302
+
303
+ def forward(
304
+ self,
305
+ hidden_states: torch.Tensor,
306
+ position_ids: torch.LongTensor,
307
+ attention_mask: Optional[torch.Tensor] = None,
308
+ past_key_value: Optional[Cache] = None,
309
+ output_attentions: bool = False,
310
+ use_cache: bool = False,
311
+ cache_position: Optional[torch.LongTensor] = None,
312
+ **kwargs: Any,
313
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
314
+ bsz, q_len, _ = hidden_states.size()
315
+
316
+ qkv_states = self.Wqkv(hidden_states)
317
+ if self.clip_qkv is not None:
318
+ qkv_states = qkv_states.clamp(min=-self.clip_qkv, max=self.clip_qkv)
319
+
320
+ query_states, key_states, value_states = qkv_states.split(
321
+ [
322
+ self.hidden_size,
323
+ self.num_key_value_heads * self.head_dim,
324
+ self.num_key_value_heads * self.head_dim,
325
+ ],
326
+ dim=2,
327
+ )
328
+
329
+ query_states = query_states.view(bsz, q_len, self.num_heads,
330
+ self.head_dim).transpose(1, 2)
331
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads,
332
+ self.head_dim).transpose(1, 2)
333
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads,
334
+ self.head_dim).transpose(1, 2)
335
+
336
+ past_key_value = getattr(self, 'past_key_value', past_key_value)
337
+ cos, sin = self.rotary_emb(value_states, position_ids)
338
+ query_states, key_states = apply_rotary_pos_emb(query_states,
339
+ key_states, cos, sin)
340
+
341
+ if past_key_value is not None:
342
+ # sin and cos are specific to RoPE models; position_ids needed for the static cache
343
+ cache_kwargs = {
344
+ 'sin': sin,
345
+ 'cos': cos,
346
+ 'cache_position': cache_position
347
+ }
348
+ key_states, value_states = past_key_value.update(
349
+ key_states, value_states, self.block_idx, cache_kwargs)
350
+
351
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
352
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
353
+
354
+ attn_weights = torch.matmul(query_states, key_states.transpose(
355
+ 2, 3)) / math.sqrt(self.head_dim)
356
+
357
+ if attention_mask is not None: # no matter the length, we just slice it
358
+ causal_mask = attention_mask[:, :, :, :key_states.shape[-2]]
359
+ attn_weights = attn_weights + causal_mask
360
+
361
+ # upcast attention to fp32
362
+ attn_weights = nn.functional.softmax(attn_weights,
363
+ dim=-1,
364
+ dtype=torch.float32).to(
365
+ query_states.dtype)
366
+ attn_weights = nn.functional.dropout(attn_weights,
367
+ p=self.attn_pdrop,
368
+ training=self.training)
369
+ attn_output = torch.matmul(attn_weights, value_states)
370
+
371
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
372
+ raise ValueError(
373
+ f'`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is'
374
+ + f' {attn_output.size()}')
375
+
376
+ attn_output = attn_output.transpose(1, 2).contiguous()
377
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
378
+ attn_output = self.out_proj(attn_output)
379
+
380
+ if not output_attentions:
381
+ attn_weights = None
382
+
383
+ return attn_output, attn_weights, past_key_value
384
+
385
+
386
+ class DbrxFlashAttention2(DbrxAttention):
387
+ """Dbrx flash attention module.
388
+
389
+ This module inherits from `DbrxAttention` as the weights of the module stays
390
+ untouched. The only required change would be on the forward pass where it
391
+ calls the public API of flash attention.
392
+ """
393
+
394
+ def __init__(self, *args: Any, **kwargs: Any):
395
+ if not is_flash_attn_2_available():
396
+ raise ImportError(
397
+ 'Flash Attention 2 is not available. Please install it with `pip install flash-attn`.'
398
+ )
399
+
400
+ super().__init__(*args, **kwargs)
401
+
402
+ def forward(
403
+ self,
404
+ hidden_states: torch.Tensor,
405
+ attention_mask: Optional[torch.LongTensor] = None,
406
+ position_ids: Optional[torch.LongTensor] = None,
407
+ past_key_value: Optional[Cache] = None,
408
+ output_attentions: bool = False,
409
+ use_cache: bool = False,
410
+ cache_position: Optional[torch.LongTensor] = None,
411
+ **kwargs: Any,
412
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor],
413
+ Optional[Tuple[torch.Tensor]]]:
414
+ logger.info(
415
+ 'Implicitly setting `output_attentions` to False as it is not supported in Flash Attention.'
416
+ )
417
+ output_attentions = False
418
+
419
+ bsz, q_len, _ = hidden_states.size()
420
+
421
+ qkv_states = self.Wqkv(hidden_states)
422
+ if self.clip_qkv is not None:
423
+ qkv_states = qkv_states.clamp(min=-self.clip_qkv, max=self.clip_qkv)
424
+
425
+ query_states, key_states, value_states = qkv_states.split(
426
+ [
427
+ self.hidden_size,
428
+ self.num_key_value_heads * self.head_dim,
429
+ self.num_key_value_heads * self.head_dim,
430
+ ],
431
+ dim=2,
432
+ )
433
+
434
+ # Flash attention requires the input to have the shape
435
+ # batch_size x seq_length x head_dim x hidden_dim
436
+ # therefore we just need to keep the original shape
437
+ query_states = query_states.view(bsz, q_len, self.num_heads,
438
+ self.head_dim).transpose(1, 2)
439
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads,
440
+ self.head_dim).transpose(1, 2)
441
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads,
442
+ self.head_dim).transpose(1, 2)
443
+
444
+ cos, sin = self.rotary_emb(value_states, position_ids)
445
+ query_states, key_states = apply_rotary_pos_emb(query_states,
446
+ key_states, cos, sin)
447
+
448
+ past_key_value = getattr(self, 'past_key_value', past_key_value)
449
+
450
+ if past_key_value is not None:
451
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
452
+ cache_kwargs = {
453
+ 'sin': sin,
454
+ 'cos': cos,
455
+ 'cache_position': cache_position
456
+ }
457
+ key_states, value_states = past_key_value.update(
458
+ key_states, value_states, self.block_idx, cache_kwargs)
459
+
460
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout
461
+ # [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
462
+ # to be able to avoid many of these transpose/reshape/view.
463
+ query_states = query_states.transpose(1, 2)
464
+ key_states = key_states.transpose(1, 2)
465
+ value_states = value_states.transpose(1, 2)
466
+
467
+ dropout_rate = self.attn_pdrop if self.training else 0.0
468
+
469
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
470
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
471
+ # cast them back in the correct dtype just to be sure everything works as expected.
472
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
473
+ # in fp32. (LlamaRMSNorm handles it correctly)
474
+ input_dtype = query_states.dtype
475
+ if input_dtype == torch.float32:
476
+ if torch.is_autocast_enabled():
477
+ target_dtype = torch.get_autocast_gpu_dtype()
478
+ # Handle the case where the model is quantized
479
+ elif hasattr(self.config, '_pre_quantization_dtype'):
480
+ target_dtype = self.config._pre_quantization_dtype
481
+ else:
482
+ target_dtype = query_states.dtype
483
+
484
+ logger.warning_once(
485
+ f'The input hidden states seems to be silently casted in float32, this might be '
486
+ +
487
+ f'related to the fact you have upcasted embedding or layer norm layers in '
488
+ + f'float32. We will cast back the input in {target_dtype}.')
489
+
490
+ query_states = query_states.to(target_dtype)
491
+ key_states = key_states.to(target_dtype)
492
+ value_states = value_states.to(target_dtype)
493
+
494
+ attn_output = self._flash_attention_forward(
495
+ query_states,
496
+ key_states,
497
+ value_states,
498
+ attention_mask,
499
+ q_len,
500
+ dropout=dropout_rate,
501
+ )
502
+
503
+ attn_output = attn_output.reshape(bsz, q_len,
504
+ self.hidden_size).contiguous()
505
+ attn_output = self.out_proj(attn_output)
506
+
507
+ if not output_attentions:
508
+ attn_weights = None
509
+
510
+ return attn_output, attn_weights, past_key_value # type: ignore
511
+
512
+ def _flash_attention_forward(
513
+ self,
514
+ query_states: torch.Tensor,
515
+ key_states: torch.Tensor,
516
+ value_states: torch.Tensor,
517
+ attention_mask: Union[torch.LongTensor, None],
518
+ query_length: int,
519
+ dropout: float = 0.0,
520
+ softmax_scale: Optional[float] = None,
521
+ ):
522
+ """Use FlashAttention, stripping padding tokens if necessary.
523
+
524
+ Args:
525
+ query_states (torch.Tensor): Input query states to be passed to Flash Attention API
526
+ key_states (torch.Tensor): Input key states to be passed to Flash Attention API
527
+ value_states (torch.Tensor): Input value states to be passed to Flash Attention API
528
+ attention_mask (torch.LongTensor | None): The padding mask - corresponds to a tensor of size
529
+ (batch_size, seq_len) where 0 stands for the position of padding tokens and 1
530
+ for the position of non-padding tokens.
531
+ query_length (int): The length of the query sequence
532
+ dropout (float): Attention dropout
533
+ softmax_scale (float, optional): The scaling of QK^T before applying softmax.
534
+ Defaults to 1 / sqrt(head_dim)
535
+ """
536
+ causal = True
537
+ # Contains at least one padding token in the sequence
538
+ if attention_mask is not None:
539
+ batch_size = query_states.shape[0]
540
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
541
+ query_states, key_states, value_states, attention_mask,
542
+ query_length)
543
+
544
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
545
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
546
+
547
+ attn_output_unpad = flash_attn_varlen_func(
548
+ query_states,
549
+ key_states,
550
+ value_states,
551
+ cu_seqlens_q=cu_seqlens_q,
552
+ cu_seqlens_k=cu_seqlens_k,
553
+ max_seqlen_q=max_seqlen_in_batch_q,
554
+ max_seqlen_k=max_seqlen_in_batch_k,
555
+ dropout_p=dropout,
556
+ softmax_scale=softmax_scale,
557
+ causal=causal,
558
+ )
559
+
560
+ attn_output = pad_input(
561
+ attn_output_unpad,
562
+ indices_q,
563
+ batch_size,
564
+ query_length,
565
+ )
566
+ else:
567
+ attn_output = flash_attn_func(
568
+ query_states,
569
+ key_states,
570
+ value_states,
571
+ dropout,
572
+ softmax_scale=softmax_scale,
573
+ causal=causal,
574
+ )
575
+
576
+ return attn_output
577
+
578
+ def _upad_input(self, query_layer: torch.Tensor, key_layer: torch.Tensor,
579
+ value_layer: torch.Tensor, attention_mask: torch.Tensor,
580
+ query_length: int):
581
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(
582
+ attention_mask)
583
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
584
+
585
+ key_layer = index_first_axis(
586
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads,
587
+ head_dim), indices_k)
588
+ value_layer = index_first_axis(
589
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads,
590
+ head_dim), indices_k)
591
+ if query_length == kv_seq_len:
592
+ query_layer = index_first_axis(
593
+ query_layer.reshape(batch_size * kv_seq_len, self.num_heads,
594
+ head_dim), indices_k)
595
+ cu_seqlens_q = cu_seqlens_k
596
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
597
+ indices_q = indices_k
598
+ elif query_length == 1:
599
+ max_seqlen_in_batch_q = 1
600
+ cu_seqlens_q = torch.arange(
601
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
602
+ ) # There is a memcpy here, that is very bad.
603
+ indices_q = cu_seqlens_q[:-1]
604
+ query_layer = query_layer.squeeze(1)
605
+ else:
606
+ # The -q_len: slice assumes left padding.
607
+ attention_mask = attention_mask[:, -query_length:]
608
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(
609
+ query_layer, attention_mask)
610
+
611
+ return (
612
+ query_layer,
613
+ key_layer,
614
+ value_layer,
615
+ indices_q,
616
+ (cu_seqlens_q, cu_seqlens_k),
617
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
618
+ )
619
+
620
+
621
+ DBRX_ATTENTION_CLASSES = {
622
+ 'eager': DbrxAttention,
623
+ 'flash_attention_2': DbrxFlashAttention2,
624
+ }
625
+
626
+
627
+ class DbrxNormAttentionNorm(nn.Module):
628
+
629
+ def __init__(
630
+ self,
631
+ hidden_size: int,
632
+ num_heads: int,
633
+ max_position_embeddings: int,
634
+ resid_pdrop: float,
635
+ attn_implementation: str,
636
+ attn_config: DbrxAttentionConfig,
637
+ block_idx: Optional[int] = None,
638
+ ):
639
+ super().__init__()
640
+ self.block_idx = block_idx
641
+ self.resid_pdrop = resid_pdrop
642
+ self.norm_1 = nn.LayerNorm(hidden_size, bias=False)
643
+ self.attn = DBRX_ATTENTION_CLASSES[attn_implementation](
644
+ hidden_size=hidden_size,
645
+ num_heads=num_heads,
646
+ max_position_embeddings=max_position_embeddings,
647
+ attn_config=attn_config,
648
+ block_idx=block_idx,
649
+ )
650
+ self.norm_2 = nn.LayerNorm(hidden_size, bias=False)
651
+
652
+ def forward(
653
+ self,
654
+ hidden_states: torch.Tensor,
655
+ position_ids: torch.LongTensor,
656
+ attention_mask: Optional[torch.Tensor] = None,
657
+ past_key_value: Optional[Cache] = None,
658
+ output_attentions: bool = False,
659
+ use_cache: bool = False,
660
+ cache_position: Optional[torch.LongTensor] = None,
661
+ **kwargs: Any,
662
+ ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
663
+ Optional[Cache]]:
664
+
665
+ residual_states = hidden_states
666
+ hidden_states = self.norm_1(hidden_states).to(hidden_states.dtype)
667
+
668
+ hidden_states, attn_weights, past_key_value = self.attn(
669
+ hidden_states=hidden_states,
670
+ attention_mask=attention_mask,
671
+ position_ids=position_ids,
672
+ past_key_value=past_key_value,
673
+ output_attentions=output_attentions,
674
+ use_cache=use_cache,
675
+ cache_position=cache_position,
676
+ **kwargs,
677
+ )
678
+
679
+ hidden_states = nn.functional.dropout(hidden_states,
680
+ p=self.resid_pdrop,
681
+ training=self.training)
682
+ hidden_states = hidden_states + residual_states
683
+
684
+ residual_states = hidden_states
685
+ hidden_states = self.norm_2(hidden_states).to(hidden_states.dtype)
686
+
687
+ return residual_states, hidden_states, attn_weights, past_key_value
688
+
689
+
690
+ class DbrxRouter(nn.Module):
691
+
692
+ def __init__(self, hidden_size: int, moe_num_experts: int, moe_top_k: int,
693
+ moe_jitter_eps: Optional[float],
694
+ moe_normalize_expert_weights: Optional[float],
695
+ uniform_expert_assignment: bool):
696
+ super().__init__()
697
+ self.hidden_size = hidden_size
698
+ self.moe_num_experts = moe_num_experts
699
+ self.moe_top_k = moe_top_k
700
+ self.moe_jitter_eps = moe_jitter_eps
701
+ self.moe_normalize_expert_weights = moe_normalize_expert_weights
702
+ self.uniform_expert_assignment = uniform_expert_assignment
703
+
704
+ self.layer = nn.Linear(self.hidden_size,
705
+ self.moe_num_experts,
706
+ bias=False)
707
+
708
+ def jitter(self, x: torch.Tensor) -> torch.Tensor:
709
+ if self.moe_jitter_eps is None:
710
+ raise RuntimeError('The router does not have moe_jitter_eps set.')
711
+ low = 1.0 - self.moe_jitter_eps
712
+ high = 1.0 + self.moe_jitter_eps
713
+ noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
714
+ return low + noise * (high - low)
715
+
716
+ def forward(
717
+ self, x: torch.Tensor
718
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.LongTensor]:
719
+ if self.training and self.moe_jitter_eps is not None:
720
+ x = x * self.jitter(x)
721
+
722
+ weights = self.layer(x.view(-1,
723
+ x.shape[-1])).softmax(dim=-1,
724
+ dtype=torch.float32)
725
+ top_weights, top_experts = torch.topk(weights, self.moe_top_k, dim=-1)
726
+
727
+ if self.moe_normalize_expert_weights:
728
+ top_weights = top_weights / torch.norm(
729
+ top_weights,
730
+ p=self.moe_normalize_expert_weights,
731
+ dim=-1,
732
+ keepdim=True)
733
+
734
+ if self.uniform_expert_assignment:
735
+ with torch.no_grad():
736
+ uniform_tensor = torch.arange(
737
+ 0,
738
+ top_experts.numel(),
739
+ device=top_experts.device,
740
+ dtype=top_experts.dtype) % self.moe_num_experts
741
+ top_experts = uniform_tensor.reshape(top_experts.shape)
742
+ # Note, weights and top_weights are not changed
743
+
744
+ weights = weights.to(x.dtype)
745
+ top_weights = top_weights.to(x.dtype)
746
+ return weights, top_weights, top_experts # type: ignore
747
+
748
+
749
+ class DbrxMLP(nn.Module):
750
+
751
+ def __init__(self, hidden_size: int, ffn_hidden_size: int, ffn_act_fn: dict):
752
+ super().__init__()
753
+
754
+ self.w1 = nn.Linear(hidden_size, ffn_hidden_size, bias=False)
755
+ self.v1 = nn.Linear(hidden_size, ffn_hidden_size, bias=False)
756
+ self.w2 = nn.Linear(ffn_hidden_size, hidden_size, bias=False)
757
+ self.activation_fn = resolve_ffn_act_fn(ffn_act_fn)
758
+
759
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
760
+
761
+ return self.w2(self.activation_fn(self.w1(x)) * self.v1(x))
762
+
763
+
764
+ class DbrxExperts(nn.Module):
765
+
766
+ def __init__(self, hidden_size: int, ffn_hidden_size: int,
767
+ moe_num_experts: int, ffn_act_fn: dict):
768
+ super().__init__()
769
+ self.moe_num_experts = moe_num_experts
770
+ self.mlp = nn.ModuleList([DbrxMLP(hidden_size, ffn_hidden_size, ffn_act_fn) for _ in range(moe_num_experts)])
771
+
772
+ def forward(self, x: torch.Tensor, weights: torch.Tensor,
773
+ top_weights: torch.Tensor,
774
+ top_experts: torch.LongTensor) -> torch.Tensor:
775
+ bsz, q_len, hidden_size = x.shape
776
+ x = x.view(-1, hidden_size)
777
+ out = torch.zeros_like(x)
778
+
779
+ expert_mask = nn.functional.one_hot(
780
+ top_experts, num_classes=self.moe_num_experts).permute(2, 1, 0)
781
+ for expert_idx in range(0, self.moe_num_experts):
782
+ topk_idx, token_idx = torch.where(expert_mask[expert_idx])
783
+ if token_idx.shape[0] == 0:
784
+ continue
785
+
786
+ token_list = token_idx.tolist()
787
+ topk_list = topk_idx.tolist()
788
+
789
+ expert_tokens = x[None, token_list].reshape(-1, hidden_size)
790
+ expert_out = self.mlp[expert_idx](expert_tokens) * top_weights[token_list, topk_list, None]
791
+
792
+ out.index_add_(0, token_idx, expert_out)
793
+
794
+ out = out.reshape(bsz, q_len, hidden_size)
795
+ return out
796
+
797
+
798
+ class DbrxFFN(nn.Module):
799
+
800
+ def __init__(self, hidden_size: int, ffn_config: DbrxFFNConfig):
801
+ super().__init__()
802
+
803
+ self.router = DbrxRouter(
804
+ hidden_size,
805
+ moe_num_experts=ffn_config.moe_num_experts,
806
+ moe_top_k=ffn_config.moe_top_k,
807
+ moe_jitter_eps=ffn_config.moe_jitter_eps,
808
+ moe_normalize_expert_weights=ffn_config.
809
+ moe_normalize_expert_weights,
810
+ uniform_expert_assignment=ffn_config.uniform_expert_assignment,
811
+ )
812
+
813
+ self.experts = DbrxExperts(
814
+ hidden_size=hidden_size,
815
+ ffn_hidden_size=ffn_config.ffn_hidden_size,
816
+ moe_num_experts=ffn_config.moe_num_experts,
817
+ ffn_act_fn=ffn_config.ffn_act_fn,
818
+ )
819
+
820
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
821
+ weights, top_weights, top_experts = self.router(x)
822
+ out = self.experts(x, weights, top_weights, top_experts)
823
+ return out, weights
824
+
825
+
826
+ class DbrxBlock(nn.Module):
827
+
828
+ def __init__(self, config: DbrxConfig, block_idx: int):
829
+ super().__init__()
830
+ self.hidden_size = config.d_model
831
+ self.resid_pdrop = config.resid_pdrop
832
+ self.block_idx = block_idx
833
+ self.norm_attn_norm = DbrxNormAttentionNorm(
834
+ hidden_size=config.d_model,
835
+ num_heads=config.n_heads,
836
+ max_position_embeddings=config.max_seq_len,
837
+ resid_pdrop=config.resid_pdrop,
838
+ attn_implementation=config._attn_implementation,
839
+ attn_config=config.attn_config,
840
+ block_idx=block_idx,
841
+ )
842
+ self.ffn = DbrxFFN(hidden_size=config.d_model,
843
+ ffn_config=config.ffn_config)
844
+
845
+ def forward(
846
+ self,
847
+ hidden_states: torch.Tensor,
848
+ position_ids: torch.LongTensor,
849
+ attention_mask: Optional[torch.Tensor] = None,
850
+ past_key_value: Optional[Cache] = None,
851
+ output_attentions: Optional[bool] = False,
852
+ output_router_logits: Optional[bool] = False,
853
+ use_cache: Optional[bool] = False,
854
+ cache_position: Optional[torch.LongTensor] = None,
855
+ **kwargs: Any,
856
+ ) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, Optional[torch.Tensor]],
857
+ Tuple[torch.Tensor, Optional[Cache]], Tuple[
858
+ torch.Tensor, Optional[torch.Tensor], Optional[Cache]],
859
+ Tuple[torch.Tensor, Optional[torch.Tensor],
860
+ Optional[torch.Tensor]], Tuple[
861
+ torch.Tensor, Optional[Cache], Optional[torch.Tensor]],
862
+ Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache],
863
+ Optional[torch.Tensor]],]:
864
+ """Forward function for DbrxBlock.
865
+
866
+ Args:
867
+ hidden_states (`torch.Tensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
868
+ position_ids (`torch.LongTensor`): position ids of shape `(batch, seq_len)`
869
+ attention_mask (`torch.Tensor`, optional): attention mask of size (batch_size, sequence_length)
870
+ if flash attention is used or (batch_size, 1, query_sequence_length, key_sequence_length)
871
+ if default attention is used.
872
+ past_key_value (`Tuple(torch.Tensor)`, optional): cached past key and value projection states
873
+ output_attentions (`bool`, optional): Whether or not to return the attentions tensors of all
874
+ attention layers. See `attentions` under returned tensors for more detail.
875
+ output_router_logits (`bool`, optional): Whether or not to return the router logits.
876
+ use_cache (`bool`, optional): If set to `True`, `past_key_values` key value states are
877
+ returned and can be used to speed up decoding (see `past_key_values`).
878
+ cache_position (`torch.LongTensor`, optional): position ids of the cache
879
+ """
880
+ if 'padding_mask' in kwargs:
881
+ warnings.warn(
882
+ 'Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`'
883
+ )
884
+
885
+ # Norm + Attention + Norm
886
+ resid_states, hidden_states, self_attn_weights, present_key_value = self.norm_attn_norm(
887
+ hidden_states=hidden_states,
888
+ attention_mask=attention_mask,
889
+ position_ids=position_ids,
890
+ past_key_value=past_key_value,
891
+ output_attentions=output_attentions,
892
+ use_cache=use_cache,
893
+ cache_position=cache_position,
894
+ **kwargs,
895
+ )
896
+
897
+ # Fully Connected
898
+ hidden_states, router_logits = self.ffn(hidden_states)
899
+ hidden_states = nn.functional.dropout(hidden_states,
900
+ p=self.resid_pdrop,
901
+ training=self.training)
902
+ hidden_states = resid_states + hidden_states
903
+
904
+ outputs = (hidden_states,)
905
+
906
+ if output_attentions:
907
+ outputs += (self_attn_weights,)
908
+
909
+ if use_cache:
910
+ outputs += (present_key_value,)
911
+
912
+ if output_router_logits:
913
+ outputs += (router_logits,)
914
+
915
+ return outputs
916
+
917
+
918
+ class DbrxPreTrainedModel(PreTrainedModel):
919
+ config_class = DbrxConfig
920
+ base_model_prefix = 'transformer'
921
+ supports_gradient_checkpointing = True
922
+ _no_split_modules = ['DbrxBlock']
923
+ _skip_keys_device_placement = ['past_key_values']
924
+ _supports_flash_attn_2 = True
925
+ _supports_sdpa = False
926
+ _supports_cache_class = True
927
+
928
+ def _init_weights(self, module: nn.Module):
929
+ std = self.config.initializer_range
930
+ if isinstance(module, nn.Linear):
931
+ module.weight.data.normal_(mean=0.0, std=std)
932
+ if module.bias is not None:
933
+ module.bias.data.zero_()
934
+ elif isinstance(module, nn.Embedding):
935
+ module.weight.data.normal_(mean=0.0, std=std)
936
+ if module.padding_idx is not None:
937
+ module.weight.data[module.padding_idx].zero_()
938
+ elif isinstance(module, nn.LayerNorm):
939
+ module.weight.data.normal_(mean=0.0, std=std)
940
+ if module.bias is not None:
941
+ module.bias.data.zero_()
942
+
943
+ def _setup_cache(self, cache_cls: Any, max_batch_size: int,
944
+ max_cache_len: int): # TODO: how to set var type of class?
945
+ if self.config._attn_implementation == 'flash_attention_2' and cache_cls == StaticCache:
946
+ raise ValueError(
947
+ '`static` cache implementation is not compatible with ' +
948
+ '`attn_implementation==flash_attention_2`. Make sure to use ' +
949
+ '`spda` in the mean time and open an issue at https://github.com/huggingface/transformers.'
950
+ )
951
+
952
+ for block in self.transformer.blocks:
953
+ device = block.norm_attn_norm.norm_1.weight.device
954
+ if hasattr(self.config, '_pre_quantization_dtype'):
955
+ dtype = self.config._pre_quantization_dtype
956
+ else:
957
+ dtype = block.norm_attn_norm.attn.out_proj.weight.dtype
958
+ block.norm_attn_norm.attn.past_key_value = cache_cls(self.config,
959
+ max_batch_size,
960
+ max_cache_len,
961
+ device=device,
962
+ dtype=dtype)
963
+
964
+ def _reset_cache(self):
965
+ for block in self.transformer.blocks:
966
+ block.norm_attn_norm.attn.past_key_value = None
967
+
968
+
969
+ class DbrxModel(DbrxPreTrainedModel):
970
+ """Transformer decoder consisting of *config.num_hidden_layers*
971
+
972
+ [`DbrxBlock`] layers.
973
+
974
+ Args:
975
+ config: DbrxConfig
976
+ """
977
+
978
+ def __init__(self, config: DbrxConfig):
979
+ super().__init__(config)
980
+ self.padding_idx = config.pad_token_id
981
+ self.vocab_size = config.vocab_size
982
+ self.emb_pdrop = config.emb_pdrop
983
+
984
+ self.wte = nn.Embedding(config.vocab_size, config.d_model,
985
+ self.padding_idx)
986
+ self.blocks = nn.ModuleList([
987
+ DbrxBlock(config, block_idx) for block_idx in range(config.n_layers)
988
+ ])
989
+ self.norm_f = nn.LayerNorm(config.d_model, bias=False)
990
+ self.gradient_checkpointing = False
991
+
992
+ # Initialize weights and apply final processing
993
+ self.post_init()
994
+
995
+ def get_input_embeddings(self) -> nn.Embedding:
996
+ return self.wte
997
+
998
+ def set_input_embeddings(self, value: nn.Embedding):
999
+ self.wte = value
1000
+
1001
+ def _autocast_input_embeddings(self,
1002
+ inputs_embeds: torch.Tensor) -> torch.Tensor:
1003
+ if inputs_embeds.device.type == 'cuda' and torch.is_autocast_enabled():
1004
+ return inputs_embeds.to(dtype=torch.get_autocast_gpu_dtype())
1005
+ elif inputs_embeds.device.type == 'cpu' and torch.is_autocast_cpu_enabled(
1006
+ ):
1007
+ return inputs_embeds.to(dtype=torch.get_autocast_cpu_dtype())
1008
+ else:
1009
+ return inputs_embeds
1010
+
1011
+ def forward(
1012
+ self,
1013
+ input_ids: Optional[torch.LongTensor] = None,
1014
+ attention_mask: Optional[torch.Tensor] = None,
1015
+ position_ids: Optional[torch.LongTensor] = None,
1016
+ past_key_values: Optional[Cache] = None,
1017
+ inputs_embeds: Optional[torch.Tensor] = None,
1018
+ use_cache: Optional[bool] = None,
1019
+ output_attentions: Optional[bool] = None,
1020
+ output_hidden_states: Optional[bool] = None,
1021
+ output_router_logits: Optional[bool] = None,
1022
+ return_dict: Optional[bool] = None,
1023
+ cache_position: Optional[torch.LongTensor] = None,
1024
+ ) -> Union[Tuple, MoeModelOutputWithPast]:
1025
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1026
+ output_hidden_states = (output_hidden_states
1027
+ if output_hidden_states is not None else
1028
+ self.config.output_hidden_states)
1029
+ output_router_logits = (output_router_logits
1030
+ if output_router_logits is not None else
1031
+ self.config.output_router_logits)
1032
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1033
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1034
+
1035
+ if (input_ids is None) ^ (inputs_embeds is not None):
1036
+ raise ValueError(
1037
+ 'You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one'
1038
+ )
1039
+
1040
+ if self.gradient_checkpointing and self.training and use_cache:
1041
+ logger.warning_once(
1042
+ '`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.'
1043
+ )
1044
+ use_cache = False
1045
+
1046
+ if inputs_embeds is None:
1047
+ inputs_embeds = self.wte(input_ids)
1048
+
1049
+ inputs_embeds = self._autocast_input_embeddings(
1050
+ inputs_embeds) # type: ignore
1051
+ inputs_embeds = nn.functional.dropout(inputs_embeds,
1052
+ p=self.emb_pdrop,
1053
+ training=self.training)
1054
+
1055
+ past_seen_tokens = 0
1056
+ if use_cache: # kept for BC (cache positions)
1057
+ if not isinstance(past_key_values, StaticCache):
1058
+ past_key_values = DynamicCache.from_legacy_cache(
1059
+ past_key_values)
1060
+ past_seen_tokens = past_key_values.get_seq_length( # type: ignore
1061
+ )
1062
+
1063
+ if cache_position is None:
1064
+ if isinstance(past_key_values, StaticCache):
1065
+ raise ValueError(
1066
+ 'cache_position is a required argument when using StaticCache.'
1067
+ )
1068
+ cache_position = torch.arange( # type: ignore
1069
+ past_seen_tokens,
1070
+ past_seen_tokens + inputs_embeds.shape[1],
1071
+ device=inputs_embeds.device)
1072
+
1073
+ if position_ids is None:
1074
+ position_ids = cache_position.unsqueeze(0) # type: ignore
1075
+
1076
+ causal_mask = self._update_causal_mask(attention_mask, inputs_embeds,
1077
+ cache_position) # type: ignore
1078
+
1079
+ # embed positions
1080
+ hidden_states = inputs_embeds
1081
+
1082
+ # decoder layers
1083
+ all_hidden_states = () if output_hidden_states else None
1084
+ all_self_attns = () if output_attentions else None
1085
+ all_router_logits = () if output_router_logits else None
1086
+ next_decoder_cache = None
1087
+
1088
+ for block in self.blocks:
1089
+ if output_hidden_states:
1090
+ all_hidden_states += (hidden_states,) # type: ignore
1091
+
1092
+ if self.gradient_checkpointing and self.training:
1093
+ block_outputs = self._gradient_checkpointing_func(
1094
+ block.__call__,
1095
+ hidden_states,
1096
+ attention_mask=causal_mask,
1097
+ position_ids=position_ids,
1098
+ past_key_values=past_key_values,
1099
+ output_attentions=output_attentions,
1100
+ output_router_logits=output_router_logits,
1101
+ use_cache=use_cache,
1102
+ cache_position=cache_position,
1103
+ )
1104
+ else:
1105
+ block_outputs = block(
1106
+ hidden_states,
1107
+ attention_mask=causal_mask,
1108
+ position_ids=position_ids,
1109
+ past_key_value=past_key_values,
1110
+ output_attentions=output_attentions,
1111
+ output_router_logits=output_router_logits,
1112
+ use_cache=use_cache,
1113
+ cache_position=cache_position,
1114
+ )
1115
+
1116
+ hidden_states = block_outputs[0]
1117
+
1118
+ if use_cache:
1119
+ next_decoder_cache = block_outputs[
1120
+ 2 if output_attentions else 1]
1121
+
1122
+ if output_attentions:
1123
+ all_self_attns += (block_outputs[1],) # type: ignore
1124
+
1125
+ if output_router_logits:
1126
+ all_router_logits += (block_outputs[-1],) # type: ignore
1127
+
1128
+ hidden_states = self.norm_f(hidden_states)
1129
+
1130
+ # add hidden states from the last decoder layer
1131
+ if output_hidden_states:
1132
+ all_hidden_states += (hidden_states,) # type: ignore
1133
+
1134
+ next_cache = None
1135
+ if use_cache:
1136
+ next_cache = (
1137
+ next_decoder_cache.to_legacy_cache() # type: ignore
1138
+ if isinstance(next_decoder_cache, Cache) else
1139
+ next_decoder_cache)
1140
+ if not return_dict:
1141
+ return tuple(v for v in [
1142
+ hidden_states, next_cache, all_hidden_states, all_self_attns,
1143
+ all_router_logits
1144
+ ] if v is not None)
1145
+ return MoeModelOutputWithPast(
1146
+ last_hidden_state=hidden_states,
1147
+ past_key_values=next_cache,
1148
+ hidden_states=all_hidden_states,
1149
+ attentions=all_self_attns,
1150
+ router_logits=all_router_logits,
1151
+ )
1152
+
1153
+ # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
1154
+ # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
1155
+ # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
1156
+ # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
1157
+ def _update_causal_mask(
1158
+ self, attention_mask: Optional[torch.Tensor],
1159
+ input_tensor: torch.Tensor,
1160
+ cache_position: torch.Tensor) -> Optional[torch.Tensor]:
1161
+ if self.config._attn_implementation == 'flash_attention_2':
1162
+ if attention_mask is not None and 0.0 in attention_mask:
1163
+ return attention_mask
1164
+ return None
1165
+
1166
+ dtype, device = input_tensor.dtype, input_tensor.device
1167
+ min_dtype = torch.finfo(dtype).min
1168
+ sequence_length = input_tensor.shape[1]
1169
+ if hasattr(self.blocks[0].norm_attn_norm.attn,
1170
+ 'past_key_value'): # static cache
1171
+ target_length = self.config.max_position_embeddings
1172
+ else: # dynamic cache
1173
+ target_length = (attention_mask.shape[-1] if isinstance(
1174
+ attention_mask, torch.Tensor) else cache_position[-1] + 1)
1175
+ target_length = int(target_length)
1176
+
1177
+ causal_mask = torch.full((sequence_length, target_length),
1178
+ fill_value=min_dtype,
1179
+ dtype=dtype,
1180
+ device=device)
1181
+ if sequence_length != 1:
1182
+ causal_mask = torch.triu(causal_mask, diagonal=1)
1183
+ causal_mask *= torch.arange(
1184
+ target_length, device=device) > cache_position.reshape(-1, 1)
1185
+ causal_mask = causal_mask[None,
1186
+ None, :, :].expand(input_tensor.shape[0], 1,
1187
+ -1, -1)
1188
+ if attention_mask is not None:
1189
+ causal_mask = causal_mask.clone(
1190
+ ) # copy to contiguous memory for in-place edit
1191
+ if attention_mask.dim() == 2:
1192
+ mask_length = attention_mask.shape[-1]
1193
+ padding_mask = causal_mask[..., :mask_length].eq(
1194
+ 0.0) * attention_mask[:, None, None, :].eq(0.0)
1195
+ causal_mask[..., :mask_length] = causal_mask[
1196
+ ..., :mask_length].masked_fill(padding_mask, min_dtype)
1197
+ elif attention_mask.dim() == 4:
1198
+ # backwards compatibility: we allow passing a 4D attention mask shorter than the input length with
1199
+ # cache. In that case, the 4D attention mask attends to the newest tokens only.
1200
+ if attention_mask.shape[
1201
+ -2] < cache_position[0] + sequence_length:
1202
+ offset = cache_position[0]
1203
+ else:
1204
+ offset = 0
1205
+ mask_shape = attention_mask.shape
1206
+ mask_slice = (attention_mask.eq(0.0)).to(
1207
+ dtype=dtype) * min_dtype
1208
+ causal_mask[:mask_shape[0], :mask_shape[1],
1209
+ offset:mask_shape[2] +
1210
+ offset, :mask_shape[3]] = mask_slice
1211
+
1212
+ if (self.config._attn_implementation == 'sdpa' and
1213
+ attention_mask is not None and
1214
+ attention_mask.device.type == 'cuda'):
1215
+ # TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400).
1216
+ is_tracing = (
1217
+ torch.jit.is_tracing() or
1218
+ isinstance(input_tensor, torch.fx.Proxy) or # type: ignore
1219
+ (hasattr(torch, '_dynamo') and torch._dynamo.is_compiling()))
1220
+ if not is_tracing and torch.any(attention_mask != 1):
1221
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
1222
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
1223
+ # Details: https://github.com/pytorch/pytorch/issues/110213
1224
+ causal_mask = AttentionMaskConverter._unmask_unattended(
1225
+ causal_mask, min_dtype)
1226
+
1227
+ return causal_mask
1228
+
1229
+
1230
+ class DbrxForCausalLM(DbrxPreTrainedModel):
1231
+
1232
+ def __init__(self, config: DbrxConfig):
1233
+ super().__init__(config)
1234
+ self.transformer = DbrxModel(config)
1235
+ self.vocab_size = config.vocab_size
1236
+ self.lm_head = nn.Linear(config.hidden_size,
1237
+ config.vocab_size,
1238
+ bias=False)
1239
+ self.router_aux_loss_coef = config.router_aux_loss_coef
1240
+ self.num_experts = config.ffn_config.moe_num_experts
1241
+ self.num_experts_per_tok = config.ffn_config.moe_top_k
1242
+
1243
+ # Initialize weights and apply final processing
1244
+ self.post_init()
1245
+
1246
+ def get_input_embeddings(self) -> nn.Embedding:
1247
+ return self.transformer.get_input_embeddings()
1248
+
1249
+ def set_input_embeddings(self, value: nn.Embedding):
1250
+ self.transformer.set_input_embeddings(value)
1251
+
1252
+ def get_output_embeddings(self) -> nn.Linear:
1253
+ return self.lm_head
1254
+
1255
+ def set_output_embeddings(self, new_embeddings: nn.Linear):
1256
+ self.lm_head = new_embeddings
1257
+
1258
+ def set_decoder(self, decoder: DbrxModel):
1259
+ self.transformer = decoder
1260
+
1261
+ def get_decoder(self) -> DbrxModel:
1262
+ return self.transformer
1263
+
1264
+ def forward(
1265
+ self,
1266
+ input_ids: Optional[torch.LongTensor] = None,
1267
+ attention_mask: Optional[torch.Tensor] = None,
1268
+ position_ids: Optional[torch.LongTensor] = None,
1269
+ past_key_values: Optional[Cache] = None,
1270
+ inputs_embeds: Optional[torch.Tensor] = None,
1271
+ labels: Optional[torch.LongTensor] = None,
1272
+ use_cache: Optional[bool] = None,
1273
+ output_attentions: Optional[bool] = None,
1274
+ output_hidden_states: Optional[bool] = None,
1275
+ output_router_logits: Optional[bool] = None,
1276
+ return_dict: Optional[bool] = None,
1277
+ cache_position: Optional[torch.LongTensor] = None,
1278
+ ) -> Union[Tuple, MoeCausalLMOutputWithPast]:
1279
+ r"""Forward function for causal language modeling.
1280
+
1281
+ Example:
1282
+ ```python
1283
+ >>> from transformers import AutoTokenizer, DbrxForCausalLM
1284
+
1285
+ >>> model = DbrxForCausalLM.from_pretrained("databricks/dbrx")
1286
+ >>> tokenizer = AutoTokenizer.from_pretrained("databricks/dbrx")
1287
+
1288
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
1289
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
1290
+
1291
+ >>> # Generate
1292
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1293
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1294
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1295
+ ```
1296
+ """
1297
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1298
+ output_hidden_states = (output_hidden_states
1299
+ if output_hidden_states is not None else
1300
+ self.config.output_hidden_states)
1301
+ output_router_logits = (output_router_logits
1302
+ if output_router_logits is not None else
1303
+ self.config.output_router_logits)
1304
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1305
+
1306
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1307
+ outputs = self.transformer(
1308
+ input_ids=input_ids,
1309
+ attention_mask=attention_mask,
1310
+ position_ids=position_ids,
1311
+ past_key_values=past_key_values,
1312
+ inputs_embeds=inputs_embeds,
1313
+ use_cache=use_cache,
1314
+ output_attentions=output_attentions,
1315
+ output_hidden_states=output_hidden_states,
1316
+ output_router_logits=output_router_logits,
1317
+ return_dict=return_dict,
1318
+ cache_position=cache_position,
1319
+ )
1320
+
1321
+ hidden_states = outputs[0]
1322
+ logits = self.lm_head(hidden_states)
1323
+
1324
+ loss = None
1325
+ if labels is not None:
1326
+ # Shift so that tokens < n predict n
1327
+ shift_logits = logits[..., :-1, :].contiguous()
1328
+ shift_labels = labels[..., 1:].contiguous()
1329
+ # Flatten the tokens
1330
+ loss_fct = nn.CrossEntropyLoss()
1331
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
1332
+ shift_labels = shift_labels.view(-1)
1333
+ # Enable model parallelism
1334
+ shift_labels = shift_labels.to(shift_logits.device)
1335
+ loss = loss_fct(shift_logits, shift_labels)
1336
+
1337
+ aux_loss = None
1338
+ if output_router_logits:
1339
+ aux_loss = load_balancing_loss_func(
1340
+ outputs.router_logits if return_dict else outputs[-1],
1341
+ self.num_experts,
1342
+ self.num_experts_per_tok,
1343
+ attention_mask,
1344
+ )
1345
+ if labels is not None and loss is not None:
1346
+ loss += self.router_aux_loss_coef * aux_loss.to(
1347
+ loss.device) # make sure to reside in the same device
1348
+
1349
+ if not return_dict:
1350
+ output = (logits,) + outputs[1:]
1351
+ return (loss,) + output if loss is not None else output
1352
+
1353
+ return MoeCausalLMOutputWithPast(
1354
+ loss=loss,
1355
+ aux_loss=aux_loss,
1356
+ logits=logits,
1357
+ past_key_values=outputs.past_key_values,
1358
+ hidden_states=outputs.hidden_states,
1359
+ attentions=outputs.attentions,
1360
+ router_logits=outputs.router_logits,
1361
+ )
1362
+
1363
+ def prepare_inputs_for_generation(
1364
+ self,
1365
+ input_ids: torch.Tensor,
1366
+ past_key_values: Optional[Cache] = None,
1367
+ attention_mask: Optional[torch.Tensor] = None,
1368
+ inputs_embeds: Optional[torch.Tensor] = None,
1369
+ **kwargs: Any) -> Dict[str, Any]:
1370
+ past_length = 0
1371
+ if past_key_values is not None:
1372
+ if isinstance(past_key_values, Cache):
1373
+ cache_length = past_key_values.get_seq_length()
1374
+ past_length = past_key_values.seen_tokens
1375
+ max_cache_length = past_key_values.get_max_length()
1376
+ else:
1377
+ cache_length = past_length = past_key_values[0][0].shape[2]
1378
+ max_cache_length = None
1379
+
1380
+ # Keep only the unprocessed tokens:
1381
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1382
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
1383
+ # input)
1384
+ if attention_mask is not None and attention_mask.shape[
1385
+ 1] > input_ids.shape[1]:
1386
+ input_ids = input_ids[:,
1387
+ -(attention_mask.shape[1] - past_length):]
1388
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
1389
+ # input_ids based on the past_length.
1390
+ elif past_length < input_ids.shape[1]:
1391
+ input_ids = input_ids[:, past_length:]
1392
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
1393
+
1394
+ # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
1395
+ if (max_cache_length is not None and attention_mask is not None and
1396
+ cache_length + input_ids.shape[1] > max_cache_length):
1397
+ attention_mask = attention_mask[:, -max_cache_length:]
1398
+
1399
+ position_ids = kwargs.get('position_ids', None)
1400
+ if attention_mask is not None and position_ids is None:
1401
+ # create position_ids on the fly for batch generation
1402
+ position_ids = attention_mask.long().cumsum(-1) - 1
1403
+ position_ids.masked_fill_(attention_mask == 0, 1)
1404
+ if past_key_values:
1405
+ position_ids = position_ids[:, -input_ids.shape[1]:]
1406
+
1407
+ if self.generation_config.cache_implementation == 'static':
1408
+ # generation with static cache
1409
+ cache_position = kwargs.get('cache_position', None)
1410
+ if cache_position is None:
1411
+ past_length = 0
1412
+ else:
1413
+ past_length = cache_position[-1] + 1
1414
+ input_ids = input_ids[:, past_length:]
1415
+ position_ids = position_ids[:,
1416
+ past_length:] if position_ids is not None else None
1417
+
1418
+ # TODO @gante we should only keep a `cache_position` in generate, and do +=1.
1419
+ # same goes for position ids. Could also help with continued generation.
1420
+ input_length = position_ids.shape[
1421
+ -1] if position_ids is not None else input_ids.shape[-1]
1422
+ cache_position = torch.arange(past_length,
1423
+ past_length + input_length,
1424
+ device=input_ids.device)
1425
+ position_ids = position_ids.contiguous(
1426
+ ) if position_ids is not None else None
1427
+
1428
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1429
+ if inputs_embeds is not None and past_key_values is None:
1430
+ model_inputs = {'inputs_embeds': inputs_embeds}
1431
+ else:
1432
+ # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
1433
+ # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114
1434
+ # TODO: use `next_tokens` directly instead.
1435
+ model_inputs = {'input_ids': input_ids.contiguous()}
1436
+
1437
+ model_inputs.update(
1438
+ { # type: ignore
1439
+ 'position_ids': position_ids,
1440
+ 'cache_position': cache_position,
1441
+ 'past_key_values': past_key_values,
1442
+ 'use_cache': kwargs.get('use_cache'),
1443
+ 'attention_mask': attention_mask,
1444
+ }
1445
+ )
1446
+ return model_inputs
1447
+
1448
+ @staticmethod
1449
+ def _reorder_cache(past_key_values: Cache, beam_idx: torch.LongTensor):
1450
+ reordered_past = ()
1451
+ for layer_past in past_key_values:
1452
+ reordered_past += (tuple(
1453
+ past_state.index_select(0, beam_idx.to(past_state.device))
1454
+ for past_state in layer_past),)
1455
+ return reordered_past
LICENSE.txt ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Databricks Open Model License
2
+
3
+ By using, reproducing, modifying, distributing, performing or displaying
4
+ any portion or element of DBRX or DBRX Derivatives, or otherwise accepting
5
+ the terms of this Agreement, you agree to be bound by this Agreement.
6
+
7
+ Version Release Date: March 27, 2024
8
+
9
+
10
+ Section 1: Definitions
11
+
12
+ “Agreement” means these terms and conditions that govern the use, reproduction,
13
+ modification, distribution, performance or display of DBRX and/or DBRX
14
+ Derivatives and any terms and conditions incorporated by reference.
15
+
16
+ “Databricks” or “we” means Databricks, Inc.
17
+
18
+ “Licensee” or “you” means you, or your employer or any other person or entity
19
+ (if you are entering into this Agreement on such person or entity’s behalf),
20
+ of the age required under applicable laws, rules or regulations to provide
21
+ legal consent and that has legal authority to bind your employer or such other
22
+ person or entity if you are entering in this Agreement on their behalf.
23
+
24
+ “DBRX Derivatives” means all (i) modifications to DBRX, (ii) works based on
25
+ DBRX and (iii) any other derivative works thereof. Outputs are not deemed DBRX
26
+ Derivatives.
27
+
28
+ “DBRX” means the foundational large language models and software and
29
+ algorithms, including machine-learning model code, trained model weights,
30
+ inference-enabling code, training-enabling code, fine-tuning enabling code,
31
+ documentation and other elements of the foregoing identified by Databricks at
32
+ https://github.com/databricks/dbrx, regardless of the source that you obtained
33
+ it from.
34
+
35
+ “Output” means the results of operating DBRX or DBRX Derivatives.
36
+
37
+ As used in this Agreement, “including” means “including without limitation.”
38
+
39
+
40
+ Section 2: License Rights and Conditions on Use and Distribution
41
+
42
+ 2.1 Grant of Rights
43
+
44
+ You are granted a non-exclusive, worldwide, non-transferable and royalty-free
45
+ limited license under Databricks’ intellectual property or other rights owned
46
+ by Databricks embodied in DBRX to use, reproduce, distribute, copy, modify,
47
+ and create derivative works of DBRX in accordance with the terms of this
48
+ Agreement.
49
+
50
+ 2.2 Reproduction and Distribution
51
+
52
+ 1. All distributions of DBRX or DBRX Derivatives must be accompanied by a
53
+ "Notice" text file that contains the following notice: "DBRX is provided
54
+ under and subject to the Databricks Open Model License, Copyright ©
55
+ Databricks, Inc. All rights reserved."
56
+
57
+ 2. If you distribute or make DBRX or DBRX Derivatives available to a third
58
+ party, you must provide a copy of this Agreement to such third party.
59
+
60
+ 3. You must cause any modified files that you distribute to carry prominent
61
+ notices stating that you modified the files.
62
+
63
+ You may add your own intellectual property statement to your modifications of
64
+ DBRX and, except as set forth in this Section, may provide additional or
65
+ different terms and conditions for use, reproduction, or distribution of DBRX
66
+ or DBRX Derivatives as a whole, provided your use, reproduction, modification,
67
+ distribution, performance, and display of DBRX or DBRX Derivatives otherwise
68
+ complies with the terms and conditions of this Agreement. Any additional or
69
+ different terms and conditions you impose must not conflict with the terms of
70
+ this Agreement and in the event of a conflict, the terms and conditions of this
71
+ Agreement shall govern over any such additional or different terms and conditions.
72
+
73
+ 2.3 Use Restrictions
74
+
75
+ You will not use DBRX or DBRX Derivatives or any Output to improve any other
76
+ large language model (excluding DBRX or DBRX Derivatives).
77
+
78
+ You will not use DBRX or DBRX Derivatives:
79
+
80
+ 1. for any restricted use set forth in the Databricks Open Model Acceptable
81
+ Use Policy identified at
82
+ https://www.databricks.com/legal/acceptable-use-policy-open-model
83
+ ("Acceptable Use Policy"), which is hereby incorporated by reference into
84
+ this Agreement; or
85
+
86
+ 2. in violation of applicable laws and regulations.
87
+
88
+ To the maximum extent permitted by law, Databricks reserves the right to
89
+ restrict (remotely or otherwise) usage of DBRX or DBRX Derivatives that
90
+ Databricks reasonably believes are in violation of this Agreement.
91
+
92
+
93
+ Section 3: Additional Commercial Terms
94
+
95
+ If, on the DBRX version release date, the monthly active users of the products
96
+ or services made available by or for Licensee, or Licensee’s affiliates, is
97
+ greater than 700 million monthly active users in the preceding calendar month,
98
+ you must request a license from Databricks, which we may grant to you in our
99
+ sole discretion, and you are not authorized to exercise any of the rights under
100
+ this Agreement unless or until Databricks otherwise expressly grants you such
101
+ rights.
102
+
103
+ If you receive DBRX or DBRX Derivatives from a direct or indirect licensee as
104
+ part of an integrated end user product, then this section (Section 3) of the
105
+ Agreement will not apply to you.
106
+
107
+
108
+ Section 4: Additional Provisions
109
+
110
+ 4.1 Updates
111
+
112
+ Databricks may update DBRX from time to time, and you must make reasonable
113
+ efforts to use the latest version of DBRX.
114
+
115
+ 4.2 Intellectual Property
116
+
117
+ a. No trademark licenses are granted under this Agreement, and in connection
118
+ with DBRX or DBRX Derivatives, neither Databricks nor Licensee may use any name
119
+ or mark owned by or associated with the other or any of its affiliates, except
120
+ as required for reasonable and customary use in describing and redistributing
121
+ DBRX or DBRX Derivatives.
122
+
123
+ b. Subject to Databricks’ ownership of DBRX and DRBX Derivatives made by or for
124
+ Databricks, with respect to any DBRX Derivatives that are made by you, as
125
+ between you and Databricks, you are and will be the owner of such DBRX
126
+ Derivatives.
127
+
128
+ c. Databricks claims no ownership rights in Outputs. You are responsible for
129
+ Outputs and their subsequent uses.
130
+
131
+ d. If you institute litigation or other proceedings against Databricks or any
132
+ entity (including a cross-claim or counterclaim in a lawsuit) alleging that
133
+ DBRX or Outputs or results therefrom, or any portion of any of the foregoing,
134
+ constitutes infringement of intellectual property or other rights owned or
135
+ licensable by you, then any licenses granted to you under this Agreement shall
136
+ terminate as of the date such litigation or claim is filed or instituted. You
137
+ will indemnify and hold harmless Databricks from and against any claim by any
138
+ third party arising out of or related to your use or distribution of DBRX or
139
+ DBRX Derivatives.
140
+
141
+ 4.3 DISCLAIMER OF WARRANTY
142
+
143
+ UNLESS REQUIRED BY APPLICABLE LAW, DBRX AND ANY OUTPUT AND RESULTS THEREFROM
144
+ ARE PROVIDED ON AN “AS IS” BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER
145
+ EXPRESS OR IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OF TITLE,
146
+ NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU
147
+ ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING OR
148
+ REDISTRIBUTING DBRX OR DBRX DERIVATIVES AND ANY OUTPUT AND ASSUME ANY RISKS
149
+ ASSOCIATED WITH YOUR USE OF DBRX OR DBRX DERIVATIVES AND ANY OUTPUT AND RESULTS.
150
+
151
+ 4.4 LIMITATION OF LIABILITY
152
+
153
+ IN NO EVENT WILL DATABRICKS OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF
154
+ LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR
155
+ OTHERWISE, ARISING OUT OF THIS AGREEMENT, FOR ANY LOST PROFITS OR ANY INDIRECT,
156
+ SPECIAL, CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN IF
157
+ DATABRICKS OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE
158
+ FOREGOING.
159
+
160
+ 4.5 Term and Termination
161
+
162
+ The term of this Agreement will commence upon your acceptance of this Agreement
163
+ or access to DBRX or DBRX Derivatives and will continue in full force and
164
+ effect until terminated in accordance with the terms and conditions herein.
165
+ Databricks may terminate this Agreement if you are in breach of any term or
166
+ condition of this Agreement. Upon termination of this Agreement, you shall
167
+ delete and cease use of DBRX or any DBRX Derivatives. Sections 1, 4.2(d), 4.3,
168
+ 4.4, and 4.6 shall survive the termination of this Agreement.
169
+
170
+ 4.6 Governing Law and Jurisdiction
171
+
172
+ This Agreement will be governed and construed under the laws of the State of
173
+ California without regard to choice of law principles, and the UN Convention
174
+ on Contracts for the International Sale of Goods does not apply to this
175
+ Agreement. The courts of California shall have exclusive jurisdiction of any
176
+ dispute arising out of this Agreement.
NOTICE.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ DBRX is provided under and subject to the Databricks Open Model License, Copyright © Databricks, Inc. All rights reserved.
README.md ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ extra_gated_heading: You need to share contact information with Databricks to access this model
3
+ extra_gated_prompt: >-
4
+
5
+ ### DBRX Terms of Use
6
+
7
+ Use of DBRX is governed by the [Databricks Open Model License](https://www.databricks.com/legal/open-model-license) and the [Databricks Open Model Acceptable Use Policy](https://www.databricks.com/legal/acceptable-use-policy-open-model).
8
+
9
+ extra_gated_fields:
10
+ First Name: text
11
+ Last Name: text
12
+ Organization: text
13
+ Purpose for Base Model Access: text
14
+ By clicking 'Submit' below, I accept the terms of the license and acknowledge that the information I provide will be collected, stored, processed, and shared in accordance with Databricks' Privacy Notice and I understand I can update my preferences at any time: checkbox
15
+ extra_gated_description: >-
16
+ The information you provide will be collected, stored, processed, and shared in accordance with Databricks [Privacy Notice](https://www.databricks.com/legal/privacynotice).
17
+ extra_gated_button_content: Submit
18
+ inference: false
19
+ license: other
20
+ license_name: databricks-open-model-license
21
+ license_link: https://www.databricks.com/legal/open-model-license
22
+ ---
23
+
24
+ # DBRX Base
25
+
26
+ * DBRX Base is a mixture-of-experts (MoE) large language model trained from scratch by Databricks.
27
+ * We are releasing both DBRX Base, a pretrained base model, and DBRX Instruct, a fine-tuned version for few-turn interactions, under [an open license](https://www.databricks.com/legal/open-model-license).
28
+ * This is the repository for DBRX Base. DBRX Instruct can be found [here](https://huggingface.co/databricks/dbrx-instruct).
29
+ * For full details on the DBRX models, please read our [technical blog post](https://www.databricks.com/blog/introducing-dbrx-new-state-art-open-llm).
30
+
31
+
32
+ ## Model Overview
33
+ DBRX is a [transformer-based](https://www.isattentionallyouneed.com/) decoder-only large language model (LLM) that was trained using next-token prediction.
34
+ It uses a *fine-grained* mixture-of-experts (MoE) architecture with 132B total parameters of which 36B parameters are active on any input.
35
+ It was pre-trained on 12T tokens of text and code data.
36
+ Compared to other open MoE models like Mixtral-8x7B and Grok-1, DBRX is fine-grained, meaning it uses a larger number of smaller experts. DBRX has 16 experts and chooses 4, while Mixtral-8x7B and Grok-1 have 8 experts and choose 2.
37
+ This provides 65x more possible combinations of experts and we found that this improves model quality.
38
+ DBRX uses rotary position encodings (RoPE), gated linear units (GLU), and grouped query attention (GQA).
39
+ It uses the GPT-4 tokenizer as provided in the [tiktoken](https://github.com/openai/tiktoken) repository.
40
+ We made these choices based on exhaustive evaluation and scaling experiments.
41
+
42
+ DBRX was pretrained on 12T tokens of carefully curated data and a maximum context length of 32K tokens.
43
+ We estimate that this data is at least 2x better token-for-token than the data we used to pretrain the MPT family of models.
44
+ This new dataset was developed using the full suite of Databricks tools, including Apache Spark™ and Databricks notebooks for data processing, and Unity Catalog for data management and governance.
45
+ We used curriculum learning for pretraining, changing the data mix during training in ways we found to substantially improve model quality.
46
+
47
+ * **Inputs:** DBRX only accepts text-based inputs and accepts a context length of up to 32768 tokens.
48
+ * **Outputs:** DBRX only produces text-based outputs.
49
+ * **Model Architecture:** More detailed information about DBRX Instruct and DBRX Base can be found in our [technical blog post](https://www.databricks.com/blog/introducing-dbrx-new-state-art-open-llm).
50
+ * **License:** [Databricks Open Model License](https://www.databricks.com/legal/open-model-license)
51
+ * **Acceptable Use Policy:** [Databricks Open Model Acceptable Use Policy](https://www.databricks.com/legal/acceptable-use-policy-open-model)
52
+ * **Version:** 1.0
53
+ * **Owner:** Databricks, Inc.
54
+
55
+
56
+ ## Usage
57
+ These are several general ways to use the DBRX models:
58
+ * DBRX Base and DBRX Instruct are available for download on HuggingFace (see our Quickstart guide below). This is the HF repository for DBRX Base; DBRX Instruct can be found [here](https://huggingface.co/databricks/dbrx-instruct).
59
+ * The DBRX model repository can be found on GitHub [here](https://github.com/databricks/dbrx).
60
+ * DBRX Base and DBRX Instruct are available with [Databricks Foundation Model APIs](https://docs.databricks.com/en/machine-learning/foundation-models/index.html) via both *Pay-per-token* and *Provisioned Throughput* endpoints. These are enterprise-ready deployments.
61
+ * For more information on how to fine-tune using LLM-Foundry, please take a look at our LLM pretraining and fine-tuning [documentation](https://github.com/mosaicml/llm-foundry/blob/main/scripts/train/README.md).
62
+
63
+
64
+ ## Quickstart Guide
65
+ **NOTE: This is DBRX Base, and has not been instruction finetuned. It has not been trained for interactive chat and is only a completion model.**
66
+ If you are looking for the finetuned model, please use [DBRX Instruct](https://huggingface.co/databricks/dbrx-instruct).
67
+
68
+ Getting started with DBRX models is easy with the `transformers` library. The model requires ~264GB of RAM and the following packages:
69
+
70
+ ```bash
71
+ pip install "transformers>=4.39.2" "tiktoken>=0.6.0"
72
+ ```
73
+
74
+ If you'd like to speed up download time, you can use the `hf_transfer` package as described by Huggingface [here](https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads).
75
+ ```bash
76
+ pip install hf_transfer
77
+ export HF_HUB_ENABLE_HF_TRANSFER=1
78
+ ```
79
+
80
+ You will need to request access to this repository to download the model. Once this is granted,
81
+ [obtain an access token](https://huggingface.co/docs/hub/en/security-tokens) with `read` permission, and supply the token below.
82
+
83
+ ### Run the model on a CPU:
84
+ ```python
85
+ from transformers import AutoTokenizer, AutoModelForCausalLM
86
+ import torch
87
+
88
+ tokenizer = AutoTokenizer.from_pretrained("databricks/dbrx-base", trust_remote_code=True, token="hf_YOUR_TOKEN")
89
+ model = AutoModelForCausalLM.from_pretrained("databricks/dbrx-base", device_map="cpu", torch_dtype=torch.bfloat16, trust_remote_code=True, token="hf_YOUR_TOKEN")
90
+
91
+ input_text = "Databricks was founded in "
92
+ input_ids = tokenizer(input_text, return_tensors="pt")
93
+
94
+ outputs = model.generate(**input_ids, max_new_tokens=100)
95
+ print(tokenizer.decode(outputs[0]))
96
+ ```
97
+
98
+ ### Run the model on multiple GPUs:
99
+ ```python
100
+ from transformers import AutoTokenizer, AutoModelForCausalLM
101
+ import torch
102
+
103
+ tokenizer = AutoTokenizer.from_pretrained("databricks/dbrx-base", trust_remote_code=True, token="hf_YOUR_TOKEN")
104
+ model = AutoModelForCausalLM.from_pretrained("databricks/dbrx-base", device_map="auto", torch_dtype=torch.bfloat16, trust_remote_code=True, token="hf_YOUR_TOKEN")
105
+
106
+ input_text = "Databricks was founded in "
107
+ input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
108
+
109
+ outputs = model.generate(**input_ids, max_new_tokens=100)
110
+ print(tokenizer.decode(outputs[0]))
111
+ ```
112
+ If your GPU system supports [FlashAttention2](https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-2), you can add `attn_implementation=”flash_attention_2”` as a keyword to `AutoModelForCausalLM.from_pretrained()` to achieve faster inference.
113
+
114
+
115
+ ## Limitations and Ethical Considerations
116
+ ### Training Dataset Limitations
117
+ The DBRX models were trained on 12T tokens of text, with a knowledge cutoff date of December 2023.
118
+
119
+ The training mix used for DBRX contains both natural-language and code examples. The vast majority of our training data is in the English language. We did not test DBRX for non-English proficiency. Therefore, DBRX should be considered a generalist model for text-based use in the English language.
120
+
121
+ DBRX does not have multimodal capabilities.
122
+
123
+ ### Associated Risks and Recommendations
124
+ All foundation models are novel technologies that carry various risks, and may output information that is inaccurate, incomplete, biased, or offensive.
125
+ Users should exercise judgment and evaluate such output for accuracy and appropriateness for their desired use case before using or sharing it.
126
+ Databricks recommends [using retrieval augmented generation (RAG)](https://www.databricks.com/glossary/retrieval-augmented-generation-rag) in scenarios where accuracy and fidelity are important.
127
+ We also recommend that anyone using or fine-tuning either DBRX Base or DBRX Instruct perform additional testing around safety in the context of their particular application and domain.
128
+
129
+
130
+ ## Intended Uses
131
+ ### Intended Use Cases
132
+ The DBRX models are open, general-purpose LLMs intended and licensed for both commercial and research applications.
133
+ They can be further fine-tuned for various domain-specific natural language and coding tasks.
134
+ DBRX Base can be used as an off-the-shelf model for text completion for general English-language and coding tasks.
135
+
136
+ Please review the Associated Risks section above, as well as the [Databricks Open Model License](https://www.databricks.com/legal/open-model-license) and [Databricks Open Model Acceptable Use Policy](https://www.databricks.com/legal/acceptable-use-policy-open-model) for further information about permissible uses of DBRX Base and its derivatives.
137
+
138
+ ### Out-of-Scope Use Cases
139
+ DBRX models are not intended to be used out-of-the-box in non-English languages and do not support native code execution, or other forms of function-calling.
140
+ DBRX models should not be used in any manner that violates applicable laws or regulations or in any other way that is prohibited by the [Databricks Open Model License](https://www.databricks.com/legal/open-model-license) and [Databricks Open Model Acceptable Use Policy](https://www.databricks.com/legal/acceptable-use-policy-open-model).
141
+
142
+
143
+ ## Training Stack
144
+ MoE models are complicated to train, and the training of DBRX Base and DBRX Instruct was heavily supported by Databricks’ infrastructure for data processing and large-scale LLM training (e.g., [Composer](https://github.com/mosaicml/composer), [Streaming](https://github.com/mosaicml/streaming), [Megablocks](https://github.com/stanford-futuredata/megablocks), and [LLM Foundry](https://github.com/mosaicml/llm-foundry)).
145
+
146
+ Composer is our core library for large-scale training.
147
+ It provides an optimized training loop, easy [checkpointing](https://docs.mosaicml.com/projects/composer/en/latest/trainer/checkpointing.html) and [logging](https://docs.mosaicml.com/projects/composer/en/latest/trainer/logging.html#wood-logging),
148
+ [FSDP](https://pytorch.org/docs/stable/fsdp.html)-based [model sharding](https://docs.mosaicml.com/projects/composer/en/latest/notes/distributed_training.html#fullyshardeddataparallel-fsdp),
149
+ convenient [abstractions](https://docs.mosaicml.com/projects/composer/en/latest/trainer/time.html), extreme customizability via [callbacks](https://docs.mosaicml.com/projects/composer/en/latest/trainer/callbacks.html), and more.
150
+
151
+ Streaming enables fast, low cost, and scalable training on large datasets from cloud storage. It handles a variety of challenges around deterministic resumption as node counts change, avoiding redundant downloads across devices, high-quality shuffling at scale, sample-level random access, and speed.
152
+
153
+ Megablocks is a lightweight library for MoE training. Crucially, it supports “dropless MoE,” which avoids inefficient padding and is intended to provide deterministic outputs for a given sequence no matter what other sequences are in the batch.
154
+
155
+ LLM Foundry ties all of these libraries together to create a simple LLM pretraining, fine-tuning, and inference experience.
156
+
157
+ DBRX was trained using proprietary optimized versions of the above open source libraries, along with our [LLM training platform](https://www.databricks.com/product/machine-learning/mosaic-ai-training).
158
+
159
+
160
+ ## Evaluation
161
+ We find that DBRX outperforms established open-source and open-weight base models on the [Databricks Model Gauntlet](https://www.databricks.com/blog/llm-evaluation-for-icl), the [Hugging Face Open LLM Leaderboard](https://huggingface.co/spaces/HuggingFaceH4/open_llm_leaderboard), and HumanEval.
162
+ The Databricks Model Gauntlet measures performance on more than 30 tasks across six categories: world knowledge, common sense reasoning, language understanding, reading comprehension, symbolic problem solving, and programming.
163
+ The Hugging Face Open LLM Leaderboard measures the average of ARC-Challenge, HellaSwag, MMLU, TruthfulQA, Winogrande and GSM8k.
164
+ HumanEval measures coding ability.
165
+
166
+ Full evaluation details can be found in our [technical blog post](https://www.databricks.com/blog/introducing-dbrx-new-state-art-open-llm).
167
+
168
+
169
+ ## Acknowledgements
170
+ The DBRX models were made possible thanks in large part to the open-source community, especially:
171
+ * The [MegaBlocks](https://arxiv.org/abs/2211.15841) library, which established a foundation for our MoE implementation.
172
+ * [PyTorch FSDP](https://arxiv.org/abs/2304.11277), which we built on for distributed training.
config.json ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "DbrxForCausalLM"
4
+ ],
5
+ "attn_config": {
6
+ "clip_qkv": 8,
7
+ "kv_n_heads": 8,
8
+ "model_type": "",
9
+ "rope_theta": 500000
10
+ },
11
+ "auto_map": {
12
+ "AutoConfig": "configuration_dbrx.DbrxConfig",
13
+ "AutoModelForCausalLM": "modeling_dbrx.DbrxForCausalLM"
14
+ },
15
+ "d_model": 6144,
16
+ "emb_pdrop": 0.0,
17
+ "ffn_config": {
18
+ "ffn_hidden_size": 10752,
19
+ "model_type": "",
20
+ "moe_jitter_eps": 0.01,
21
+ "moe_loss_weight": 0.05,
22
+ "moe_num_experts": 16,
23
+ "moe_top_k": 4
24
+ },
25
+ "initializer_range": 0.02,
26
+ "max_seq_len": 32768,
27
+ "model_type": "dbrx",
28
+ "n_heads": 48,
29
+ "n_layers": 40,
30
+ "output_router_logits": false,
31
+ "resid_pdrop": 0.0,
32
+ "router_aux_loss_coef": 0.05,
33
+ "tie_word_embeddings": false,
34
+ "torch_dtype": "bfloat16",
35
+ "transformers_version": "4.38.2",
36
+ "use_cache": true,
37
+ "vocab_size": 100352
38
+ }
configuration_dbrx.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Dbrx configuration."""
2
+ from typing import Any, Optional
3
+
4
+ from transformers.configuration_utils import PretrainedConfig
5
+ from transformers.utils import logging
6
+
7
+ logger = logging.get_logger(__name__)
8
+
9
+ DBRX_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
10
+
11
+
12
+ class DbrxAttentionConfig(PretrainedConfig):
13
+ """Configuration class for Dbrx Attention.
14
+
15
+ [`DbrxAttention`] class. It is used to instantiate attention layers
16
+ according to the specified arguments, defining the layers architecture.
17
+
18
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
19
+ documentation from [`PretrainedConfig`] for more information.
20
+
21
+ Args:
22
+ attn_pdrop (`float`, *optional*, defaults to 0.0):
23
+ The dropout probability for the attention layers.
24
+ clip_qkv (`float`, *optional*, defualts to None):
25
+ If not `None`, clip the queries, keys, and values in the attention layer to this value.
26
+ kv_n_heads (Optional[int]): For grouped_query_attention only, allow user to specify number of kv heads.
27
+ rope_theta (float): The base frequency for rope.
28
+ """
29
+
30
+ def __init__(
31
+ self,
32
+ attn_pdrop: float = 0,
33
+ clip_qkv: Optional[float] = None,
34
+ kv_n_heads: int = 1,
35
+ rope_theta: float = 10000.0,
36
+ **kwargs: Any,
37
+ ):
38
+ super().__init__(**kwargs)
39
+ self.attn_pdrop = attn_pdrop
40
+ self.clip_qkv = clip_qkv
41
+ self.kv_n_heads = kv_n_heads
42
+ self.rope_theta = rope_theta
43
+
44
+ for k in ['model_type']:
45
+ if k in kwargs:
46
+ kwargs.pop(k)
47
+ if len(kwargs) != 0:
48
+ raise ValueError(f'Found unknown {kwargs=}')
49
+
50
+ @classmethod
51
+ def from_pretrained(cls, pretrained_model_name_or_path: str,
52
+ **kwargs: Any) -> 'PretrainedConfig':
53
+ cls._set_token_in_kwargs(kwargs)
54
+
55
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path,
56
+ **kwargs)
57
+
58
+ if config_dict.get('model_type') == 'dbrx':
59
+ config_dict = config_dict['attn_config']
60
+
61
+ if 'model_type' in config_dict and hasattr(
62
+ cls,
63
+ 'model_type') and config_dict['model_type'] != cls.model_type:
64
+ logger.warning(
65
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
66
+ +
67
+ f'{cls.model_type}. This is not supported for all configurations of models and can yield errors.'
68
+ )
69
+
70
+ return cls.from_dict(config_dict, **kwargs)
71
+
72
+
73
+ class DbrxFFNConfig(PretrainedConfig):
74
+ """Configuration class for Dbrx FFN.
75
+
76
+ [`DbrxFFN`] class. It is used to instantiate feedforward layers according to
77
+ the specified arguments, defining the layers architecture.
78
+
79
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
80
+ documentation from [`PretrainedConfig`] for more information.
81
+
82
+ Args:
83
+ ffn_act_fn (dict, optional): A dict specifying activation function for the FFN.
84
+ The dict should have a key 'name' with the value being the name of
85
+ the activation function along with any additional keyword arguments.
86
+ ffn_hidden_size (int, optional): The hidden size of the feedforward network.
87
+ moe_num_experts (int, optional): The number of experts in the mixture of experts layer.
88
+ moe_top_k (int, optional): The number of experts to use in the mixture of experts layer.
89
+ moe_jitter_eps (float, optional): The jitter epsilon for the mixture of experts layer.
90
+ moe_loss_weight (float, optional): The loss weight for the mixture of experts layer.
91
+ moe_normalize_expert_weights (float, optional): The normalization factor for the expert weights.
92
+ uniform_expert_assignment (bool, optional): Whether to use uniform expert assignment.
93
+ This should only be used for benchmarking purposes.
94
+ """
95
+
96
+ def __init__(
97
+ self,
98
+ ffn_act_fn: Optional[dict] = None,
99
+ ffn_hidden_size: int = 3584,
100
+ moe_num_experts: int = 4,
101
+ moe_top_k: int = 1,
102
+ moe_jitter_eps: Optional[float] = None,
103
+ moe_loss_weight: float = 0.01,
104
+ moe_normalize_expert_weights: Optional[float] = 1,
105
+ uniform_expert_assignment: bool = False,
106
+ **kwargs: Any,
107
+ ):
108
+ super().__init__()
109
+ if ffn_act_fn is None:
110
+ ffn_act_fn = {'name': 'silu'}
111
+ self.ffn_act_fn = ffn_act_fn
112
+ self.ffn_hidden_size = ffn_hidden_size
113
+ self.moe_num_experts = moe_num_experts
114
+ self.moe_top_k = moe_top_k
115
+ self.moe_jitter_eps = moe_jitter_eps
116
+ self.moe_loss_weight = moe_loss_weight
117
+ self.moe_normalize_expert_weights = moe_normalize_expert_weights
118
+ self.uniform_expert_assignment = uniform_expert_assignment
119
+
120
+ for k in ['model_type']:
121
+ if k in kwargs:
122
+ kwargs.pop(k)
123
+ if len(kwargs) != 0:
124
+ raise ValueError(f'Found unknown {kwargs=}')
125
+
126
+ @classmethod
127
+ def from_pretrained(cls, pretrained_model_name_or_path: str,
128
+ **kwargs: Any) -> 'PretrainedConfig':
129
+ cls._set_token_in_kwargs(kwargs)
130
+
131
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path,
132
+ **kwargs)
133
+
134
+ if config_dict.get('model_type') == 'dbrx':
135
+ config_dict = config_dict['ffn_config']
136
+
137
+ if 'model_type' in config_dict and hasattr(
138
+ cls,
139
+ 'model_type') and config_dict['model_type'] != cls.model_type:
140
+ logger.warning(
141
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
142
+ +
143
+ f'{cls.model_type}. This is not supported for all configurations of models and can yield errors.'
144
+ )
145
+
146
+ return cls.from_dict(config_dict, **kwargs)
147
+
148
+
149
+ class DbrxConfig(PretrainedConfig):
150
+ """Configuration class for Dbrx.
151
+
152
+ [`DbrxModel`]. It is used to instantiate a Dbrx model according to the
153
+ specified arguments, defining the model architecture.
154
+
155
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
156
+ documentation from [`PretrainedConfig`] for more information.
157
+
158
+
159
+ Args:
160
+ d_model (`int`, *optional*, defaults to 6144):
161
+ Dimensionality of the embeddings and hidden states.
162
+ n_heads (`int`, *optional*, defaults to 48):
163
+ Number of attention heads for each attention layer in the Transformer encoder.
164
+ n_layers (`int`, *optional*, defaults to 40):
165
+ Number of hidden layers in the Transformer encoder.
166
+ max_seq_len (`int`, *optional*, defaults to 32768):
167
+ The maximum sequence length of the model.
168
+ vocab_size (`int`, *optional*, defaults to 100352):
169
+ Vocabulary size of the Dbrx model. Defines the maximum number of different tokens that can be represented by
170
+ the `inputs_ids` passed when calling [`DbrxModel`].
171
+ resid_pdrop (`float`, *optional*, defaults to 0.0):
172
+ The dropout probability applied to the attention output before combining with residual.
173
+ emb_pdrop (`float`, *optional*, defaults to 0.0):
174
+ The dropout probability for the embedding layer.
175
+ attn_config (`dict`, *optional*):
176
+ A dictionary used to configure the model's attention module.
177
+ ffn_config (`dict`, *optional*):
178
+ A dictionary used to configure the model's FFN module.
179
+ use_cache (`bool`, *optional*, defaults to `False`):
180
+ Whether or not the model should return the last key/values attentions (not used by all models).
181
+ initializer_range (`float`, *optional*, defaults to 0.02):
182
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
183
+ output_router_logits (`bool`, *optional*, defaults to `False`):
184
+ Whether or not the router logits should be returned by the model. Enabling this will also
185
+ allow the model to output the auxiliary loss. See [here]() for more details
186
+ router_aux_loss_coef (`float`, *optional*, defaults to 0.001):
187
+ The aux loss factor for the total loss.
188
+
189
+
190
+ Example:
191
+ ```python
192
+ >>> from transformers import DbrxConfig, DbrxModel
193
+
194
+ >>> # Initializing a Dbrx configuration
195
+ >>> configuration = DbrxConfig()
196
+
197
+ >>> # Initializing a model (with random weights) from the configuration
198
+ >>> model = DbrxModel(configuration)
199
+
200
+ >>> # Accessing the model configuration
201
+ >>> configuration = model.config
202
+ ```
203
+ """
204
+
205
+ model_type = 'dbrx'
206
+ attribute_map = {
207
+ 'num_attention_heads': 'n_heads',
208
+ 'hidden_size': 'd_model',
209
+ 'num_hidden_layers': 'n_layers',
210
+ 'max_position_embeddings': 'max_seq_len'
211
+ }
212
+
213
+ def __init__(
214
+ self,
215
+ d_model: int = 2048,
216
+ n_heads: int = 16,
217
+ n_layers: int = 24,
218
+ max_seq_len: int = 2048,
219
+ vocab_size: int = 32000,
220
+ resid_pdrop: float = 0.0,
221
+ emb_pdrop: float = 0.0,
222
+ attn_config: Optional[DbrxAttentionConfig] = None,
223
+ ffn_config: Optional[DbrxFFNConfig] = None,
224
+ use_cache: bool = True,
225
+ initializer_range: float = 0.02,
226
+ output_router_logits: bool = False,
227
+ router_aux_loss_coef: float = 0.05,
228
+ **kwargs: Any,
229
+ ):
230
+ if attn_config is None:
231
+ self.attn_config = DbrxAttentionConfig()
232
+ elif isinstance(attn_config, dict):
233
+ self.attn_config = DbrxAttentionConfig(**attn_config)
234
+ else:
235
+ self.attn_config = attn_config
236
+
237
+ if ffn_config is None:
238
+ self.ffn_config = DbrxFFNConfig()
239
+ elif isinstance(ffn_config, dict):
240
+ self.ffn_config = DbrxFFNConfig(**ffn_config)
241
+ else:
242
+ self.ffn_config = ffn_config
243
+
244
+ self.d_model = d_model
245
+ self.n_heads = n_heads
246
+ self.n_layers = n_layers
247
+ self.max_seq_len = max_seq_len
248
+ self.vocab_size = vocab_size
249
+ self.resid_pdrop = resid_pdrop
250
+ self.emb_pdrop = emb_pdrop
251
+ self.use_cache = use_cache
252
+ self.initializer_range = initializer_range
253
+ self.output_router_logits = output_router_logits
254
+ self.router_aux_loss_coef = router_aux_loss_coef
255
+
256
+ tie_word_embeddings = kwargs.pop('tie_word_embeddings', False)
257
+ if tie_word_embeddings:
258
+ raise ValueError(
259
+ 'tie_word_embeddings is not supported for Dbrx models.')
260
+
261
+ super().__init__(
262
+ tie_word_embeddings=tie_word_embeddings,
263
+ **kwargs,
264
+ )
generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "eos_token_id": [
4
+ 100257
5
+ ],
6
+ "transformers_version": "4.38.2"
7
+ }
huggingface-metadata.txt ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ url: https://huggingface.co/databricks/dbrx-base
2
+ branch: main
3
+ download date: 2024-03-30 00:19:12
4
+ sha256sum:
5
+ 6fc16f714bc5bdae2a9c712ebbd0c60282d51d8102446bdcb42bf73fbcd789cd model-00001-of-00061.safetensors
6
+ 7b33be8a577936012420efcdb9fb42394820083dcc52279ab39f1f469446d92c model-00002-of-00061.safetensors
7
+ 5edb3101b011b9ac53216a8a89344854932cba4a116462d269bafab67b2504ea model-00003-of-00061.safetensors
8
+ 3ab9725625b4be1dbe497733c7da6a826763da7fd4ecb4c7a1b0a01730ad5f00 model-00004-of-00061.safetensors
9
+ e70596f19ce2b40d833456da403f504af4b33c7a28fc663ece64e4b9b5c023c4 model-00005-of-00061.safetensors
10
+ 6a59eb220f7174521cbb8df029c5665f9bd70afd919e5432461c9d31c2498b41 model-00006-of-00061.safetensors
11
+ 2e77dd3c8552550dd8db66372b1680992d438bbfb6a2950071f1718e1d197ef5 model-00007-of-00061.safetensors
12
+ 49830c115d1c6dd80ee1f187e4c7aa2f44e8565ad8a50ff2474d276e6a7ea436 model-00008-of-00061.safetensors
13
+ 3535faaac65410f8e78bd77bd6ab9c22638a7e753f7d7d73a255b8ee71909ed6 model-00009-of-00061.safetensors
14
+ 388344ef7a364c3cf90ce50d7cb44e477c17233966a8ff84cfa0e998455e5fba model-00010-of-00061.safetensors
15
+ 8f423206b83056355629508621b422e839ef0544c9f6099d3a946ad81ffe13c5 model-00011-of-00061.safetensors
16
+ 46630da164ed5472036c49e1247fd7157f1ca1a3ab25957d27a56dd80f1ec4e9 model-00012-of-00061.safetensors
17
+ 7a09b47e2f13c74b72275226ed5ad56191eea1883328dd5ccc8291a49ab64908 model-00013-of-00061.safetensors
18
+ 7708780cfcba2211f806b2164550ea6982ce7e68dc47db9e4fe42a4727e58d89 model-00014-of-00061.safetensors
19
+ 3c208a59d20af1d5504dd72611def467a79f37d8f047b53c2588d2a594de6190 model-00015-of-00061.safetensors
20
+ bb09a5c3e813be20033ca5b75308bebc0eea127bbe4d5c254d4877dae9080731 model-00016-of-00061.safetensors
21
+ 9b5dfe0cbdfcd0131aad4605d16339f6ee3e94737bb01b21bcfb54ca4a9fe325 model-00017-of-00061.safetensors
22
+ 627ccca7006bc60d9374a8b0e59d1960fc300cda6a9caa6e1b187ea9c3af58e5 model-00018-of-00061.safetensors
23
+ 2d327dda14ac1c8a0c4f6590a1cd49861164b26aaae1743c0bdf9b0ac62f2b93 model-00019-of-00061.safetensors
24
+ fc042ddef23192a39cf83b475a3ef29c5dea84a33e1584ec1a90f309f1fd4ba1 model-00020-of-00061.safetensors
25
+ 23143b0dda50c67186262a2d22182fa72c3f9527482d6b693736a43a7c1e5c15 model-00021-of-00061.safetensors
26
+ cf47ac22655c7a9df91ccb39a5b3b5753d7a5e36a122138329c9a9811e7a2953 model-00022-of-00061.safetensors
27
+ 7bb213db0ee84652be17d60b19ba5eba2c7aa4a0e6e858a2b5e6e26281bbfbf5 model-00023-of-00061.safetensors
28
+ 099d8224582075ddaa2c05d91cf14d12f6d8d0f9ba88a871b0ac5fc8191244eb model-00024-of-00061.safetensors
29
+ 6ad2df7ec70500e020701b80f723af9cb42724c3b66a199453b1480a02f8a199 model-00025-of-00061.safetensors
30
+ 6144ecaada0d9662d8572b23e5116d4ed05a3e634ddcaaf42720db684b0435e5 model-00026-of-00061.safetensors
31
+ d5d60b11ef83487c2751a9a43043ff2319843653f2c792abdfd604f36d3d0848 model-00027-of-00061.safetensors
32
+ 1921574ab5bda287a6ea20e3ae2d091fde29973059a688944f204dbd7724b147 model-00028-of-00061.safetensors
33
+ 31291c52c7fe8e6ec8465d42767a9ea1603fe64311e18a4d3c36b27a905be62f model-00029-of-00061.safetensors
34
+ 3c99753094273763636069d6d9b2a3474064b38148d57915c2916d2b1e9f0c7a model-00030-of-00061.safetensors
35
+ 33a05e332cae49e0ebe746944f2e4892ff0517876f8e5c06a23f9333ed32e4e0 model-00031-of-00061.safetensors
36
+ f817a0338220e16a62cce0cbb56cc6859ed85ba2af145ec1a5b336c77ad362dd model-00032-of-00061.safetensors
37
+ 524cff55ac50174e3bc3fab234749217403f106c32d87b3f97af316115eb101a model-00033-of-00061.safetensors
38
+ f9f51a8c51ab37e4ca31ccc2e4aba30ea206e32ad0b864db7d2404d770000476 model-00034-of-00061.safetensors
39
+ 6b16806ace65db95d602c39d5e372901a265a1e6d3b0d757679cc05aa7fab7e8 model-00035-of-00061.safetensors
40
+ f2830ff0574e31e39ba6b434187fbfd3b2bb8264a822fe725d93f1fc5348b3d5 model-00036-of-00061.safetensors
41
+ 9ed804ce59a2a8f2c1cff4d5bde7782b866961aa88a5846e4d0250d8f51c84c8 model-00037-of-00061.safetensors
42
+ c5446321bb918a0e269ccc5a07eb002fca71e0047fd7c08bb7d7dca6709ee999 model-00038-of-00061.safetensors
43
+ d3e03720f169d8923e3457bad304ad3ba0c75037ceb13bd6716807d01403fdf5 model-00039-of-00061.safetensors
44
+ b48bee2e8c38c105cc7a8434cc2cf845b3d023d5a849814c17b65ad896fe2c0f model-00040-of-00061.safetensors
45
+ 6d14cabc491a883d8260aa3bea801eb757cb38db51c4081263237fdfc4053400 model-00041-of-00061.safetensors
46
+ 178ea07958e6b2e573025d2bf61dde4cf0638b8dc6e81311c089400d6fa81717 model-00042-of-00061.safetensors
47
+ b202858f99636de6933dccd1c487d590cfca76b8f9cddee876af04692b38ab80 model-00043-of-00061.safetensors
48
+ c9c823bc0f84dacb3996d598b4795c400361a9c657f36b29c40b4c7982ad0ed9 model-00044-of-00061.safetensors
49
+ 97ac08f2364a4b67f6a17274c4770637c98d7306eca87ec797e3cd754447b1e3 model-00045-of-00061.safetensors
50
+ 23e05dbc72e799bb043b6497be35cacaebfda43c8057dd04db07e8ac03e74751 model-00046-of-00061.safetensors
51
+ 77e4a2617f18043aecb00d4bb04c41f1ab12eb26f5f78038b53efe70e11edc4c model-00047-of-00061.safetensors
52
+ 6f0226bfb4b796f64deafbbeb80b2c7ea5314b6038afc607bc42b3dbd0a79313 model-00048-of-00061.safetensors
53
+ ced55f41dc87ca44d1adc2be76144bf0de0045b340eb3231ae3e16bd16a1f6d5 model-00049-of-00061.safetensors
54
+ 6634f09f0cf60e974c23ca533d00f99868e71c05f94ef3e021a851926f5acc6e model-00050-of-00061.safetensors
55
+ 5f3a6f60f6c73f1fb1097ba66d6672bd1d2c24a5c4ec89ebe0e1308f82dccb4e model-00051-of-00061.safetensors
56
+ 7846d13c5af42185217dd4cacabb389dc2fdb8bee06cfa4591f079dbceca3033 model-00052-of-00061.safetensors
57
+ d898218d5f978b6cc1bff5861d3bb263c5a7c88825e6cdecc5d19be953615b93 model-00053-of-00061.safetensors
58
+ e9509ff00409d2019de6380d8ffceb718f1b1edc433ffe30ced2b659f6e381db model-00054-of-00061.safetensors
59
+ 074057b7b5436e6dc77571895a15743c6de34d54e69378c9937d5f8017a32a35 model-00055-of-00061.safetensors
60
+ dfbd182dfc66404837b09cfed5c1ccee2b9f22745e365a556473bed1ca2fb454 model-00056-of-00061.safetensors
61
+ 25a1ca1e6fa15d26b54f20f93cbb3c1be6f5858bd9123e8e68d4afe03d525ded model-00057-of-00061.safetensors
62
+ d9119d4edf798b355b6625738c540b807916bd605eb21a157d533bd5d4e59fd2 model-00058-of-00061.safetensors
63
+ 95a0fed1c3fa366feae66bdebcef87feced28c54b1081b458bd4503435ddd928 model-00059-of-00061.safetensors
64
+ 32a4d9683495cffbed6ca2e7fa9b7adc47d2d0171984cda92b4535e0e25f6903 model-00060-of-00061.safetensors
65
+ 7d4c98ee4f4f01a06854d8b3a4959484fafa8b7b4371ad25d35a77c880066c43 model-00061-of-00061.safetensors
model-00001-of-00061.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d3c2a81b624cbb921015e97a43a20598b529c193748f8a2a78913016f4c066e7
3
+ size 3523439352
model-00002-of-00061.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ed001899eb392f361397d83bb4e90111ee10677023c68aca7ef92986f00171c1
3
+ size 4404245144
model-00003-of-00061.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8a53fc538e17f263dbafac832259701d71c039ceb0e0a133411bd868f67c1412
3
+ size 4227862560
model-00004-of-00061.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a9f781b7cdfffb7c0ce24de37b45bbcbe6ac74ef0f666429c4faf14e4ac1ce37
3
+ size 4404245144
model-00005-of-00061.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b95e369dd864c52a4249deb1180a1213ba1aa4ff32236c28c59b008a3373be76
3
+ size 4404245144
model-00006-of-00061.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b8a9a26d88087d1f17f398bad35b00f8c9df6d45684c2d27ed51c6fd8479da0a
3
+ size 4227862560
model-00007-of-00061.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:23a54c86412163b35be1a339b8b414e44f9591eb2b4b8f926d215be0a4925521
3
+ size 4404245144
model-00008-of-00061.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e62c2e4874bbce04a6aeb208255d57863539740adb4f889c742f9040ba2f3d96
3
+ size 4404245144
model-00009-of-00061.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f06f6d9014cf96eb3650c99db4045d37c8b7e34c5117e9a2738d400de1658860
3
+ size 4227862560
model-00010-of-00061.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f85d6cd1c218a5a69d01157f87cdbd50ad097058cbd0bf4eaf3dcca4dc13e8d7
3
+ size 4404245144
model-00011-of-00061.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d6cbb405f375b9a9fb6a7fae49875ec4b71057a83f8809af7485e80a1f8bcad5
3
+ size 4404245144
model-00012-of-00061.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e56c3d4c9786fcb2eeeddca5f1ffadd4e8a4b64514c807595572ba0fc0f61b76
3
+ size 4227862560
model-00013-of-00061.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:947da27479f93c092e40bb4436e04b6d73868f85c413a4e5401300eaec099414
3
+ size 4404245144
model-00014-of-00061.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0af3684c789d505e5800532e7c0b093a0534b31cc4fe3ff7d2cdd70d3af103df
3
+ size 4404245144
model-00015-of-00061.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:55f55d4b560746620978473685d5a6c87a5eb9ac780b97592a77a27e9fd67dce
3
+ size 4227862560
model-00016-of-00061.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a7689e83173a2e0055e7869d7ae681940d55084c4ec0cde3b1ac066ee42342b4
3
+ size 4404245168
model-00017-of-00061.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6148c0ed255843e70f502c59328e9c40f639e87a4c540708715cdaac823ca474
3
+ size 4404245184
model-00018-of-00061.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e1b24cd6fbcb7458e33aa6db4702aad519a32637b15fda82454747f5e9432e13
3
+ size 4227862592
model-00019-of-00061.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c4da75153d9ed71593a0f26d8751837a3d9d43930025aaf20012045946ace9de
3
+ size 4404245184
model-00020-of-00061.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8b4b005943135f0a19ccd47ec2977305514d0956b59962fa780378cbfa01cd13
3
+ size 4404245184
model-00021-of-00061.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:45bb5601b485a573af20d7b23bad58e0f4146b349644bd8f728b06252102176b
3
+ size 4227862592
model-00022-of-00061.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:186133d4fa899fc798454adf2edf927e3c15dd1ac4325fc38123202d52bea8d8
3
+ size 4404245184
model-00023-of-00061.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9f39b6544788e0f7cbe61e78fc0f5b4e7909702b4debd54d08ce8738b8144713
3
+ size 4404245184
model-00024-of-00061.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2523a4be387557beb04260b0d08cf8a37d0a87006197c9103f365e3cfad06daf
3
+ size 4227862592
model-00025-of-00061.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c707561d9b07c3a4ee99c6949760eee338e471f013a581ceb2f32a473280cfaa
3
+ size 4404245184
model-00026-of-00061.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:20c4aeafbbcf1806560fcbfd8fd17d5e15355405f2109cf04354a6d353a09cad
3
+ size 4404245184
model-00027-of-00061.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:41879f07d50975e532be376f7efc62bcb5be7719b9303ef8c807706c43c2ad43
3
+ size 4227862592
model-00028-of-00061.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:da3729c87038162e7f6751435cfe69ac2a9b900102fb8bf9135d6dee46dba5e2
3
+ size 4404245184
model-00029-of-00061.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:85c7a0900e23458a71566d7a396ad00ad20103f7c22895e85c69e29599e50755
3
+ size 4404245184
model-00030-of-00061.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:af610db5f894204a2c2c6891f4644b5ae4f38e429f1b33978d3567981aa61faa
3
+ size 4227862592
model-00031-of-00061.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8815b02db955ec7b62b7940a60fffaaed34ed066e0a752d083c77f966291f12c
3
+ size 4404245184
model-00032-of-00061.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f85f2db6f5ee9ef26ec51166dce1b9b439975f112e793b3ef2ec7f243b7d2650
3
+ size 4404245184
model-00033-of-00061.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fa6f00ab09861368f0aaf66ec92bf43b17a70cd8bfb5be6264aa5b21d9a36a25
3
+ size 4227862592
model-00034-of-00061.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:89fbac47b8d8fd6487d15b9c8b7b79578204c3f102fcb6c5dbf1943ac64647a2
3
+ size 4404245184
model-00035-of-00061.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:55fbb3c3bcd978b94ec29434ba6c2741d3687c8fd0c9b204857758566f85f1f6
3
+ size 4404245184
model-00036-of-00061.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:792058de69a54860e03de6dcd68bb88c758e0e701f9a9054ef95e74b66f641c1
3
+ size 4227862592
model-00037-of-00061.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0e44c4782d6f922d6538f55628604b4b0bbf06999d7c2828b9b62b958ed80db0
3
+ size 4404245184
model-00038-of-00061.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0d3a9421a898bd09894fb6569b09796cab23aa564a043a97565b91534f5ed4fe
3
+ size 4404245184
model-00039-of-00061.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b597724c9c431c9b66ef7f1c5c98c572e1eb2cc04eab00b4b5578c804f7dabe2
3
+ size 4227862592
model-00040-of-00061.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c286a04587a4c2798f27e6cb8cb1d8cbaf5b77a6f253f2988f5d1954d02ea718
3
+ size 4404245184
model-00041-of-00061.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ff71194ea96b94d3ada6ded57a5d92d4cd3199e85cda0568b17b5fa93e0347a5
3
+ size 4404245184
model-00042-of-00061.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:93347f13c06c95ed4131c03f1bcbade4aaf2fc9e8e8738483467ed482e1ff556
3
+ size 4227862592