jingyaogong
commited on
Commit
•
fad3b10
1
Parent(s):
50d126d
Upload model.py
Browse files
model.py
CHANGED
@@ -1,6 +1,8 @@
|
|
1 |
import math
|
2 |
import struct
|
3 |
import inspect
|
|
|
|
|
4 |
from .LMConfig import LMConfig
|
5 |
from typing import Any, Optional, Tuple
|
6 |
import numpy as np
|
@@ -66,93 +68,66 @@ class Attention(nn.Module):
|
|
66 |
super().__init__()
|
67 |
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
|
68 |
assert args.n_heads % self.n_kv_heads == 0
|
69 |
-
|
70 |
-
self.
|
71 |
-
self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
|
72 |
self.n_rep = self.n_local_heads // self.n_local_kv_heads
|
73 |
self.head_dim = args.dim // args.n_heads
|
74 |
self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
|
75 |
self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
|
76 |
self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
|
77 |
self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
|
|
|
78 |
self.attn_dropout = nn.Dropout(args.dropout)
|
79 |
self.resid_dropout = nn.Dropout(args.dropout)
|
80 |
self.dropout = args.dropout
|
81 |
-
|
82 |
-
# use flash attention or a manual implementation?
|
83 |
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') and args.flash_attn
|
84 |
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
def forward(
|
92 |
-
self,
|
93 |
-
x: torch.Tensor,
|
94 |
-
pos_cis: torch.Tensor,
|
95 |
-
use_kv_cache: bool = False,
|
96 |
-
past_kv: Tuple[torch.Tensor] = None
|
97 |
-
):
|
98 |
bsz, seqlen, _ = x.shape
|
99 |
-
# QKV
|
100 |
-
# inference
|
101 |
-
if use_kv_cache:
|
102 |
-
# 只计算最后一个token的Q
|
103 |
-
current_token = x[:, -1:, :]
|
104 |
-
|
105 |
-
if not past_kv:
|
106 |
-
xq = self.wq(x)
|
107 |
-
xk, xv = self.wk(x), self.wv(x)
|
108 |
-
else:
|
109 |
-
past_key, past_value = past_kv
|
110 |
-
xq = torch.cat((torch.zeros_like(x[:, :-1, :]), self.wq(current_token)), dim=1)
|
111 |
-
xk = torch.cat((past_key, self.wk(current_token)), dim=1)
|
112 |
-
xv = torch.cat((past_value, self.wv(current_token)), dim=1)
|
113 |
|
114 |
-
|
115 |
-
else:
|
116 |
-
xq = self.wq(x)
|
117 |
-
xk, xv = self.wk(x), self.wv(x)
|
118 |
|
119 |
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
|
120 |
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
|
121 |
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
|
122 |
|
123 |
-
# RoPE relative positional embeddings
|
124 |
xq, xk = apply_rotary_emb(xq, xk, pos_cis)
|
125 |
|
126 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
127 |
xk = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
|
128 |
xv = repeat_kv(xv, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
|
129 |
|
130 |
-
|
131 |
-
xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
|
132 |
xk = xk.transpose(1, 2)
|
133 |
xv = xv.transpose(1, 2)
|
134 |
|
135 |
-
|
136 |
-
if self.flash:
|
137 |
output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv, attn_mask=None,
|
138 |
dropout_p=self.dropout if self.training else 0.0,
|
139 |
is_causal=True)
|
140 |
else:
|
141 |
-
# manual implementation
|
142 |
scores = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_dim)
|
143 |
-
assert hasattr(self, 'mask')
|
144 |
scores = scores + self.mask[:, :, :seqlen, :seqlen] # (bs, n_local_heads, seqlen, cache_len + seqlen)
|
145 |
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
|
146 |
scores = self.attn_dropout(scores)
|
147 |
output = torch.matmul(scores, xv) # (bs, n_local_heads, seqlen, head_dim)
|
148 |
|
149 |
-
# restore time as batch dimension and concat heads
|
150 |
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
|
151 |
|
152 |
-
# final projection into the residual stream
|
153 |
output = self.wo(output)
|
154 |
output = self.resid_dropout(output)
|
155 |
-
return output
|
156 |
|
157 |
|
158 |
class FeedForward(nn.Module):
|
@@ -182,7 +157,6 @@ class MoEGate(nn.Module):
|
|
182 |
self.alpha = config.aux_loss_alpha
|
183 |
self.seq_aux = config.seq_aux
|
184 |
|
185 |
-
# topk selection algorithm
|
186 |
self.norm_topk_prob = config.norm_topk_prob
|
187 |
self.gating_dim = config.dim
|
188 |
self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim)))
|
@@ -194,7 +168,7 @@ class MoEGate(nn.Module):
|
|
194 |
|
195 |
def forward(self, hidden_states):
|
196 |
bsz, seq_len, h = hidden_states.shape
|
197 |
-
|
198 |
hidden_states = hidden_states.view(-1, h)
|
199 |
logits = F.linear(hidden_states, self.weight, None)
|
200 |
if self.scoring_func == 'softmax':
|
@@ -202,19 +176,15 @@ class MoEGate(nn.Module):
|
|
202 |
else:
|
203 |
raise NotImplementedError(f'insupportable scoring function for MoE gating: {self.scoring_func}')
|
204 |
|
205 |
-
### select top-k experts
|
206 |
topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)
|
207 |
|
208 |
-
### norm gate to sum 1
|
209 |
if self.top_k > 1 and self.norm_topk_prob:
|
210 |
denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
|
211 |
topk_weight = topk_weight / denominator
|
212 |
|
213 |
-
### expert-level computation auxiliary loss
|
214 |
if self.training and self.alpha > 0.0:
|
215 |
scores_for_aux = scores
|
216 |
aux_topk = self.top_k
|
217 |
-
# always compute aux loss based on the naive greedy topk method
|
218 |
topk_idx_for_aux_loss = topk_idx.view(bsz, -1)
|
219 |
if self.seq_aux:
|
220 |
scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1)
|
@@ -331,11 +301,10 @@ class TransformerBlock(nn.Module):
|
|
331 |
dropout=args.dropout,
|
332 |
)
|
333 |
|
334 |
-
def forward(self, x, pos_cis,
|
335 |
-
|
336 |
-
h = x + attn_res
|
337 |
out = h + self.feed_forward(self.ffn_norm(h))
|
338 |
-
return out
|
339 |
|
340 |
|
341 |
class Transformer(PreTrainedModel):
|
@@ -357,22 +326,16 @@ class Transformer(PreTrainedModel):
|
|
357 |
self.layers.append(TransformerBlock(layer_id, params))
|
358 |
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
|
359 |
self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
|
360 |
-
|
361 |
-
# share the unembedding parameters with the embedding parameters
|
362 |
-
self.tok_embeddings.weight = self.output.weight # https://paperswithcode.com/method/weight-tying
|
363 |
-
|
364 |
-
# some useful precompute for the RoPE relative positional embeddings
|
365 |
pos_cis = precompute_pos_cis(self.params.dim // self.params.n_heads, self.params.max_seq_len)
|
366 |
self.register_buffer("pos_cis", pos_cis, persistent=False)
|
367 |
|
368 |
-
# init all weights
|
369 |
self.apply(self._init_weights)
|
370 |
-
|
371 |
for pn, p in self.named_parameters():
|
372 |
if pn.endswith('w3.weight') or pn.endswith('wo.weight'):
|
373 |
torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * params.n_layers))
|
374 |
|
375 |
-
# Initialize attribute for the loss of the last forward call. This will be set if the forward is called with a targets tensor.
|
376 |
self.last_loss = None
|
377 |
self.OUT = CausalLMOutputWithPast()
|
378 |
|
@@ -384,78 +347,64 @@ class Transformer(PreTrainedModel):
|
|
384 |
elif isinstance(module, nn.Embedding):
|
385 |
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
386 |
|
387 |
-
def forward(self, tokens: Optional[torch.Tensor] = None,
|
388 |
-
|
389 |
-
|
390 |
-
if past_kvs is None:
|
391 |
-
past_kvs = [None for _ in range(self.n_layers)]
|
392 |
if 'input_ids' in keyargs:
|
393 |
tokens = keyargs['input_ids']
|
394 |
if 'attention_mask' in keyargs:
|
395 |
targets = keyargs['attention_mask']
|
|
|
|
|
396 |
|
397 |
_bsz, seqlen = tokens.shape
|
398 |
h = self.tok_embeddings(tokens)
|
399 |
h = self.dropout(h)
|
400 |
-
pos_cis = self.pos_cis[:seqlen]
|
401 |
for idx, layer in enumerate(self.layers):
|
402 |
-
h
|
403 |
|
404 |
h = self.norm(h)
|
405 |
|
406 |
if targets is not None:
|
407 |
-
# if we are given some desired targets also calculate the loss
|
408 |
logits = self.output(h)
|
409 |
self.last_loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
|
410 |
else:
|
411 |
-
|
412 |
-
logits = self.output(h[:, [-1], :]) # note: using list [-1] to preserve the time dim
|
413 |
self.last_loss = None
|
414 |
|
415 |
self.OUT.__setitem__('logits', logits)
|
416 |
self.OUT.__setitem__('last_loss', self.last_loss)
|
417 |
-
|
418 |
-
if use_kv_cache:
|
419 |
-
return self.OUT, past_kvs
|
420 |
return self.OUT
|
421 |
|
422 |
-
|
423 |
@torch.inference_mode()
|
424 |
-
def generate(self, idx, eos, max_new_tokens, temperature=0.7, top_k=
|
|
|
425 |
index = idx.shape[1]
|
426 |
-
|
427 |
-
past_kvs = [None for _ in range(self.n_layers)]
|
428 |
while idx.shape[1] < max_new_tokens - 1:
|
429 |
-
|
430 |
-
|
431 |
-
# forward the model to get the logits for the index in the sequence
|
432 |
-
inference_res = self(idx_cond, use_kv_cache=use_kv_cache, past_kvs=past_kvs)
|
433 |
-
if use_kv_cache:
|
434 |
-
logits, past_kvs = inference_res[0].logits, inference_res[1]
|
435 |
else:
|
436 |
-
|
437 |
|
438 |
-
logits = logits
|
|
|
439 |
|
440 |
-
# Apply repetition penalty
|
441 |
for token in set(idx.tolist()[0]):
|
442 |
-
logits[:, token] /=
|
443 |
|
444 |
if temperature == 0.0:
|
445 |
-
|
446 |
-
__, idx_next = torch.topk(logits, k=1, dim=-1)
|
447 |
else:
|
448 |
-
# pluck the logits at the final step and scale by desired temperature
|
449 |
logits = logits / temperature
|
450 |
-
# optionally crop the logits to only the top k options
|
451 |
if top_k is not None:
|
452 |
-
v,
|
453 |
logits[logits < v[:, [-1]]] = -float('Inf')
|
454 |
|
455 |
-
# apply softmax to convert logits to (normalized) probabilities
|
456 |
probs = F.softmax(logits, dim=-1)
|
457 |
idx_next = torch.multinomial(probs, num_samples=1, generator=None)
|
458 |
-
|
459 |
if idx_next == eos:
|
460 |
break
|
461 |
|
@@ -468,63 +417,8 @@ class Transformer(PreTrainedModel):
|
|
468 |
|
469 |
@torch.inference_mode()
|
470 |
def eval_answer(self, idx):
|
471 |
-
# if the sequence context is growing too long we must crop it at block_size
|
472 |
idx_cond = idx if idx.size(1) <= self.params.max_seq_len else idx[:, -self.params.max_seq_len:]
|
473 |
-
|
474 |
-
past_kvs = [None for _ in range(self.n_layers)]
|
475 |
-
inference_res = self(idx_cond, use_kv_cache=False, past_kvs=past_kvs)
|
476 |
logits = inference_res.logits
|
477 |
logits = logits[:, -1, :]
|
478 |
return logits
|
479 |
-
|
480 |
-
def export(self, filepath='model.bin'):
|
481 |
-
"""export the model weights in fp32 into .bin file to be read from C"""
|
482 |
-
f = open(filepath, 'wb')
|
483 |
-
|
484 |
-
def serialize(t):
|
485 |
-
d = t.detach().cpu().view(-1).numpy().astype(np.float32)
|
486 |
-
b = struct.pack(f'{len(d)}f', *d)
|
487 |
-
f.write(b)
|
488 |
-
|
489 |
-
# first write out the header
|
490 |
-
hidden_dim = self.layers[0].feed_forward.w1.weight.shape[0]
|
491 |
-
p = self.params
|
492 |
-
n_kv_heads = p.n_heads if p.n_kv_heads is None else p.n_kv_heads
|
493 |
-
header = struct.pack('iiiiiii', p.dim, hidden_dim, p.n_layers, p.n_heads,
|
494 |
-
n_kv_heads, p.vocab_size, p.max_seq_len)
|
495 |
-
f.write(header)
|
496 |
-
|
497 |
-
# next write out the embedding weights
|
498 |
-
serialize(self.tok_embeddings.weight)
|
499 |
-
|
500 |
-
# now all the layers
|
501 |
-
# attention weights
|
502 |
-
for layer in self.layers:
|
503 |
-
serialize(layer.attention_norm.weight)
|
504 |
-
for layer in self.layers:
|
505 |
-
serialize(layer.attention.wq.weight)
|
506 |
-
for layer in self.layers:
|
507 |
-
serialize(layer.attention.wk.weight)
|
508 |
-
for layer in self.layers:
|
509 |
-
serialize(layer.attention.wv.weight)
|
510 |
-
for layer in self.layers:
|
511 |
-
serialize(layer.attention.wo.weight)
|
512 |
-
# ffn weights
|
513 |
-
for layer in self.layers:
|
514 |
-
serialize(layer.ffn_norm.weight)
|
515 |
-
for layer in self.layers:
|
516 |
-
serialize(layer.feed_forward.w1.weight)
|
517 |
-
for layer in self.layers:
|
518 |
-
serialize(layer.feed_forward.w2.weight)
|
519 |
-
for layer in self.layers:
|
520 |
-
serialize(layer.feed_forward.w3.weight)
|
521 |
-
# final rmsnorm
|
522 |
-
serialize(self.norm.weight)
|
523 |
-
# note: no need to write final classifier weights due to weight sharing
|
524 |
-
# pos_cis
|
525 |
-
serialize(self.freqs_cos[:p.max_seq_len])
|
526 |
-
serialize(self.freqs_sin[:p.max_seq_len])
|
527 |
-
|
528 |
-
# write to binary file
|
529 |
-
f.close()
|
530 |
-
print(f"wrote {filepath}")
|
|
|
1 |
import math
|
2 |
import struct
|
3 |
import inspect
|
4 |
+
import time
|
5 |
+
|
6 |
from .LMConfig import LMConfig
|
7 |
from typing import Any, Optional, Tuple
|
8 |
import numpy as np
|
|
|
68 |
super().__init__()
|
69 |
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
|
70 |
assert args.n_heads % self.n_kv_heads == 0
|
71 |
+
self.n_local_heads = args.n_heads
|
72 |
+
self.n_local_kv_heads = self.n_kv_heads
|
|
|
73 |
self.n_rep = self.n_local_heads // self.n_local_kv_heads
|
74 |
self.head_dim = args.dim // args.n_heads
|
75 |
self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
|
76 |
self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
|
77 |
self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
|
78 |
self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
|
79 |
+
self.k_cache, self.v_cache = None, None
|
80 |
self.attn_dropout = nn.Dropout(args.dropout)
|
81 |
self.resid_dropout = nn.Dropout(args.dropout)
|
82 |
self.dropout = args.dropout
|
|
|
|
|
83 |
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') and args.flash_attn
|
84 |
|
85 |
+
# print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
|
86 |
+
mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))
|
87 |
+
mask = torch.triu(mask, diagonal=1)
|
88 |
+
self.register_buffer("mask", mask)
|
89 |
+
|
90 |
+
def forward(self, x: torch.Tensor, pos_cis: torch.Tensor, kv_cache=False):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
bsz, seqlen, _ = x.shape
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
|
93 |
+
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
|
|
|
|
|
|
|
94 |
|
95 |
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
|
96 |
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
|
97 |
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
|
98 |
|
|
|
99 |
xq, xk = apply_rotary_emb(xq, xk, pos_cis)
|
100 |
|
101 |
+
# 更高效的kv_cache实现
|
102 |
+
if kv_cache and self.eval():
|
103 |
+
if seqlen == 1 and all(cache is not None for cache in (self.k_cache, self.v_cache)):
|
104 |
+
xk = torch.cat((self.k_cache, xk), dim=1)
|
105 |
+
xv = torch.cat((self.v_cache, xv), dim=1)
|
106 |
+
self.k_cache, self.v_cache = xk, xv
|
107 |
+
|
108 |
xk = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
|
109 |
xv = repeat_kv(xv, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
|
110 |
|
111 |
+
xq = xq.transpose(1, 2)
|
|
|
112 |
xk = xk.transpose(1, 2)
|
113 |
xv = xv.transpose(1, 2)
|
114 |
|
115 |
+
if self.flash and seqlen != 1:
|
|
|
116 |
output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv, attn_mask=None,
|
117 |
dropout_p=self.dropout if self.training else 0.0,
|
118 |
is_causal=True)
|
119 |
else:
|
|
|
120 |
scores = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_dim)
|
|
|
121 |
scores = scores + self.mask[:, :, :seqlen, :seqlen] # (bs, n_local_heads, seqlen, cache_len + seqlen)
|
122 |
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
|
123 |
scores = self.attn_dropout(scores)
|
124 |
output = torch.matmul(scores, xv) # (bs, n_local_heads, seqlen, head_dim)
|
125 |
|
|
|
126 |
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
|
127 |
|
|
|
128 |
output = self.wo(output)
|
129 |
output = self.resid_dropout(output)
|
130 |
+
return output
|
131 |
|
132 |
|
133 |
class FeedForward(nn.Module):
|
|
|
157 |
self.alpha = config.aux_loss_alpha
|
158 |
self.seq_aux = config.seq_aux
|
159 |
|
|
|
160 |
self.norm_topk_prob = config.norm_topk_prob
|
161 |
self.gating_dim = config.dim
|
162 |
self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim)))
|
|
|
168 |
|
169 |
def forward(self, hidden_states):
|
170 |
bsz, seq_len, h = hidden_states.shape
|
171 |
+
|
172 |
hidden_states = hidden_states.view(-1, h)
|
173 |
logits = F.linear(hidden_states, self.weight, None)
|
174 |
if self.scoring_func == 'softmax':
|
|
|
176 |
else:
|
177 |
raise NotImplementedError(f'insupportable scoring function for MoE gating: {self.scoring_func}')
|
178 |
|
|
|
179 |
topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)
|
180 |
|
|
|
181 |
if self.top_k > 1 and self.norm_topk_prob:
|
182 |
denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
|
183 |
topk_weight = topk_weight / denominator
|
184 |
|
|
|
185 |
if self.training and self.alpha > 0.0:
|
186 |
scores_for_aux = scores
|
187 |
aux_topk = self.top_k
|
|
|
188 |
topk_idx_for_aux_loss = topk_idx.view(bsz, -1)
|
189 |
if self.seq_aux:
|
190 |
scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1)
|
|
|
301 |
dropout=args.dropout,
|
302 |
)
|
303 |
|
304 |
+
def forward(self, x, pos_cis, kv_cache=False):
|
305 |
+
h = x + self.attention(self.attention_norm(x), pos_cis, kv_cache)
|
|
|
306 |
out = h + self.feed_forward(self.ffn_norm(h))
|
307 |
+
return out
|
308 |
|
309 |
|
310 |
class Transformer(PreTrainedModel):
|
|
|
326 |
self.layers.append(TransformerBlock(layer_id, params))
|
327 |
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
|
328 |
self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
|
329 |
+
self.tok_embeddings.weight = self.output.weight
|
|
|
|
|
|
|
|
|
330 |
pos_cis = precompute_pos_cis(self.params.dim // self.params.n_heads, self.params.max_seq_len)
|
331 |
self.register_buffer("pos_cis", pos_cis, persistent=False)
|
332 |
|
|
|
333 |
self.apply(self._init_weights)
|
334 |
+
|
335 |
for pn, p in self.named_parameters():
|
336 |
if pn.endswith('w3.weight') or pn.endswith('wo.weight'):
|
337 |
torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * params.n_layers))
|
338 |
|
|
|
339 |
self.last_loss = None
|
340 |
self.OUT = CausalLMOutputWithPast()
|
341 |
|
|
|
347 |
elif isinstance(module, nn.Embedding):
|
348 |
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
349 |
|
350 |
+
def forward(self, tokens: Optional[torch.Tensor] = None, targets: Optional[torch.Tensor] = None,
|
351 |
+
kv_cache=False, **keyargs):
|
352 |
+
current_idx = 0
|
|
|
|
|
353 |
if 'input_ids' in keyargs:
|
354 |
tokens = keyargs['input_ids']
|
355 |
if 'attention_mask' in keyargs:
|
356 |
targets = keyargs['attention_mask']
|
357 |
+
if 'current_idx' in keyargs:
|
358 |
+
current_idx = int(keyargs['current_idx'])
|
359 |
|
360 |
_bsz, seqlen = tokens.shape
|
361 |
h = self.tok_embeddings(tokens)
|
362 |
h = self.dropout(h)
|
363 |
+
pos_cis = self.pos_cis[current_idx:current_idx + seqlen]
|
364 |
for idx, layer in enumerate(self.layers):
|
365 |
+
h = layer(h, pos_cis, kv_cache)
|
366 |
|
367 |
h = self.norm(h)
|
368 |
|
369 |
if targets is not None:
|
|
|
370 |
logits = self.output(h)
|
371 |
self.last_loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
|
372 |
else:
|
373 |
+
logits = self.output(h[:, [-1], :])
|
|
|
374 |
self.last_loss = None
|
375 |
|
376 |
self.OUT.__setitem__('logits', logits)
|
377 |
self.OUT.__setitem__('last_loss', self.last_loss)
|
|
|
|
|
|
|
378 |
return self.OUT
|
379 |
|
|
|
380 |
@torch.inference_mode()
|
381 |
+
def generate(self, idx, eos, max_new_tokens, temperature=0.7, top_k=8, stream=True, rp=1., kv_cache=True):
|
382 |
+
# rp: repetition_penalty
|
383 |
index = idx.shape[1]
|
384 |
+
init_inference = True
|
|
|
385 |
while idx.shape[1] < max_new_tokens - 1:
|
386 |
+
if init_inference or not kv_cache:
|
387 |
+
inference_res, init_inference = self(idx, kv_cache=kv_cache), False
|
|
|
|
|
|
|
|
|
388 |
else:
|
389 |
+
inference_res = self(idx[:, -1:], kv_cache=kv_cache, current_idx=idx.shape[1] - 1)
|
390 |
|
391 |
+
logits = inference_res.logits
|
392 |
+
logits = logits[:, -1, :]
|
393 |
|
|
|
394 |
for token in set(idx.tolist()[0]):
|
395 |
+
logits[:, token] /= rp
|
396 |
|
397 |
if temperature == 0.0:
|
398 |
+
_, idx_next = torch.topk(logits, k=1, dim=-1)
|
|
|
399 |
else:
|
|
|
400 |
logits = logits / temperature
|
|
|
401 |
if top_k is not None:
|
402 |
+
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
403 |
logits[logits < v[:, [-1]]] = -float('Inf')
|
404 |
|
|
|
405 |
probs = F.softmax(logits, dim=-1)
|
406 |
idx_next = torch.multinomial(probs, num_samples=1, generator=None)
|
407 |
+
|
408 |
if idx_next == eos:
|
409 |
break
|
410 |
|
|
|
417 |
|
418 |
@torch.inference_mode()
|
419 |
def eval_answer(self, idx):
|
|
|
420 |
idx_cond = idx if idx.size(1) <= self.params.max_seq_len else idx[:, -self.params.max_seq_len:]
|
421 |
+
inference_res = self(idx_cond)
|
|
|
|
|
422 |
logits = inference_res.logits
|
423 |
logits = logits[:, -1, :]
|
424 |
return logits
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|