chargoddard commited on
Commit
bd87b10
1 Parent(s): 5c7d89e

Add example linear+ntk monkeypatch

Browse files
Files changed (1) hide show
  1. llama_scale_rope.py +76 -0
llama_scale_rope.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import transformers
3
+
4
+
5
+ class LlamaComboScaledRope(torch.nn.Module):
6
+ """
7
+ stolen from: https://huggingface.co/kaiokendev/superhot-13b-8k-no-rlhf-test
8
+ https://github.com/jquesnelle/scaled-rope
9
+ """
10
+
11
+ def __init__(
12
+ self,
13
+ dim,
14
+ max_position_embeddings=2048,
15
+ base=10000,
16
+ scale=1,
17
+ alpha=1,
18
+ device=None,
19
+ ):
20
+ super().__init__()
21
+ if alpha != 1:
22
+ base = base * alpha ** (dim / (dim - 2))
23
+
24
+ self.scale = 1 / scale
25
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
26
+ self.register_buffer("inv_freq", inv_freq)
27
+
28
+ # Build here to make `torch.jit.trace` work.
29
+ self.max_seq_len_cached = max_position_embeddings
30
+ t = torch.arange(
31
+ self.max_seq_len_cached,
32
+ device=self.inv_freq.device,
33
+ dtype=self.inv_freq.dtype,
34
+ )
35
+ t *= self.scale
36
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
37
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
38
+ emb = torch.cat((freqs, freqs), dim=-1)
39
+ dtype = torch.get_default_dtype()
40
+ self.register_buffer(
41
+ "cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False
42
+ )
43
+ self.register_buffer(
44
+ "sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False
45
+ )
46
+
47
+ def forward(self, x, seq_len=None):
48
+ # x: [bs, num_attention_heads, seq_len, head_size]
49
+ # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
50
+ if seq_len > self.max_seq_len_cached:
51
+ self.max_seq_len_cached = seq_len
52
+ t = torch.arange(
53
+ self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype
54
+ )
55
+ t *= self.scale
56
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
57
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
58
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
59
+ self.register_buffer(
60
+ "cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False
61
+ )
62
+ self.register_buffer(
63
+ "sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False
64
+ )
65
+ return (
66
+ self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
67
+ self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
68
+ )
69
+
70
+
71
+ def llama_scale_rope(model: transformers.LlamaForCausalLM, **kwargs):
72
+ kwargs.update({"device": model.device})
73
+ for layer in model.model.layers:
74
+ layer.self_attn.rotary_emb = LlamaComboScaledRope(
75
+ layer.self_attn.head_dim, **kwargs
76
+ )