jingyaogong commited on
Commit
b84b6fc
·
verified ·
1 Parent(s): f7836ea

Upload model.py

Browse files
Files changed (1) hide show
  1. model.py +9 -8
model.py CHANGED
@@ -23,7 +23,7 @@ class RMSNorm(torch.nn.Module):
23
  return self.weight * (x.float() * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)).type_as(x)
24
 
25
 
26
- def precompute_pos_cis(dim: int, end: int, theta: float = 1e4):
27
  freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
28
  t = torch.arange(end, device=freqs.device) # type: ignore
29
  freqs = torch.outer(t, freqs).float() # type: ignore
@@ -295,8 +295,9 @@ class MiniMindLM(PreTrainedModel):
295
  self.norm = RMSNorm(params.dim, eps=params.norm_eps)
296
  self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
297
  self.tok_embeddings.weight = self.output.weight
298
- self.register_buffer("pos_cis", precompute_pos_cis(params.dim // params.n_heads, params.max_seq_len,
299
- theta=params.rope_theta), persistent=False)
 
300
  self.OUT = CausalLMOutputWithPast()
301
 
302
  def forward(self,
@@ -328,13 +329,13 @@ class MiniMindLM(PreTrainedModel):
328
  stream=False, rp=1., use_cache=True, pad_token_id=0, **args):
329
  # 流式生成
330
  if stream:
331
- return self._generate_stream(input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache)
332
 
333
  # 直接生成
334
  generated = []
335
  for i in range(input_ids.size(0)):
336
  non_pad = input_ids[i][input_ids[i] != pad_token_id].unsqueeze(0)
337
- out = self._generate_stream(non_pad, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache)
338
  tokens_list = [tokens[:, -1:] for tokens in out]
339
  gen = torch.cat(tokens_list, dim=-1) if tokens_list else non_pad
340
  full_sequence = torch.cat([non_pad, gen], dim=-1)
@@ -348,14 +349,14 @@ class MiniMindLM(PreTrainedModel):
348
  ]
349
  return torch.cat(generated, dim=0)
350
 
351
- def _generate_stream(self, input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache, **args):
352
  start, first_seq, past_kvs = input_ids.shape[1], True, None
353
  while input_ids.shape[1] < max_new_tokens - 1:
354
  if first_seq or not use_cache:
355
- out, first_seq = self(input_ids, past_key_values=past_kvs, use_cache=use_cache), False
356
  else:
357
  out = self(input_ids[:, -1:], past_key_values=past_kvs, use_cache=use_cache,
358
- start_pos=input_ids.shape[1] - 1)
359
  logits, past_kvs = out.logits[:, -1, :], out.past_key_values
360
  logits[:, list(set(input_ids.tolist()[0]))] /= rp
361
  logits /= (temperature + 1e-9)
 
23
  return self.weight * (x.float() * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)).type_as(x)
24
 
25
 
26
+ def precompute_pos_cis(dim: int, end: int = int(32 * 1024), theta: float = 1e6):
27
  freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
28
  t = torch.arange(end, device=freqs.device) # type: ignore
29
  freqs = torch.outer(t, freqs).float() # type: ignore
 
295
  self.norm = RMSNorm(params.dim, eps=params.norm_eps)
296
  self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
297
  self.tok_embeddings.weight = self.output.weight
298
+ self.register_buffer("pos_cis",
299
+ precompute_pos_cis(dim=params.dim // params.n_heads, theta=params.rope_theta),
300
+ persistent=False)
301
  self.OUT = CausalLMOutputWithPast()
302
 
303
  def forward(self,
 
329
  stream=False, rp=1., use_cache=True, pad_token_id=0, **args):
330
  # 流式生成
331
  if stream:
332
+ return self._stream(input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache, **args)
333
 
334
  # 直接生成
335
  generated = []
336
  for i in range(input_ids.size(0)):
337
  non_pad = input_ids[i][input_ids[i] != pad_token_id].unsqueeze(0)
338
+ out = self._stream(non_pad, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache, **args)
339
  tokens_list = [tokens[:, -1:] for tokens in out]
340
  gen = torch.cat(tokens_list, dim=-1) if tokens_list else non_pad
341
  full_sequence = torch.cat([non_pad, gen], dim=-1)
 
349
  ]
350
  return torch.cat(generated, dim=0)
351
 
352
+ def _stream(self, input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache, **args):
353
  start, first_seq, past_kvs = input_ids.shape[1], True, None
354
  while input_ids.shape[1] < max_new_tokens - 1:
355
  if first_seq or not use_cache:
356
+ out, first_seq = self(input_ids, past_key_values=past_kvs, use_cache=use_cache, **args), False
357
  else:
358
  out = self(input_ids[:, -1:], past_key_values=past_kvs, use_cache=use_cache,
359
+ start_pos=input_ids.shape[1] - 1, **args)
360
  logits, past_kvs = out.logits[:, -1, :], out.past_key_values
361
  logits[:, list(set(input_ids.tolist()[0]))] /= rp
362
  logits /= (temperature + 1e-9)