Update scripts/eval_mteb.py
Browse files- scripts/eval_mteb.py +25 -8
scripts/eval_mteb.py
CHANGED
@@ -405,7 +405,9 @@ class Wrapper:
|
|
405 |
self._target_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
406 |
self.eod_id = self.tokenizer.convert_tokens_to_ids("<|endoftext|>")
|
407 |
self.instruction = instruction
|
408 |
-
|
|
|
|
|
409 |
if self.tokenizer.padding_side != 'right':
|
410 |
logger.warning(f"Change tokenizer.padding_side from {self.tokenizer.padding_side} to right")
|
411 |
self.tokenizer.padding_side = 'right'
|
@@ -544,9 +546,9 @@ class Wrapper:
|
|
544 |
|
545 |
def _tokenize(self, sentences: List[str], is_query: bool):
|
546 |
|
547 |
-
batch_dict = tokenizer(sentences, max_length=
|
548 |
-
batch_dict['input_ids'] = [input_ids + [tokenizer.eos_token_id] for input_ids in batch_dict['input_ids']]
|
549 |
-
batch_dict = tokenizer.pad(batch_dict, padding=True, return_attention_mask=True, return_tensors='pt')
|
550 |
batch_dict['is_causal'] = False
|
551 |
return batch_dict
|
552 |
|
@@ -672,13 +674,15 @@ class Wrapper:
|
|
672 |
def main(args):
|
673 |
tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True)
|
674 |
encoder = Encoder(args.model, args.pooling)
|
|
|
675 |
model = Wrapper(
|
676 |
tokenizer, encoder,
|
677 |
batch_size=args.batch_size,
|
678 |
max_seq_len=args.max_seq_len,
|
679 |
-
normalize_embeddings=args.norm
|
|
|
680 |
)
|
681 |
-
|
682 |
if args.task == 'mteb':
|
683 |
task_names = MTEB_TASK_LIST
|
684 |
lang = ['en']
|
@@ -706,8 +710,21 @@ def main(args):
|
|
706 |
eval_splits = task_cls.description['eval_splits']
|
707 |
else:
|
708 |
eval_splits = ["test"]
|
709 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
710 |
evaluation.run(model, output_folder=args.output_dir, eval_splits=eval_splits)
|
|
|
|
|
|
|
|
|
711 |
print('\n')
|
712 |
|
713 |
|
@@ -726,4 +743,4 @@ if __name__ == "__main__":
|
|
726 |
)
|
727 |
_PARSER.add_argument("--norm", action="store_true")
|
728 |
_ARGS = _PARSER.parse_args()
|
729 |
-
main(_ARGS)
|
|
|
405 |
self._target_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
406 |
self.eod_id = self.tokenizer.convert_tokens_to_ids("<|endoftext|>")
|
407 |
self.instruction = instruction
|
408 |
+
self.default_query = default_query
|
409 |
+
self.sep = sep
|
410 |
+
self.force_default = force_default
|
411 |
if self.tokenizer.padding_side != 'right':
|
412 |
logger.warning(f"Change tokenizer.padding_side from {self.tokenizer.padding_side} to right")
|
413 |
self.tokenizer.padding_side = 'right'
|
|
|
546 |
|
547 |
def _tokenize(self, sentences: List[str], is_query: bool):
|
548 |
|
549 |
+
batch_dict = self.tokenizer(sentences, max_length=self.max_seq_len - 1, return_attention_mask=False, padding=False, truncation=True)
|
550 |
+
batch_dict['input_ids'] = [input_ids + [self.tokenizer.eos_token_id] for input_ids in batch_dict['input_ids']]
|
551 |
+
batch_dict = self.tokenizer.pad(batch_dict, padding=True, return_attention_mask=True, return_tensors='pt')
|
552 |
batch_dict['is_causal'] = False
|
553 |
return batch_dict
|
554 |
|
|
|
674 |
def main(args):
|
675 |
tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True)
|
676 |
encoder = Encoder(args.model, args.pooling)
|
677 |
+
default_query = args.default_type == 'query'
|
678 |
model = Wrapper(
|
679 |
tokenizer, encoder,
|
680 |
batch_size=args.batch_size,
|
681 |
max_seq_len=args.max_seq_len,
|
682 |
+
normalize_embeddings=args.norm,
|
683 |
+
default_query=default_query
|
684 |
)
|
685 |
+
sym_retrievals = ['QuoraRetrieval', 'ArguAna', 'CQADupstack']
|
686 |
if args.task == 'mteb':
|
687 |
task_names = MTEB_TASK_LIST
|
688 |
lang = ['en']
|
|
|
710 |
eval_splits = task_cls.description['eval_splits']
|
711 |
else:
|
712 |
eval_splits = ["test"]
|
713 |
+
sym = False
|
714 |
+
for name in sym_retrievals:
|
715 |
+
if task.startswith(name):
|
716 |
+
sym = True
|
717 |
+
break
|
718 |
+
else:
|
719 |
+
sym = False
|
720 |
+
if sym:
|
721 |
+
logger.info(f"Switch to symmetric mode for {task}, all as {'query' if default_query else 'doc'}.")
|
722 |
+
model.force_default = True
|
723 |
evaluation.run(model, output_folder=args.output_dir, eval_splits=eval_splits)
|
724 |
+
|
725 |
+
if sym:
|
726 |
+
logger.info(f"Switch back.")
|
727 |
+
model.force_default = force_default_ori
|
728 |
print('\n')
|
729 |
|
730 |
|
|
|
743 |
)
|
744 |
_PARSER.add_argument("--norm", action="store_true")
|
745 |
_ARGS = _PARSER.parse_args()
|
746 |
+
main(_ARGS)
|