Update modeling_llama_no_vcache.py
Browse files- modeling_llama_no_vcache.py +11 -18
modeling_llama_no_vcache.py
CHANGED
@@ -414,28 +414,16 @@ class LlamaAttention(nn.Module):
|
|
414 |
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
415 |
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
416 |
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
417 |
-
|
|
|
418 |
# import pdb; pdb.set_trace()
|
419 |
|
420 |
if value_states.shape[2]>576:
|
421 |
reuse = True
|
422 |
-
if self.load_ae_v:
|
423 |
-
self.ae_v.load_state_dict(torch.load("weights/"+"autoencoder_epoch_1_L1_1280_nonorm_layer_"+str(self.layer_idx)+".pth", map_location='cuda'))
|
424 |
-
self.load_ae_v = False
|
425 |
-
else:
|
426 |
-
pass
|
427 |
value_states_ = value_states.clone()
|
428 |
-
value_states_v = value_states[:,:,35:35+576,:]
|
429 |
-
value_states_v = value_states_v.permute(0, 2, 1, 3)
|
430 |
-
value_states_v=value_states_v.reshape(value_states_v.shape[0],value_states_v.shape[1],5120)
|
431 |
-
# import pdb; pdb.set_trace()
|
432 |
-
value_states_v = self.ae_v(value_states_v)
|
433 |
-
value_states_v = value_states_v.reshape(value_states_v.shape[0],value_states_v.shape[1], 40, 128)
|
434 |
-
value_states_v = value_states_v.permute(0, 2, 1, 3)
|
435 |
-
value_states[:,:,35:35+576,:] = value_states_v
|
436 |
else:
|
437 |
reuse = False
|
438 |
-
|
439 |
kv_seq_len = key_states.shape[-2]
|
440 |
if past_key_value is not None:
|
441 |
if self.layer_idx is None:
|
@@ -475,8 +463,7 @@ class LlamaAttention(nn.Module):
|
|
475 |
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
476 |
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
477 |
|
478 |
-
|
479 |
-
value_states = value_states_
|
480 |
#if self.layer_idx==5:
|
481 |
# print(value_states[0,0,256,:])
|
482 |
|
@@ -486,8 +473,9 @@ class LlamaAttention(nn.Module):
|
|
486 |
else:
|
487 |
pass
|
488 |
|
|
|
|
|
489 |
if value_states.shape[2]>576:
|
490 |
-
value_states_ = value_states.clone()
|
491 |
value_states_v = value_states[:,:,35:35+576,:]
|
492 |
value_states_v = value_states_v.permute(0, 2, 1, 3)
|
493 |
value_states_v=value_states_v.reshape(value_states_v.shape[0],value_states_v.shape[1],5120)
|
@@ -497,6 +485,11 @@ class LlamaAttention(nn.Module):
|
|
497 |
value_states_v = value_states_v.permute(0, 2, 1, 3)
|
498 |
value_states[:,:,35:35+576,:] = value_states_v
|
499 |
|
|
|
|
|
|
|
|
|
|
|
500 |
attn_output = torch.matmul(attn_weights, value_states)
|
501 |
|
502 |
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
|
|
414 |
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
415 |
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
416 |
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
417 |
+
|
418 |
+
|
419 |
# import pdb; pdb.set_trace()
|
420 |
|
421 |
if value_states.shape[2]>576:
|
422 |
reuse = True
|
|
|
|
|
|
|
|
|
|
|
423 |
value_states_ = value_states.clone()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
424 |
else:
|
425 |
reuse = False
|
426 |
+
|
427 |
kv_seq_len = key_states.shape[-2]
|
428 |
if past_key_value is not None:
|
429 |
if self.layer_idx is None:
|
|
|
463 |
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
464 |
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
465 |
|
466 |
+
|
|
|
467 |
#if self.layer_idx==5:
|
468 |
# print(value_states[0,0,256,:])
|
469 |
|
|
|
473 |
else:
|
474 |
pass
|
475 |
|
476 |
+
#if self.layer_idx==5:
|
477 |
+
# print(value_states.shape)
|
478 |
if value_states.shape[2]>576:
|
|
|
479 |
value_states_v = value_states[:,:,35:35+576,:]
|
480 |
value_states_v = value_states_v.permute(0, 2, 1, 3)
|
481 |
value_states_v=value_states_v.reshape(value_states_v.shape[0],value_states_v.shape[1],5120)
|
|
|
485 |
value_states_v = value_states_v.permute(0, 2, 1, 3)
|
486 |
value_states[:,:,35:35+576,:] = value_states_v
|
487 |
|
488 |
+
if reuse:
|
489 |
+
value_states = value_states_
|
490 |
+
|
491 |
+
#if self.layer_idx==5:
|
492 |
+
# print(value_states[0,0,256,:])
|
493 |
attn_output = torch.matmul(attn_weights, value_states)
|
494 |
|
495 |
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|