|
import sys |
|
import json |
|
from torch.utils.data import DataLoader |
|
from sentence_transformers import SentenceTransformer, LoggingHandler, util, models, evaluation, losses, InputExample |
|
import logging |
|
from datetime import datetime |
|
import gzip |
|
import os |
|
import tarfile |
|
import tqdm |
|
from torch.utils.data import Dataset |
|
import random |
|
from shutil import copyfile |
|
import pickle |
|
import argparse |
|
|
|
|
|
logging.basicConfig(format='%(asctime)s - %(message)s', |
|
datefmt='%Y-%m-%d %H:%M:%S', |
|
level=logging.INFO, |
|
handlers=[LoggingHandler()]) |
|
|
|
|
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--train_batch_size", default=64, type=int) |
|
parser.add_argument("--max_seq_length", default=250, type=int) |
|
parser.add_argument("--model_name", default="nicoladecao/msmarco-word2vec256000-distilbert-base-uncased") |
|
parser.add_argument("--max_passages", default=0, type=int) |
|
parser.add_argument("--epochs", default=30, type=int) |
|
parser.add_argument("--pooling", default="mean") |
|
parser.add_argument("--negs_to_use", default=None, help="From which systems should negatives be used? Multiple systems seperated by comma. None = all") |
|
parser.add_argument("--warmup_steps", default=1000, type=int) |
|
parser.add_argument("--lr", default=2e-5, type=float) |
|
parser.add_argument("--num_negs_per_system", default=5, type=int) |
|
parser.add_argument("--use_all_queries", default=False, action="store_true") |
|
args = parser.parse_args() |
|
|
|
logging.info(str(args)) |
|
|
|
|
|
|
|
|
|
train_batch_size = args.train_batch_size |
|
model_name = args.model_name |
|
max_passages = args.max_passages |
|
max_seq_length = args.max_seq_length |
|
|
|
num_negs_per_system = args.num_negs_per_system |
|
num_epochs = args.epochs |
|
|
|
|
|
|
|
logging.info("Create new SBERT model") |
|
word_embedding_model = models.Transformer(model_name, max_seq_length=max_seq_length) |
|
pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(), args.pooling) |
|
model = SentenceTransformer(modules=[word_embedding_model, pooling_model]) |
|
|
|
|
|
word_embedding_model.auto_model.embeddings.requires_grad = False |
|
|
|
model_save_path = f'output/train_bi-encoder-margin_mse-word2vec-{model_name.replace("/", "-")}-batch_size_{train_batch_size}-{datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}' |
|
|
|
|
|
|
|
os.makedirs(model_save_path, exist_ok=True) |
|
|
|
train_script_path = os.path.join(model_save_path, 'train_script.py') |
|
copyfile(__file__, train_script_path) |
|
with open(train_script_path, 'a') as fOut: |
|
fOut.write("\n\n# Script was called via:\n#python " + " ".join(sys.argv)) |
|
|
|
|
|
|
|
data_folder = 'msmarco-data' |
|
|
|
|
|
corpus = {} |
|
collection_filepath = os.path.join(data_folder, 'collection.tsv') |
|
if not os.path.exists(collection_filepath): |
|
tar_filepath = os.path.join(data_folder, 'collection.tar.gz') |
|
if not os.path.exists(tar_filepath): |
|
logging.info("Download collection.tar.gz") |
|
util.http_get('https://msmarco.blob.core.windows.net/msmarcoranking/collection.tar.gz', tar_filepath) |
|
|
|
with tarfile.open(tar_filepath, "r:gz") as tar: |
|
tar.extractall(path=data_folder) |
|
|
|
logging.info("Read corpus: collection.tsv") |
|
with open(collection_filepath, 'r', encoding='utf8') as fIn: |
|
for line in fIn: |
|
pid, passage = line.strip().split("\t") |
|
pid = int(pid) |
|
corpus[pid] = passage |
|
|
|
|
|
|
|
queries = {} |
|
queries_filepath = os.path.join(data_folder, 'queries.train.tsv') |
|
if not os.path.exists(queries_filepath): |
|
tar_filepath = os.path.join(data_folder, 'queries.tar.gz') |
|
if not os.path.exists(tar_filepath): |
|
logging.info("Download queries.tar.gz") |
|
util.http_get('https://msmarco.blob.core.windows.net/msmarcoranking/queries.tar.gz', tar_filepath) |
|
|
|
with tarfile.open(tar_filepath, "r:gz") as tar: |
|
tar.extractall(path=data_folder) |
|
|
|
|
|
with open(queries_filepath, 'r', encoding='utf8') as fIn: |
|
for line in fIn: |
|
qid, query = line.strip().split("\t") |
|
qid = int(qid) |
|
queries[qid] = query |
|
|
|
|
|
|
|
|
|
ce_scores_file = os.path.join(data_folder, 'cross-encoder-ms-marco-MiniLM-L-6-v2-scores.pkl.gz') |
|
if not os.path.exists(ce_scores_file): |
|
logging.info("Download cross-encoder scores file") |
|
util.http_get('https://huggingface.co/datasets/sentence-transformers/msmarco-hard-negatives/resolve/main/cross-encoder-ms-marco-MiniLM-L-6-v2-scores.pkl.gz', ce_scores_file) |
|
|
|
logging.info("Load CrossEncoder scores dict") |
|
with gzip.open(ce_scores_file, 'rb') as fIn: |
|
ce_scores = pickle.load(fIn) |
|
|
|
|
|
hard_negatives_filepath = os.path.join(data_folder, 'msmarco-hard-negatives.jsonl.gz') |
|
if not os.path.exists(hard_negatives_filepath): |
|
logging.info("Download cross-encoder scores file") |
|
util.http_get('https://huggingface.co/datasets/sentence-transformers/msmarco-hard-negatives/resolve/main/msmarco-hard-negatives.jsonl.gz', hard_negatives_filepath) |
|
|
|
|
|
logging.info("Read hard negatives train file") |
|
train_queries = {} |
|
negs_to_use = None |
|
with gzip.open(hard_negatives_filepath, 'rt') as fIn: |
|
for line in tqdm.tqdm(fIn): |
|
if max_passages > 0 and len(train_queries) >= max_passages: |
|
break |
|
data = json.loads(line) |
|
|
|
|
|
pos_pids = data['pos'] |
|
|
|
|
|
neg_pids = set() |
|
if negs_to_use is None: |
|
if args.negs_to_use is not None: |
|
negs_to_use = args.negs_to_use.split(",") |
|
else: |
|
negs_to_use = list(data['neg'].keys()) |
|
logging.info("Using negatives from the following systems: {}".format(", ".join(negs_to_use))) |
|
|
|
for system_name in negs_to_use: |
|
if system_name not in data['neg']: |
|
continue |
|
|
|
system_negs = data['neg'][system_name] |
|
negs_added = 0 |
|
for pid in system_negs: |
|
if pid not in neg_pids: |
|
neg_pids.add(pid) |
|
negs_added += 1 |
|
if negs_added >= num_negs_per_system: |
|
break |
|
|
|
if args.use_all_queries or (len(pos_pids) > 0 and len(neg_pids) > 0): |
|
train_queries[data['qid']] = {'qid': data['qid'], 'query': queries[data['qid']], 'pos': pos_pids, 'neg': neg_pids} |
|
|
|
logging.info("Train queries: {}".format(len(train_queries))) |
|
|
|
|
|
|
|
class MSMARCODataset(Dataset): |
|
def __init__(self, queries, corpus, ce_scores): |
|
self.queries = queries |
|
self.queries_ids = list(queries.keys()) |
|
self.corpus = corpus |
|
self.ce_scores = ce_scores |
|
|
|
for qid in self.queries: |
|
self.queries[qid]['pos'] = list(self.queries[qid]['pos']) |
|
self.queries[qid]['neg'] = list(self.queries[qid]['neg']) |
|
random.shuffle(self.queries[qid]['neg']) |
|
|
|
def __getitem__(self, item): |
|
query = self.queries[self.queries_ids[item]] |
|
query_text = query['query'] |
|
qid = query['qid'] |
|
|
|
if len(query['pos']) > 0: |
|
pos_id = query['pos'].pop(0) |
|
pos_text = self.corpus[pos_id] |
|
query['pos'].append(pos_id) |
|
else: |
|
pos_id = query['neg'].pop(0) |
|
pos_text = self.corpus[pos_id] |
|
query['neg'].append(pos_id) |
|
|
|
|
|
neg_id = query['neg'].pop(0) |
|
neg_text = self.corpus[neg_id] |
|
query['neg'].append(neg_id) |
|
|
|
pos_score = self.ce_scores[qid][pos_id] |
|
neg_score = self.ce_scores[qid][neg_id] |
|
|
|
return InputExample(texts=[query_text, pos_text, neg_text], label=pos_score-neg_score) |
|
|
|
def __len__(self): |
|
return len(self.queries) |
|
|
|
|
|
train_dataset = MSMARCODataset(queries=train_queries, corpus=corpus, ce_scores=ce_scores) |
|
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=train_batch_size, drop_last=True) |
|
train_loss = losses.MarginMSELoss(model=model) |
|
|
|
|
|
model.fit(train_objectives=[(train_dataloader, train_loss)], |
|
epochs=num_epochs, |
|
warmup_steps=args.warmup_steps, |
|
use_amp=True, |
|
checkpoint_path=model_save_path, |
|
checkpoint_save_steps=10000, |
|
optimizer_params = {'lr': args.lr}, |
|
) |
|
|
|
|
|
model.save(model_save_path) |
|
|
|
|
|
|