KevinHuSh commited on
Commit
92cae19
·
1 Parent(s): b802ae5

refine rerank (#1056)

Browse files

### What problem does this PR solve?


### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)

Files changed (1) hide show
  1. rag/llm/rerank_model.py +5 -3
rag/llm/rerank_model.py CHANGED
@@ -67,12 +67,12 @@ class DefaultRerank(Base):
67
  token_count = 0
68
  for _, t in pairs:
69
  token_count += num_tokens_from_string(t)
70
- batch_size = 32
71
  res = []
72
  for i in range(0, len(pairs), batch_size):
73
  scores = self._model.compute_score(pairs[i:i + batch_size], max_length=2048)
74
- scores = sigmoid(np.array(scores)).tolist()
75
- res.extend(scores)
76
  return np.array(res), token_count
77
 
78
 
@@ -124,7 +124,9 @@ class YoudaoRerank(DefaultRerank):
124
  for i in range(0, len(pairs), batch_size):
125
  scores = self._model.compute_score(pairs[i:i + batch_size], max_length=self._model.max_length)
126
  scores = sigmoid(np.array(scores)).tolist()
 
127
  res.extend(scores)
128
  return np.array(res), token_count
129
 
130
 
 
 
67
  token_count = 0
68
  for _, t in pairs:
69
  token_count += num_tokens_from_string(t)
70
+ batch_size = 4096
71
  res = []
72
  for i in range(0, len(pairs), batch_size):
73
  scores = self._model.compute_score(pairs[i:i + batch_size], max_length=2048)
74
+ if isinstance(scores, float): res.append(scores)
75
+ else: res.extend(scores)
76
  return np.array(res), token_count
77
 
78
 
 
124
  for i in range(0, len(pairs), batch_size):
125
  scores = self._model.compute_score(pairs[i:i + batch_size], max_length=self._model.max_length)
126
  scores = sigmoid(np.array(scores)).tolist()
127
+ if isinstance(scores, float): res.append(scores)
128
  res.extend(scores)
129
  return np.array(res), token_count
130
 
131
 
132
+