duzx16 commited on
Commit
fc442f7
1 Parent(s): 5fe53eb

Fix gradient checkpointing and prefix prompt

Browse files
Files changed (1) hide show
  1. modeling_chatglm.py +4 -4
modeling_chatglm.py CHANGED
@@ -406,11 +406,11 @@ class SelfAttention(torch.nn.Module):
406
  key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb)
407
 
408
  # adjust key and value for inference
 
 
 
 
409
  if use_cache:
410
- if kv_cache is not None:
411
- cache_k, cache_v = kv_cache
412
- key_layer = torch.cat((cache_k, key_layer), dim=0)
413
- value_layer = torch.cat((cache_v, value_layer), dim=0)
414
  kv_cache = (key_layer, value_layer)
415
  else:
416
  kv_cache = None
 
406
  key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb)
407
 
408
  # adjust key and value for inference
409
+ if kv_cache is not None:
410
+ cache_k, cache_v = kv_cache
411
+ key_layer = torch.cat((cache_k, key_layer), dim=0)
412
+ value_layer = torch.cat((cache_v, value_layer), dim=0)
413
  if use_cache:
 
 
 
 
414
  kv_cache = (key_layer, value_layer)
415
  else:
416
  kv_cache = None