jingyaogong
commited on
Commit
•
c63b0c9
1
Parent(s):
6d52f6f
Upload model.py
Browse files
model.py
CHANGED
@@ -1,6 +1,8 @@
|
|
1 |
import math
|
2 |
import struct
|
3 |
import inspect
|
|
|
|
|
4 |
from .LMConfig import LMConfig
|
5 |
from typing import Any, Optional, Tuple
|
6 |
import numpy as np
|
@@ -80,26 +82,15 @@ class Attention(nn.Module):
|
|
80 |
self.dropout = args.dropout
|
81 |
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') and args.flash_attn
|
82 |
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
self.register_buffer("mask", mask)
|
88 |
|
89 |
-
def forward(self, x: torch.Tensor, pos_cis: torch.Tensor,
|
90 |
bsz, seqlen, _ = x.shape
|
91 |
-
if use_kv_cache and self.eval():
|
92 |
-
if self.k_cache is None or self.k_cache.shape[1] != x.shape[1] - 1:
|
93 |
-
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
|
94 |
-
else:
|
95 |
-
token = x[:, -1:, :]
|
96 |
-
xq = torch.cat((torch.zeros_like(x[:, :-1, :]), self.wq(token)), dim=1)
|
97 |
-
xk = torch.cat((self.k_cache, self.wk(token)), dim=1)
|
98 |
-
xv = torch.cat((self.v_cache, self.wv(token)), dim=1)
|
99 |
|
100 |
-
|
101 |
-
else:
|
102 |
-
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
|
103 |
|
104 |
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
|
105 |
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
|
@@ -107,6 +98,13 @@ class Attention(nn.Module):
|
|
107 |
|
108 |
xq, xk = apply_rotary_emb(xq, xk, pos_cis)
|
109 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
xk = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
|
111 |
xv = repeat_kv(xv, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
|
112 |
|
@@ -114,13 +112,12 @@ class Attention(nn.Module):
|
|
114 |
xk = xk.transpose(1, 2)
|
115 |
xv = xv.transpose(1, 2)
|
116 |
|
117 |
-
if self.flash:
|
118 |
output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv, attn_mask=None,
|
119 |
dropout_p=self.dropout if self.training else 0.0,
|
120 |
is_causal=True)
|
121 |
else:
|
122 |
scores = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_dim)
|
123 |
-
assert hasattr(self, 'mask')
|
124 |
scores = scores + self.mask[:, :, :seqlen, :seqlen] # (bs, n_local_heads, seqlen, cache_len + seqlen)
|
125 |
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
|
126 |
scores = self.attn_dropout(scores)
|
@@ -304,8 +301,8 @@ class TransformerBlock(nn.Module):
|
|
304 |
dropout=args.dropout,
|
305 |
)
|
306 |
|
307 |
-
def forward(self, x, pos_cis,
|
308 |
-
h = x + self.attention(self.attention_norm(x), pos_cis,
|
309 |
out = h + self.feed_forward(self.ffn_norm(h))
|
310 |
return out
|
311 |
|
@@ -351,18 +348,21 @@ class Transformer(PreTrainedModel):
|
|
351 |
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
352 |
|
353 |
def forward(self, tokens: Optional[torch.Tensor] = None, targets: Optional[torch.Tensor] = None,
|
354 |
-
|
|
|
355 |
if 'input_ids' in keyargs:
|
356 |
tokens = keyargs['input_ids']
|
357 |
if 'attention_mask' in keyargs:
|
358 |
targets = keyargs['attention_mask']
|
|
|
|
|
359 |
|
360 |
_bsz, seqlen = tokens.shape
|
361 |
h = self.tok_embeddings(tokens)
|
362 |
h = self.dropout(h)
|
363 |
-
pos_cis = self.pos_cis[:seqlen]
|
364 |
for idx, layer in enumerate(self.layers):
|
365 |
-
h = layer(h, pos_cis,
|
366 |
|
367 |
h = self.norm(h)
|
368 |
|
@@ -375,20 +375,24 @@ class Transformer(PreTrainedModel):
|
|
375 |
|
376 |
self.OUT.__setitem__('logits', logits)
|
377 |
self.OUT.__setitem__('last_loss', self.last_loss)
|
378 |
-
|
379 |
return self.OUT
|
380 |
|
381 |
@torch.inference_mode()
|
382 |
-
def generate(self, idx, eos, max_new_tokens, temperature=0.7, top_k=
|
383 |
-
|
384 |
index = idx.shape[1]
|
|
|
385 |
while idx.shape[1] < max_new_tokens - 1:
|
386 |
-
|
|
|
|
|
|
|
|
|
387 |
logits = inference_res.logits
|
388 |
logits = logits[:, -1, :]
|
389 |
|
390 |
for token in set(idx.tolist()[0]):
|
391 |
-
logits[:, token] /=
|
392 |
|
393 |
if temperature == 0.0:
|
394 |
_, idx_next = torch.topk(logits, k=1, dim=-1)
|
|
|
1 |
import math
|
2 |
import struct
|
3 |
import inspect
|
4 |
+
import time
|
5 |
+
|
6 |
from .LMConfig import LMConfig
|
7 |
from typing import Any, Optional, Tuple
|
8 |
import numpy as np
|
|
|
82 |
self.dropout = args.dropout
|
83 |
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') and args.flash_attn
|
84 |
|
85 |
+
# print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
|
86 |
+
mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))
|
87 |
+
mask = torch.triu(mask, diagonal=1)
|
88 |
+
self.register_buffer("mask", mask)
|
|
|
89 |
|
90 |
+
def forward(self, x: torch.Tensor, pos_cis: torch.Tensor, kv_cache=False):
|
91 |
bsz, seqlen, _ = x.shape
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
|
93 |
+
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
|
|
|
|
|
94 |
|
95 |
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
|
96 |
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
|
|
|
98 |
|
99 |
xq, xk = apply_rotary_emb(xq, xk, pos_cis)
|
100 |
|
101 |
+
# 更高效的kv_cache实现
|
102 |
+
if kv_cache and self.eval():
|
103 |
+
if seqlen == 1 and all(cache is not None for cache in (self.k_cache, self.v_cache)):
|
104 |
+
xk = torch.cat((self.k_cache, xk), dim=1)
|
105 |
+
xv = torch.cat((self.v_cache, xv), dim=1)
|
106 |
+
self.k_cache, self.v_cache = xk, xv
|
107 |
+
|
108 |
xk = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
|
109 |
xv = repeat_kv(xv, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
|
110 |
|
|
|
112 |
xk = xk.transpose(1, 2)
|
113 |
xv = xv.transpose(1, 2)
|
114 |
|
115 |
+
if self.flash and seqlen != 1:
|
116 |
output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv, attn_mask=None,
|
117 |
dropout_p=self.dropout if self.training else 0.0,
|
118 |
is_causal=True)
|
119 |
else:
|
120 |
scores = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_dim)
|
|
|
121 |
scores = scores + self.mask[:, :, :seqlen, :seqlen] # (bs, n_local_heads, seqlen, cache_len + seqlen)
|
122 |
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
|
123 |
scores = self.attn_dropout(scores)
|
|
|
301 |
dropout=args.dropout,
|
302 |
)
|
303 |
|
304 |
+
def forward(self, x, pos_cis, kv_cache=False):
|
305 |
+
h = x + self.attention(self.attention_norm(x), pos_cis, kv_cache)
|
306 |
out = h + self.feed_forward(self.ffn_norm(h))
|
307 |
return out
|
308 |
|
|
|
348 |
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
349 |
|
350 |
def forward(self, tokens: Optional[torch.Tensor] = None, targets: Optional[torch.Tensor] = None,
|
351 |
+
kv_cache=False, **keyargs):
|
352 |
+
current_idx = 0
|
353 |
if 'input_ids' in keyargs:
|
354 |
tokens = keyargs['input_ids']
|
355 |
if 'attention_mask' in keyargs:
|
356 |
targets = keyargs['attention_mask']
|
357 |
+
if 'current_idx' in keyargs:
|
358 |
+
current_idx = int(keyargs['current_idx'])
|
359 |
|
360 |
_bsz, seqlen = tokens.shape
|
361 |
h = self.tok_embeddings(tokens)
|
362 |
h = self.dropout(h)
|
363 |
+
pos_cis = self.pos_cis[current_idx:current_idx + seqlen]
|
364 |
for idx, layer in enumerate(self.layers):
|
365 |
+
h = layer(h, pos_cis, kv_cache)
|
366 |
|
367 |
h = self.norm(h)
|
368 |
|
|
|
375 |
|
376 |
self.OUT.__setitem__('logits', logits)
|
377 |
self.OUT.__setitem__('last_loss', self.last_loss)
|
|
|
378 |
return self.OUT
|
379 |
|
380 |
@torch.inference_mode()
|
381 |
+
def generate(self, idx, eos, max_new_tokens, temperature=0.7, top_k=8, stream=True, rp=1., kv_cache=True):
|
382 |
+
# rp: repetition_penalty
|
383 |
index = idx.shape[1]
|
384 |
+
init_inference = True
|
385 |
while idx.shape[1] < max_new_tokens - 1:
|
386 |
+
if init_inference or not kv_cache:
|
387 |
+
inference_res, init_inference = self(idx, kv_cache=kv_cache), False
|
388 |
+
else:
|
389 |
+
inference_res = self(idx[:, -1:], kv_cache=kv_cache, current_idx=idx.shape[1] - 1)
|
390 |
+
|
391 |
logits = inference_res.logits
|
392 |
logits = logits[:, -1, :]
|
393 |
|
394 |
for token in set(idx.tolist()[0]):
|
395 |
+
logits[:, token] /= rp
|
396 |
|
397 |
if temperature == 0.0:
|
398 |
_, idx_next = torch.topk(logits, k=1, dim=-1)
|