File size: 4,330 Bytes
680ab9f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
'''this eval code is borrowed from E5'''
import os
import json
import tqdm
import numpy as np
import torch
import argparse
from datasets import Dataset
from typing import List, Dict
from functools import partial
from transformers import AutoModel, AutoTokenizer, PreTrainedTokenizerFast, BatchEncoding, DataCollatorWithPadding
from transformers.modeling_outputs import BaseModelOutput
from torch.utils.data import DataLoader
from mteb import MTEB, AbsTaskRetrieval, DRESModel
from utils import pool, logger, move_to_cuda
parser = argparse.ArgumentParser(description='evaluation for BEIR benchmark')
parser.add_argument('--model-name-or-path', default='bert-base-uncased',
type=str, metavar='N', help='which model to use')
parser.add_argument('--output-dir', default='tmp-outputs/',
type=str, metavar='N', help='output directory')
parser.add_argument('--pool-type', default='avg', help='pool type')
parser.add_argument('--max-length', default=512, help='max length')
args = parser.parse_args()
logger.info('Args: {}'.format(json.dumps(args.__dict__, ensure_ascii=False, indent=4)))
assert args.pool_type in ['cls', 'avg'], 'pool_type should be cls or avg'
assert args.output_dir, 'output_dir should be set'
os.makedirs(args.output_dir, exist_ok=True)
def _transform_func(tokenizer: PreTrainedTokenizerFast,
examples: Dict[str, List]) -> BatchEncoding:
return tokenizer(examples['contents'],
max_length=int(args.max_length),
padding=True,
return_token_type_ids=False,
truncation=True)
class RetrievalModel(DRESModel):
# Refer to the code of DRESModel for the methods to overwrite
def __init__(self, **kwargs):
self.encoder = AutoModel.from_pretrained(args.model_name_or_path)
self.tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
self.gpu_count = torch.cuda.device_count()
if self.gpu_count > 1:
self.encoder = torch.nn.DataParallel(self.encoder)
self.encoder.cuda()
self.encoder.eval()
def encode_queries(self, queries: List[str], **kwargs) -> np.ndarray:
input_texts = ['查询: {}'.format(q) for q in queries]
return self._do_encode(input_texts)
def encode_corpus(self, corpus: List[Dict[str, str]], **kwargs) -> np.ndarray:
input_texts = ['{} {}'.format(doc.get('title', ''), doc['text']).strip() for doc in corpus]
input_texts = ['结果: {}'.format(t) for t in input_texts]
return self._do_encode(input_texts)
@torch.no_grad()
def _do_encode(self, input_texts: List[str]) -> np.ndarray:
dataset: Dataset = Dataset.from_dict({'contents': input_texts})
dataset.set_transform(partial(_transform_func, self.tokenizer))
data_collator = DataCollatorWithPadding(self.tokenizer, pad_to_multiple_of=8)
batch_size = 128 * self.gpu_count
data_loader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=False,
drop_last=False,
num_workers=4,
collate_fn=data_collator,
pin_memory=True)
encoded_embeds = []
for batch_dict in tqdm.tqdm(data_loader, desc='encoding', mininterval=10):
batch_dict = move_to_cuda(batch_dict)
with torch.cuda.amp.autocast():
outputs: BaseModelOutput = self.encoder(**batch_dict)
embeds = pool(outputs.last_hidden_state, batch_dict['attention_mask'], args.pool_type)
encoded_embeds.append(embeds.cpu().numpy())
return np.concatenate(encoded_embeds, axis=0)
TASKS = ["T2Retrieval", "MMarcoRetrieval", "DuRetrieval", "CovidRetrieval", "CmedqaRetrieval", "EcomRetrieval", "MedicalRetrieval", "VideoRetrieval"]
def main():
assert AbsTaskRetrieval.is_dres_compatible(RetrievalModel)
model = RetrievalModel()
task_names = [t.description["name"] for t in MTEB(tasks=TASKS).tasks]
logger.info('Tasks: {}'.format(task_names))
for task in task_names:
logger.info('Processing task: {}'.format(task))
evaluation = MTEB(tasks=[task])
evaluation.run(model, output_folder=args.output_dir, overwrite_results=False)
if __name__ == '__main__':
main()
|