duzx16 commited on
Commit
0ecfe0b
1 Parent(s): 31d45da

Fix prefix projection

Browse files
Files changed (1) hide show
  1. modeling_chatglm.py +4 -3
modeling_chatglm.py CHANGED
@@ -68,11 +68,12 @@ class PrefixEncoder(torch.nn.Module):
68
  self.prefix_projection = config.prefix_projection
69
  if self.prefix_projection:
70
  # Use a two-layer MLP to encode the prefix
71
- self.embedding = torch.nn.Embedding(config.pre_seq_len, config.hidden_size)
 
72
  self.trans = torch.nn.Sequential(
73
- torch.nn.Linear(config.hidden_size, config.hidden_size),
74
  torch.nn.Tanh(),
75
- torch.nn.Linear(config.hidden_size, config.num_layers * config.hidden_size * 2)
76
  )
77
  else:
78
  self.embedding = torch.nn.Embedding(config.pre_seq_len,
 
68
  self.prefix_projection = config.prefix_projection
69
  if self.prefix_projection:
70
  # Use a two-layer MLP to encode the prefix
71
+ kv_size = config.num_layers * config.kv_channels * config.multi_query_group_num * 2
72
+ self.embedding = torch.nn.Embedding(config.pre_seq_len, kv_size)
73
  self.trans = torch.nn.Sequential(
74
+ torch.nn.Linear(kv_size, config.hidden_size),
75
  torch.nn.Tanh(),
76
+ torch.nn.Linear(config.hidden_size, kv_size)
77
  )
78
  else:
79
  self.embedding = torch.nn.Embedding(config.pre_seq_len,