|
import argparse |
|
import logging |
|
import os |
|
|
|
from mteb import MTEB |
|
from sentence_transformers import SentenceTransformer |
|
logging.basicConfig(level=logging.INFO) |
|
|
|
logger = logging.getLogger("main") |
|
|
|
CLASSIFICATION_LIST = ["TNews", "IFlyTek", "MultilingualSentiment", "JDReview", "OnlineShopping", "Waimai"] |
|
STS_LIST = ["ATEC", "BQ", "LCQMC", "PAWSX", "STSB", "AFQMC", "QBQTC"] |
|
PAIRCLASSIFICATION_LIST = ["Ocnli", "Cmnli"] |
|
RERANKING_LIST = ["T2Reranking", "MmarcoReranking", "CMedQAv1", "CMedQAv2"] |
|
CLUSTERING_LIST = ["CLSClusteringS2S", "CLSClusteringP2P", "ThuNewsClusteringS2S", "ThuNewsClusteringP2P"] |
|
TASK_LIST = [CLASSIFICATION_LIST, STS_LIST, PAIRCLASSIFICATION_LIST, RERANKING_LIST, CLUSTERING_LIST] |
|
names = ['Classification', 'STS', 'Pairclassification', 'Reranking', 'Clustering'] |
|
|
|
model = SentenceTransformer('piccolo-base-zh') |
|
for name, task_list in zip(names, TASK_LIST): |
|
for task in task_list: |
|
logger.info(f"Running task: {task}") |
|
evaluation = MTEB(tasks=[task]) |
|
evaluation.run(model, output_folder=f"results/{name}") |
|
|