|
|
|
import sys
|
|
import json
|
|
from torch.utils.data import DataLoader
|
|
from sentence_transformers import SentenceTransformer, LoggingHandler, util, models, evaluation, losses, InputExample
|
|
import logging
|
|
from datetime import datetime
|
|
import gzip
|
|
import os
|
|
import tarfile
|
|
from collections import defaultdict
|
|
from torch.utils.data import IterableDataset
|
|
import tqdm
|
|
from torch.utils.data import Dataset
|
|
import random
|
|
from shutil import copyfile
|
|
|
|
import argparse
|
|
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--train_batch_size", default=64, type=int)
|
|
parser.add_argument("--max_seq_length", default=300, type=int)
|
|
parser.add_argument("--model_name", required=True)
|
|
parser.add_argument("--max_passages", default=0, type=int)
|
|
parser.add_argument("--epochs", default=10, type=int)
|
|
parser.add_argument("--pooling", default="cls")
|
|
parser.add_argument("--negs_to_use", default=None, help="From which systems should negatives be used? Multiple systems seperated by comma. None = all")
|
|
parser.add_argument("--warmup_steps", default=1000, type=int)
|
|
parser.add_argument("--lr", default=2e-5, type=float)
|
|
parser.add_argument("--name", default='')
|
|
parser.add_argument("--num_negs_per_system", default=5, type=int)
|
|
parser.add_argument("--use_pre_trained_model", default=False, action="store_true")
|
|
parser.add_argument("--use_all_queries", default=False, action="store_true")
|
|
args = parser.parse_args()
|
|
|
|
print(args)
|
|
|
|
|
|
logging.basicConfig(format='%(asctime)s - %(message)s',
|
|
datefmt='%Y-%m-%d %H:%M:%S',
|
|
level=logging.INFO,
|
|
handlers=[LoggingHandler()])
|
|
|
|
|
|
|
|
train_batch_size = args.train_batch_size
|
|
model_name = args.model_name
|
|
max_passages = args.max_passages
|
|
max_seq_length = args.max_seq_length
|
|
|
|
num_negs_per_system = args.num_negs_per_system
|
|
num_epochs = args.epochs
|
|
|
|
|
|
if args.use_pre_trained_model:
|
|
print("use pretrained SBERT model")
|
|
model = SentenceTransformer(model_name)
|
|
model.max_seq_length = max_seq_length
|
|
else:
|
|
print("Create new SBERT model")
|
|
word_embedding_model = models.Transformer(model_name, max_seq_length=max_seq_length)
|
|
pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(), args.pooling)
|
|
model = SentenceTransformer(modules=[word_embedding_model, pooling_model])
|
|
|
|
model_save_path = f'output/train_bi-encoder-margin_mse_en-{args.name}-{model_name.replace("/", "-")}-batch_size_{train_batch_size}-{datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}'
|
|
|
|
|
|
|
|
os.makedirs(model_save_path, exist_ok=True)
|
|
|
|
train_script_path = os.path.join(model_save_path, 'train_script.py')
|
|
copyfile(__file__, train_script_path)
|
|
with open(train_script_path, 'a') as fOut:
|
|
fOut.write("\n\n# Script was called via:\n#python " + " ".join(sys.argv))
|
|
|
|
|
|
|
|
data_folder = 'msmarco-data'
|
|
|
|
|
|
corpus = {}
|
|
collection_filepath = os.path.join(data_folder, 'collection.tsv')
|
|
if not os.path.exists(collection_filepath):
|
|
tar_filepath = os.path.join(data_folder, 'collection.tar.gz')
|
|
if not os.path.exists(tar_filepath):
|
|
logging.info("Download collection.tar.gz")
|
|
util.http_get('https://msmarco.blob.core.windows.net/msmarcoranking/collection.tar.gz', tar_filepath)
|
|
|
|
with tarfile.open(tar_filepath, "r:gz") as tar:
|
|
tar.extractall(path=data_folder)
|
|
|
|
logging.info("Read corpus: collection.tsv")
|
|
with open(collection_filepath, 'r', encoding='utf8') as fIn:
|
|
for line in fIn:
|
|
pid, passage = line.strip().split("\t")
|
|
corpus[pid] = passage
|
|
|
|
|
|
|
|
queries = {}
|
|
queries_filepath = os.path.join(data_folder, 'queries.train.tsv')
|
|
if not os.path.exists(queries_filepath):
|
|
tar_filepath = os.path.join(data_folder, 'queries.tar.gz')
|
|
if not os.path.exists(tar_filepath):
|
|
logging.info("Download queries.tar.gz")
|
|
util.http_get('https://msmarco.blob.core.windows.net/msmarcoranking/queries.tar.gz', tar_filepath)
|
|
|
|
with tarfile.open(tar_filepath, "r:gz") as tar:
|
|
tar.extractall(path=data_folder)
|
|
|
|
|
|
with open(queries_filepath, 'r', encoding='utf8') as fIn:
|
|
for line in fIn:
|
|
qid, query = line.strip().split("\t")
|
|
queries[qid] = query
|
|
|
|
|
|
|
|
|
|
|
|
train_filepath = '/home/msmarco/data/hard-negatives/msmarco-hard-negatives-v6.jsonl.gz'
|
|
|
|
|
|
logging.info("Read train dataset")
|
|
train_queries = {}
|
|
ce_scores = {}
|
|
negs_to_use = None
|
|
with gzip.open(train_filepath, 'rt') as fIn:
|
|
for line in tqdm.tqdm(fIn):
|
|
if max_passages > 0 and len(train_queries) >= max_passages:
|
|
break
|
|
|
|
data = json.loads(line)
|
|
|
|
if data['qid'] not in ce_scores:
|
|
ce_scores[data['qid']] = {}
|
|
|
|
|
|
for item in data['pos'] :
|
|
ce_scores[data['qid']][item['pid']] = item['ce-score']
|
|
|
|
|
|
pos_pids = [item['pid'] for item in data['pos']]
|
|
|
|
|
|
neg_pids = set()
|
|
if negs_to_use is None:
|
|
if args.negs_to_use is not None:
|
|
negs_to_use = args.negs_to_use.split(",")
|
|
else:
|
|
negs_to_use = list(data['neg'].keys())
|
|
print("Using negatives from the following systems:", negs_to_use)
|
|
|
|
for system_name in negs_to_use:
|
|
if system_name not in data['neg']:
|
|
continue
|
|
|
|
system_negs = data['neg'][system_name]
|
|
|
|
negs_added = 0
|
|
for item in system_negs:
|
|
|
|
ce_scores[data['qid']][item['pid']] = item['ce-score']
|
|
|
|
pid = item['pid']
|
|
if pid not in neg_pids:
|
|
neg_pids.add(pid)
|
|
negs_added += 1
|
|
if negs_added >= num_negs_per_system:
|
|
break
|
|
|
|
if args.use_all_queries or (len(pos_pids) > 0 and len(neg_pids) > 0):
|
|
train_queries[data['qid']] = {'qid': data['qid'], 'query': queries[data['qid']], 'pos': pos_pids, 'neg': neg_pids}
|
|
|
|
logging.info("Train queries: {}".format(len(train_queries)))
|
|
|
|
|
|
|
|
class MSMARCODataset(Dataset):
|
|
def __init__(self, queries, corpus, ce_scores):
|
|
self.queries = queries
|
|
self.queries_ids = list(queries.keys())
|
|
self.corpus = corpus
|
|
self.ce_scores = ce_scores
|
|
|
|
for qid in self.queries:
|
|
self.queries[qid]['pos'] = list(self.queries[qid]['pos'])
|
|
self.queries[qid]['neg'] = list(self.queries[qid]['neg'])
|
|
random.shuffle(self.queries[qid]['neg'])
|
|
|
|
def __getitem__(self, item):
|
|
query = self.queries[self.queries_ids[item]]
|
|
query_text = query['query']
|
|
qid = query['qid']
|
|
|
|
if len(query['pos']) > 0:
|
|
pos_id = query['pos'].pop(0)
|
|
pos_text = self.corpus[pos_id]
|
|
query['pos'].append(pos_id)
|
|
else:
|
|
pos_id = query['neg'].pop(0)
|
|
pos_text = self.corpus[pos_id]
|
|
query['neg'].append(pos_id)
|
|
|
|
|
|
neg_id = query['neg'].pop(0)
|
|
neg_text = self.corpus[neg_id]
|
|
query['neg'].append(neg_id)
|
|
|
|
pos_score = self.ce_scores[qid][pos_id]
|
|
neg_score = self.ce_scores[qid][neg_id]
|
|
|
|
return InputExample(texts=[query_text, pos_text, neg_text], label=pos_score-neg_score)
|
|
|
|
def __len__(self):
|
|
return len(self.queries)
|
|
|
|
|
|
train_dataset = MSMARCODataset(queries=train_queries, corpus=corpus, ce_scores=ce_scores)
|
|
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=train_batch_size, drop_last=True)
|
|
train_loss = losses.MarginMSELoss(model=model)
|
|
|
|
|
|
model.fit(train_objectives=[(train_dataloader, train_loss)],
|
|
epochs=num_epochs,
|
|
warmup_steps=args.warmup_steps,
|
|
use_amp=True,
|
|
checkpoint_path=model_save_path,
|
|
checkpoint_save_steps=10000,
|
|
checkpoint_save_total_limit = 0,
|
|
optimizer_params = {'lr': args.lr},
|
|
)
|
|
|
|
|
|
model.save(model_save_path)
|
|
|
|
|
|
|
|
|