Srinivasan Iyer sviyer commited on
Commit
739dc71
·
unverified ·
1 Parent(s): 6fbaf72

Add rope fp32 (#43)

Browse files

* Log model

* Add flag for rope outer in fp32

---------

Co-authored-by: Srini Iyer <sviyer@meta.com>

bytelatent/base_transformer.py CHANGED
@@ -45,6 +45,7 @@ class BaseTransformerArgs(BaseModel):
45
  norm_eps: float = 1e-5
46
 
47
  rope_theta: float = 10000.0
 
48
 
49
  init_base_std: float | None = None
50
  init_std_factor: InitStdFactor = InitStdFactor.DISABLED
@@ -78,7 +79,12 @@ def repeat_kv(x: torch.Tensor, n_rep: int, dim: int) -> torch.Tensor:
78
  )
79
 
80
 
81
- def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
 
 
 
 
 
82
  """
83
  Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
84
 
@@ -96,6 +102,9 @@ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
96
  """
97
  freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
98
  t = torch.arange(end, device=freqs.device)
 
 
 
99
  freqs = torch.outer(t, freqs).float()
100
 
101
  cos, sin = freqs.cos(), freqs.sin()
@@ -232,22 +241,37 @@ class RotaryEmbedding(torch.nn.Module):
232
  RotaryEmbedding Module
233
  """
234
 
235
- def __init__(self, theta: float, head_dim: int, max_seqlen: int = 1024):
 
 
 
 
 
 
236
  super().__init__()
237
 
238
  self.theta = theta
239
  self.head_dim = head_dim
240
  self.max_seqlen = max_seqlen
 
241
 
242
  self.register_buffer(
243
  "freqs_cis",
244
- precompute_freqs_cis(dim=head_dim, end=max_seqlen, theta=theta),
 
 
 
 
 
245
  persistent=False,
246
  )
247
 
248
  def reset_parameters(self):
249
  self.freqs_cis[...] = precompute_freqs_cis(
250
- dim=self.head_dim, end=self.max_seqlen, theta=self.theta
 
 
 
251
  )
252
 
253
  def forward(
@@ -577,6 +601,7 @@ class BaseTransformer(nn.Module):
577
  theta=args.rope_theta,
578
  head_dim=args.head_dim or args.dim // args.n_heads,
579
  max_seqlen=args.max_seqlen,
 
580
  )
581
  self.eos_id = args.eos_id
582
 
 
45
  norm_eps: float = 1e-5
46
 
47
  rope_theta: float = 10000.0
48
+ rope_use_fp32_in_outer_product: bool = False
49
 
50
  init_base_std: float | None = None
51
  init_std_factor: InitStdFactor = InitStdFactor.DISABLED
 
79
  )
80
 
81
 
82
+ def precompute_freqs_cis(
83
+ dim: int,
84
+ end: int,
85
+ theta: float = 10000.0,
86
+ rope_use_fp32_in_outer_product: bool = False,
87
+ ):
88
  """
89
  Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
90
 
 
102
  """
103
  freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
104
  t = torch.arange(end, device=freqs.device)
105
+ if rope_use_fp32_in_outer_product:
106
+ t = t.to(torch.float32)
107
+
108
  freqs = torch.outer(t, freqs).float()
109
 
110
  cos, sin = freqs.cos(), freqs.sin()
 
241
  RotaryEmbedding Module
242
  """
243
 
244
+ def __init__(
245
+ self,
246
+ theta: float,
247
+ head_dim: int,
248
+ max_seqlen: int = 1024,
249
+ rope_use_fp32_in_outer_product: bool = False,
250
+ ):
251
  super().__init__()
252
 
253
  self.theta = theta
254
  self.head_dim = head_dim
255
  self.max_seqlen = max_seqlen
256
+ self.rope_use_fp32_in_outer_product = rope_use_fp32_in_outer_product
257
 
258
  self.register_buffer(
259
  "freqs_cis",
260
+ precompute_freqs_cis(
261
+ dim=head_dim,
262
+ end=max_seqlen,
263
+ theta=theta,
264
+ rope_use_fp32_in_outer_product=self.rope_use_fp32_in_outer_product,
265
+ ),
266
  persistent=False,
267
  )
268
 
269
  def reset_parameters(self):
270
  self.freqs_cis[...] = precompute_freqs_cis(
271
+ dim=self.head_dim,
272
+ end=self.max_seqlen,
273
+ theta=self.theta,
274
+ rope_use_fp32_in_outer_product=self.rope_use_fp32_in_outer_product,
275
  )
276
 
277
  def forward(
 
601
  theta=args.rope_theta,
602
  head_dim=args.head_dim or args.dim // args.n_heads,
603
  max_seqlen=args.max_seqlen,
604
+ rope_use_fp32_in_outer_product=args.rope_use_fp32_in_outer_product,
605
  )
606
  self.eos_id = args.eos_id
607
 
bytelatent/model/blt.py CHANGED
@@ -414,7 +414,7 @@ class ByteLatentTransformerArgs(BaseTransformerArgs):
414
  patch_in_forward: bool = False
415
 
416
  # Architecture and dimensions
417
- dim_token: int = 256
418
  dim_global: int = 512
419
  dim_local_decoder: int = 512
420
  dim_local_encoder: int = 512
@@ -523,10 +523,6 @@ class ByteLatentTransformerArgs(BaseTransformerArgs):
523
  use_fsdp: bool = True
524
  attn_to_keep: str = "all"
525
 
526
- # RoPE parameters
527
- rope_theta: float = 10000.0
528
- rope_use_fp32_in_outer_product: bool = False
529
-
530
  # Parameter mixing
531
  pm_size: int = 0
532
 
@@ -619,6 +615,7 @@ def create_local_encoder(args: ByteLatentTransformerArgs) -> LocalEncoder:
619
  sliding_window=args.local_attention_window_len,
620
  use_rope=args.use_rope,
621
  rope_theta=args.rope_theta,
 
622
  init_base_std=args.init_base_std,
623
  init_std_factor=args.init_std_factor,
624
  n_kv_heads=args.n_kv_heads,
@@ -661,6 +658,7 @@ def create_local_decoder(args: ByteLatentTransformerArgs) -> LocalDecoder:
661
  sliding_window=args.local_attention_window_len,
662
  use_rope=args.use_rope,
663
  rope_theta=args.rope_theta,
 
664
  init_base_std=args.init_base_std,
665
  init_std_factor=args.init_std_factor,
666
  n_kv_heads=args.n_kv_heads,
 
414
  patch_in_forward: bool = False
415
 
416
  # Architecture and dimensions
417
+ dim_token: int | None = None
418
  dim_global: int = 512
419
  dim_local_decoder: int = 512
420
  dim_local_encoder: int = 512
 
523
  use_fsdp: bool = True
524
  attn_to_keep: str = "all"
525
 
 
 
 
 
526
  # Parameter mixing
527
  pm_size: int = 0
528
 
 
615
  sliding_window=args.local_attention_window_len,
616
  use_rope=args.use_rope,
617
  rope_theta=args.rope_theta,
618
+ rope_use_fp32_in_outer_product=args.rope_use_fp32_in_outer_product,
619
  init_base_std=args.init_base_std,
620
  init_std_factor=args.init_std_factor,
621
  n_kv_heads=args.n_kv_heads,
 
658
  sliding_window=args.local_attention_window_len,
659
  use_rope=args.use_rope,
660
  rope_theta=args.rope_theta,
661
+ rope_use_fp32_in_outer_product=args.rope_use_fp32_in_outer_product,
662
  init_base_std=args.init_base_std,
663
  init_std_factor=args.init_std_factor,
664
  n_kv_heads=args.n_kv_heads,
bytelatent/model/local_models.py CHANGED
@@ -86,6 +86,7 @@ class LocalModelBase(nn.Module):
86
  theta=args.rope_theta,
87
  head_dim=args.head_dim or args.dim // args.n_heads,
88
  max_seqlen=args.max_seqlen,
 
89
  )
90
  self.pos_embeddings = None
91
 
 
86
  theta=args.rope_theta,
87
  head_dim=args.head_dim or args.dim // args.n_heads,
88
  max_seqlen=args.max_seqlen,
89
+ rope_use_fp32_in_outer_product=args.rope_use_fp32_in_outer_product,
90
  )
91
  self.pos_embeddings = None
92
 
bytelatent/train.py CHANGED
@@ -325,6 +325,7 @@ def train(args: TrainArgs):
325
 
326
  # log model size
327
 
 
328
  logger.info(f"Model size: {model_param_count:,} total parameters")
329
 
330
  gpu_memory_monitor = GPUMemoryMonitor("cuda")
 
325
 
326
  # log model size
327
 
328
+ logger.info(model)
329
  logger.info(f"Model size: {model_param_count:,} total parameters")
330
 
331
  gpu_memory_monitor = GPUMemoryMonitor("cuda")