liuyanyi commited on
Commit
9379593
1 Parent(s): 9fa755c

Update modeling_bge_m3.py

Browse files
Files changed (1) hide show
  1. modeling_bge_m3.py +11 -25
modeling_bge_m3.py CHANGED
@@ -42,6 +42,7 @@ class BgeM3Model(XLMRobertaPreTrainedModel):
42
 
43
  self.init_weights()
44
 
 
45
  def dense_embedding(self, hidden_state, mask):
46
  if self.sentence_pooling_method == "cls":
47
  return hidden_state[:, 0]
@@ -50,6 +51,7 @@ class BgeM3Model(XLMRobertaPreTrainedModel):
50
  d = mask.sum(axis=1, keepdim=True).float()
51
  return s / d
52
 
 
53
  def sparse_embedding(self, hidden_state, input_ids, return_embedding: bool = False):
54
  token_weights = torch.relu(self.sparse_linear(hidden_state))
55
  if not return_embedding:
@@ -69,11 +71,13 @@ class BgeM3Model(XLMRobertaPreTrainedModel):
69
  sparse_embedding[:, unused_tokens] *= 0.0
70
  return sparse_embedding
71
 
 
72
  def colbert_embedding(self, last_hidden_state, mask):
73
  colbert_vecs = self.colbert_linear(last_hidden_state[:, 1:])
74
  colbert_vecs = colbert_vecs * mask[:, 1:][:, :, None].float()
75
  return colbert_vecs
76
 
 
77
  def _process_token_weights(self, token_weights, input_ids, mask):
78
  token_weights = token_weights.squeeze(-1)
79
  # conver to dict
@@ -81,50 +85,32 @@ class BgeM3Model(XLMRobertaPreTrainedModel):
81
  unused_tokens = self.config.unused_tokens
82
  unused_tokens = torch.tensor(unused_tokens, device=input_ids.device)
83
 
84
- # 获取有效的 token 的索引
85
  valid_indices = ~torch.isin(input_ids, unused_tokens)
86
- # weight必须大于0
87
  valid_indices = (valid_indices & (token_weights > 0)).bool()
88
- # 结合 attention mask,获取有效的 token 的索引
89
  valid_indices = (valid_indices & mask).bool()
90
 
91
  for i, valid in enumerate(valid_indices):
92
  result = defaultdict(int)
93
 
94
- # 获取有效的 weights ids
95
  valid_weights = token_weights[i][valid]
96
  valid_ids = input_ids[i][valid]
97
 
98
- # 获取每个 id 的最大权重
99
  unique_ids, inverse_indices = torch.unique(valid_ids, return_inverse=True)
100
 
101
- # 使用一个循环来找到每个 unique id 的最大权重
102
  for i in range(unique_ids.shape[0]):
103
  id_mask = inverse_indices == i
104
  result[str(unique_ids[i].item())] = valid_weights[id_mask].max().item()
105
 
106
  all_result.append(result)
107
- # token_weights = np.ceil(token_weights * 100)
108
- # for w, idx, num in zip(token_weights, input_ids, tokens_num):
109
- # r = defaultdict(int)
110
- # token_weight = w[:num]
111
- # idx = idx[:num]
112
-
113
- # for t_w, t_idx in zip(token_weight, idx):
114
- # if t_idx.item() not in unused_tokens:
115
- # t_idx = str(t_idx.item())
116
- # if t_w > r[t_idx]:
117
- # r[t_idx] = t_w.item()
118
-
119
- # result.append(r)
120
-
121
- # if idx not in unused_tokens and w > 0:
122
- # idx = str(idx)
123
- # # w = int(w)
124
- # if w > result[idx]:
125
- # result[idx] = w
126
  return all_result
127
 
 
128
  def _process_colbert_vecs(self, colbert_vecs, tokens_num) -> List[torch.Tensor]:
129
  # delte the vectors of padding tokens
130
  vecs = []
 
42
 
43
  self.init_weights()
44
 
45
+ # Copied from FlagEmbedding
46
  def dense_embedding(self, hidden_state, mask):
47
  if self.sentence_pooling_method == "cls":
48
  return hidden_state[:, 0]
 
51
  d = mask.sum(axis=1, keepdim=True).float()
52
  return s / d
53
 
54
+ # Copied from FlagEmbedding
55
  def sparse_embedding(self, hidden_state, input_ids, return_embedding: bool = False):
56
  token_weights = torch.relu(self.sparse_linear(hidden_state))
57
  if not return_embedding:
 
71
  sparse_embedding[:, unused_tokens] *= 0.0
72
  return sparse_embedding
73
 
74
+ # Copied from FlagEmbedding
75
  def colbert_embedding(self, last_hidden_state, mask):
76
  colbert_vecs = self.colbert_linear(last_hidden_state[:, 1:])
77
  colbert_vecs = colbert_vecs * mask[:, 1:][:, :, None].float()
78
  return colbert_vecs
79
 
80
+ # Modified from FlagEmbedding
81
  def _process_token_weights(self, token_weights, input_ids, mask):
82
  token_weights = token_weights.squeeze(-1)
83
  # conver to dict
 
85
  unused_tokens = self.config.unused_tokens
86
  unused_tokens = torch.tensor(unused_tokens, device=input_ids.device)
87
 
88
+ # Get valid matrix
89
  valid_indices = ~torch.isin(input_ids, unused_tokens)
90
+ # w>0
91
  valid_indices = (valid_indices & (token_weights > 0)).bool()
 
92
  valid_indices = (valid_indices & mask).bool()
93
 
94
  for i, valid in enumerate(valid_indices):
95
  result = defaultdict(int)
96
 
97
+ # Get valid weight and ids
98
  valid_weights = token_weights[i][valid]
99
  valid_ids = input_ids[i][valid]
100
 
101
+ # Get unique token
102
  unique_ids, inverse_indices = torch.unique(valid_ids, return_inverse=True)
103
 
104
+ # Get max weight for each token
105
  for i in range(unique_ids.shape[0]):
106
  id_mask = inverse_indices == i
107
  result[str(unique_ids[i].item())] = valid_weights[id_mask].max().item()
108
 
109
  all_result.append(result)
110
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  return all_result
112
 
113
+ # Copied from FlagEmbedding
114
  def _process_colbert_vecs(self, colbert_vecs, tokens_num) -> List[torch.Tensor]:
115
  # delte the vectors of padding tokens
116
  vecs = []