zwt123home123 commited on
Commit
8ad2fbf
·
verified ·
1 Parent(s): 50e87e1

Update modeling_llama_no_vcache.py

Browse files
Files changed (1) hide show
  1. modeling_llama_no_vcache.py +11 -18
modeling_llama_no_vcache.py CHANGED
@@ -414,28 +414,16 @@ class LlamaAttention(nn.Module):
414
  query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
415
  key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
416
  value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
417
-
 
418
  # import pdb; pdb.set_trace()
419
 
420
  if value_states.shape[2]>576:
421
  reuse = True
422
- if self.load_ae_v:
423
- self.ae_v.load_state_dict(torch.load("weights/"+"autoencoder_epoch_1_L1_1280_nonorm_layer_"+str(self.layer_idx)+".pth", map_location='cuda'))
424
- self.load_ae_v = False
425
- else:
426
- pass
427
  value_states_ = value_states.clone()
428
- value_states_v = value_states[:,:,35:35+576,:]
429
- value_states_v = value_states_v.permute(0, 2, 1, 3)
430
- value_states_v=value_states_v.reshape(value_states_v.shape[0],value_states_v.shape[1],5120)
431
- # import pdb; pdb.set_trace()
432
- value_states_v = self.ae_v(value_states_v)
433
- value_states_v = value_states_v.reshape(value_states_v.shape[0],value_states_v.shape[1], 40, 128)
434
- value_states_v = value_states_v.permute(0, 2, 1, 3)
435
- value_states[:,:,35:35+576,:] = value_states_v
436
  else:
437
  reuse = False
438
-
439
  kv_seq_len = key_states.shape[-2]
440
  if past_key_value is not None:
441
  if self.layer_idx is None:
@@ -475,8 +463,7 @@ class LlamaAttention(nn.Module):
475
  attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
476
  attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
477
 
478
- if reuse:
479
- value_states = value_states_
480
  #if self.layer_idx==5:
481
  # print(value_states[0,0,256,:])
482
 
@@ -486,8 +473,9 @@ class LlamaAttention(nn.Module):
486
  else:
487
  pass
488
 
 
 
489
  if value_states.shape[2]>576:
490
- value_states_ = value_states.clone()
491
  value_states_v = value_states[:,:,35:35+576,:]
492
  value_states_v = value_states_v.permute(0, 2, 1, 3)
493
  value_states_v=value_states_v.reshape(value_states_v.shape[0],value_states_v.shape[1],5120)
@@ -497,6 +485,11 @@ class LlamaAttention(nn.Module):
497
  value_states_v = value_states_v.permute(0, 2, 1, 3)
498
  value_states[:,:,35:35+576,:] = value_states_v
499
 
 
 
 
 
 
500
  attn_output = torch.matmul(attn_weights, value_states)
501
 
502
  if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
 
414
  query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
415
  key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
416
  value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
417
+
418
+
419
  # import pdb; pdb.set_trace()
420
 
421
  if value_states.shape[2]>576:
422
  reuse = True
 
 
 
 
 
423
  value_states_ = value_states.clone()
 
 
 
 
 
 
 
 
424
  else:
425
  reuse = False
426
+
427
  kv_seq_len = key_states.shape[-2]
428
  if past_key_value is not None:
429
  if self.layer_idx is None:
 
463
  attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
464
  attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
465
 
466
+
 
467
  #if self.layer_idx==5:
468
  # print(value_states[0,0,256,:])
469
 
 
473
  else:
474
  pass
475
 
476
+ #if self.layer_idx==5:
477
+ # print(value_states.shape)
478
  if value_states.shape[2]>576:
 
479
  value_states_v = value_states[:,:,35:35+576,:]
480
  value_states_v = value_states_v.permute(0, 2, 1, 3)
481
  value_states_v=value_states_v.reshape(value_states_v.shape[0],value_states_v.shape[1],5120)
 
485
  value_states_v = value_states_v.permute(0, 2, 1, 3)
486
  value_states[:,:,35:35+576,:] = value_states_v
487
 
488
+ if reuse:
489
+ value_states = value_states_
490
+
491
+ #if self.layer_idx==5:
492
+ # print(value_states[0,0,256,:])
493
  attn_output = torch.matmul(attn_weights, value_states)
494
 
495
  if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):