zyznull commited on
Commit
4492085
1 Parent(s): a883b02

Update scripts/eval_mteb.py

Browse files
Files changed (1) hide show
  1. scripts/eval_mteb.py +25 -8
scripts/eval_mteb.py CHANGED
@@ -405,7 +405,9 @@ class Wrapper:
405
  self._target_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
406
  self.eod_id = self.tokenizer.convert_tokens_to_ids("<|endoftext|>")
407
  self.instruction = instruction
408
-
 
 
409
  if self.tokenizer.padding_side != 'right':
410
  logger.warning(f"Change tokenizer.padding_side from {self.tokenizer.padding_side} to right")
411
  self.tokenizer.padding_side = 'right'
@@ -544,9 +546,9 @@ class Wrapper:
544
 
545
  def _tokenize(self, sentences: List[str], is_query: bool):
546
 
547
- batch_dict = tokenizer(sentences, max_length=max_length - 1, return_attention_mask=False, padding=False, truncation=True)
548
- batch_dict['input_ids'] = [input_ids + [tokenizer.eos_token_id] for input_ids in batch_dict['input_ids']]
549
- batch_dict = tokenizer.pad(batch_dict, padding=True, return_attention_mask=True, return_tensors='pt')
550
  batch_dict['is_causal'] = False
551
  return batch_dict
552
 
@@ -672,13 +674,15 @@ class Wrapper:
672
  def main(args):
673
  tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True)
674
  encoder = Encoder(args.model, args.pooling)
 
675
  model = Wrapper(
676
  tokenizer, encoder,
677
  batch_size=args.batch_size,
678
  max_seq_len=args.max_seq_len,
679
- normalize_embeddings=args.norm
 
680
  )
681
-
682
  if args.task == 'mteb':
683
  task_names = MTEB_TASK_LIST
684
  lang = ['en']
@@ -706,8 +710,21 @@ def main(args):
706
  eval_splits = task_cls.description['eval_splits']
707
  else:
708
  eval_splits = ["test"]
709
-
 
 
 
 
 
 
 
 
 
710
  evaluation.run(model, output_folder=args.output_dir, eval_splits=eval_splits)
 
 
 
 
711
  print('\n')
712
 
713
 
@@ -726,4 +743,4 @@ if __name__ == "__main__":
726
  )
727
  _PARSER.add_argument("--norm", action="store_true")
728
  _ARGS = _PARSER.parse_args()
729
- main(_ARGS)
 
405
  self._target_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
406
  self.eod_id = self.tokenizer.convert_tokens_to_ids("<|endoftext|>")
407
  self.instruction = instruction
408
+ self.default_query = default_query
409
+ self.sep = sep
410
+ self.force_default = force_default
411
  if self.tokenizer.padding_side != 'right':
412
  logger.warning(f"Change tokenizer.padding_side from {self.tokenizer.padding_side} to right")
413
  self.tokenizer.padding_side = 'right'
 
546
 
547
  def _tokenize(self, sentences: List[str], is_query: bool):
548
 
549
+ batch_dict = self.tokenizer(sentences, max_length=self.max_seq_len - 1, return_attention_mask=False, padding=False, truncation=True)
550
+ batch_dict['input_ids'] = [input_ids + [self.tokenizer.eos_token_id] for input_ids in batch_dict['input_ids']]
551
+ batch_dict = self.tokenizer.pad(batch_dict, padding=True, return_attention_mask=True, return_tensors='pt')
552
  batch_dict['is_causal'] = False
553
  return batch_dict
554
 
 
674
  def main(args):
675
  tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True)
676
  encoder = Encoder(args.model, args.pooling)
677
+ default_query = args.default_type == 'query'
678
  model = Wrapper(
679
  tokenizer, encoder,
680
  batch_size=args.batch_size,
681
  max_seq_len=args.max_seq_len,
682
+ normalize_embeddings=args.norm,
683
+ default_query=default_query
684
  )
685
+ sym_retrievals = ['QuoraRetrieval', 'ArguAna', 'CQADupstack']
686
  if args.task == 'mteb':
687
  task_names = MTEB_TASK_LIST
688
  lang = ['en']
 
710
  eval_splits = task_cls.description['eval_splits']
711
  else:
712
  eval_splits = ["test"]
713
+ sym = False
714
+ for name in sym_retrievals:
715
+ if task.startswith(name):
716
+ sym = True
717
+ break
718
+ else:
719
+ sym = False
720
+ if sym:
721
+ logger.info(f"Switch to symmetric mode for {task}, all as {'query' if default_query else 'doc'}.")
722
+ model.force_default = True
723
  evaluation.run(model, output_folder=args.output_dir, eval_splits=eval_splits)
724
+
725
+ if sym:
726
+ logger.info(f"Switch back.")
727
+ model.force_default = force_default_ori
728
  print('\n')
729
 
730
 
 
743
  )
744
  _PARSER.add_argument("--norm", action="store_true")
745
  _ARGS = _PARSER.parse_args()
746
+ main(_ARGS)