|
from torch.utils.data import DataLoader |
|
from sentence_transformers import losses, util, models |
|
from sentence_transformers import SentencesDataset, LoggingHandler, SentenceTransformer, evaluation |
|
from sentence_transformers.readers import InputExample |
|
import logging |
|
from datetime import datetime |
|
import os |
|
from shutil import copyfile |
|
import sys |
|
import math |
|
import gzip |
|
import random |
|
import tqdm |
|
from transformers import AutoTokenizer, AutoModel, BertModel |
|
import transformers |
|
import torch |
|
from SPARTA import SPARTA |
|
import json |
|
import numpy as np |
|
from torch.cuda.amp import autocast |
|
import os |
|
from shutil import copyfile |
|
import datetime |
|
from collections import defaultdict |
|
from scipy.sparse import csc_matrix, csr_matrix |
|
|
|
random.seed(42) |
|
|
|
scaler = torch.cuda.amp.GradScaler() |
|
|
|
|
|
logging.basicConfig(format='%(asctime)s - %(message)s', |
|
datefmt='%Y-%m-%d %H:%M:%S', |
|
level=logging.INFO, |
|
handlers=[LoggingHandler()]) |
|
|
|
|
|
|
|
fill_gpu = torch.eye(85000, dtype=torch.float, device='cuda') |
|
del fill_gpu |
|
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
|
model_name = sys.argv[1] |
|
model = SPARTA(model_name, device) |
|
|
|
model_save_path = "output/msmarco-{}-{}".format(model_name.rstrip("/").split("/")[-1], datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")) |
|
model.tokenizer.save_pretrained(model_save_path) |
|
|
|
|
|
|
|
if 'distil' in model_name: |
|
batch_size, num_negatives = 4, 35 |
|
else: |
|
batch_size, num_negatives = 3, 20 |
|
|
|
logging.info(f"batch_size: {batch_size}") |
|
logging.info(f"num_neg: {num_negatives}") |
|
|
|
|
|
|
|
os.makedirs(model_save_path, exist_ok=True) |
|
|
|
train_script_path = os.path.join(model_save_path, 'train_script.py') |
|
copyfile(__file__, train_script_path) |
|
with open(train_script_path, 'a') as fOut: |
|
fOut.write("\n\n# Script was called via:\n#python " + " ".join(sys.argv)) |
|
|
|
|
|
|
|
corpus = {} |
|
train_queries = {} |
|
|
|
|
|
|
|
|
|
|
|
|
|
logging.info("Create dev dataset") |
|
dev_corpus_max_size = 100*1000 |
|
|
|
dev_queries_file = '../data/queries.dev.small.tsv' |
|
needed_pids = set() |
|
needed_qids = set() |
|
dev_qids = set() |
|
|
|
dev_queries = {} |
|
dev_corpus = {} |
|
dev_rel_docs = {} |
|
|
|
with open(dev_queries_file) as fIn: |
|
for line in fIn: |
|
qid, query = line.strip().split("\t") |
|
dev_qids.add(qid) |
|
|
|
with open('../data/qrels.dev.tsv') as fIn: |
|
for line in fIn: |
|
qid, _, pid, _ = line.strip().split('\t') |
|
|
|
if qid not in dev_qids: |
|
continue |
|
|
|
if qid not in dev_rel_docs: |
|
dev_rel_docs[qid] = set() |
|
dev_rel_docs[qid].add(pid) |
|
|
|
needed_pids.add(pid) |
|
needed_qids.add(qid) |
|
|
|
with open(dev_queries_file) as fIn: |
|
for line in fIn: |
|
qid, query = line.strip().split("\t") |
|
if qid in needed_qids: |
|
dev_queries[qid] = query |
|
|
|
with gzip.open('../data/collection-rnd.tsv.gz', 'rt') as fIn: |
|
for line in fIn: |
|
pid, passage = line.strip().split("\t") |
|
if pid in needed_pids or dev_corpus_max_size <= 0 or len(dev_corpus) <= dev_corpus_max_size: |
|
dev_corpus[pid] = passage |
|
|
|
dev_corpus_pids = list(dev_corpus.keys()) |
|
dev_corpus = [dev_corpus[pid] for pid in dev_corpus_pids] |
|
|
|
|
|
|
|
def compute_passage_emb(passages): |
|
sparse_embeddings = [] |
|
bert_input_emb = model.bert_model.embeddings.word_embeddings(torch.tensor(list(range(0, len(model.tokenizer))), device=device)) |
|
sparse_vec_size = 2000 |
|
|
|
|
|
for special_id in model.tokenizer.all_special_ids: |
|
bert_input_emb[special_id] = 0 * bert_input_emb[special_id] |
|
|
|
with torch.no_grad(): |
|
tokens = model.tokenizer(passages, padding=True, truncation=True, return_tensors='pt', max_length=500).to(device) |
|
passage_embeddings = model.bert_model(**tokens).last_hidden_state |
|
for passage_emb in passage_embeddings: |
|
scores = torch.matmul(bert_input_emb, passage_emb.transpose(0, 1)) |
|
max_scores = torch.max(scores, dim=-1).values |
|
relu_scores = torch.relu(max_scores) |
|
final_scores = torch.log(relu_scores + 1) |
|
|
|
top_results = torch.topk(final_scores, k=sparse_vec_size, sorted=True) |
|
passage_emb = defaultdict(float) |
|
for score, idx in zip(top_results[0].cpu().tolist(), top_results[1].cpu().tolist()): |
|
if score > 0: |
|
passage_emb[idx] = score |
|
else: |
|
break |
|
|
|
sparse_embeddings.append(passage_emb) |
|
|
|
return sparse_embeddings |
|
|
|
def evaluate_msmarco(): |
|
passage_embs_sorted = [] |
|
batch_size = 32 |
|
|
|
length_sorted_idx = np.argsort([-len(pas) for pas in dev_corpus]) |
|
dev_corpus_sorted = [dev_corpus[idx] for idx in length_sorted_idx] |
|
|
|
for start_idx in tqdm.trange(0, len(dev_corpus_sorted), batch_size, desc='encode corpus'): |
|
passage_embs_sorted.extend(compute_passage_emb(dev_corpus_sorted[start_idx:start_idx + batch_size])) |
|
|
|
passage_embs = [passage_embs_sorted[idx] for idx in np.argsort(length_sorted_idx)] |
|
|
|
logging.info("Create sparse matrix") |
|
row = [] |
|
col = [] |
|
values = [] |
|
for pid, emb in enumerate(passage_embs): |
|
for tid, score in emb.items(): |
|
row.append(tid) |
|
col.append(pid) |
|
values.append(score) |
|
|
|
sparse = csr_matrix((values, (row, col)), shape=(len(model.tokenizer), len(passage_embs)), dtype=np.float) |
|
logging.info("Scores: {}".format(sparse.shape)) |
|
|
|
mrr = [] |
|
k = 10 |
|
for qid, question in tqdm.tqdm(dev_queries.items(), desc="score"): |
|
token_ids = model.tokenizer(question, add_special_tokens=False)['input_ids'] |
|
|
|
|
|
scores = np.asarray(sparse[token_ids, :].sum(axis=0)).squeeze(0) |
|
top_k_ind = np.argpartition(scores, -k)[-k:] |
|
hits = sorted([(dev_corpus_pids[pid], scores[pid]) for pid in top_k_ind], key=lambda x: x[1], reverse=True) |
|
|
|
mrr_score = 0 |
|
for rank, hit in enumerate(hits[0:10]): |
|
pid = hit[0] |
|
if pid in dev_rel_docs[qid]: |
|
mrr_score = 1 / (rank + 1) |
|
break |
|
mrr.append(mrr_score) |
|
|
|
assert len(mrr) == len(dev_queries) |
|
mrr = np.mean(mrr) |
|
logging.info("MRR@10: {:.4f}".format(mrr)) |
|
return mrr |
|
|
|
|
|
best_score = 0 |
|
|
|
|
|
|
|
|
|
|
|
|
|
with gzip.open('../data/collection.tsv.gz', 'rt') as fIn: |
|
for line in fIn: |
|
pid, passage = line.strip().split("\t") |
|
corpus[pid] = passage |
|
|
|
|
|
with open('../data/queries.train.tsv', 'r') as fIn: |
|
for line in fIn: |
|
qid, query = line.strip().split("\t") |
|
train_queries[qid] = {'query': query, |
|
'pos': set(), |
|
'soft-pos': set(), |
|
'neg': set()} |
|
|
|
|
|
|
|
|
|
with open('../data/qrels.train.tsv') as fIn: |
|
for line in fIn: |
|
qid, _, pid, _ = line.strip().split() |
|
train_queries[qid]['pos'].add(pid) |
|
|
|
|
|
logging.info("Clean train queries") |
|
deleted_queries = 0 |
|
for qid in list(train_queries.keys()): |
|
if len(train_queries[qid]['pos']) == 0: |
|
deleted_queries += 1 |
|
del train_queries[qid] |
|
continue |
|
|
|
logging.info("Deleted queries pos-empty: {}".format(deleted_queries)) |
|
|
|
for hard_neg_file in ['../data/hard-negatives-all.jsonl.gz']: |
|
logging.info("Read hard negatives: "+hard_neg_file) |
|
with gzip.open(hard_neg_file, 'rt') as fIn: |
|
try: |
|
for line in fIn: |
|
try: |
|
data = json.loads(line) |
|
except: |
|
continue |
|
qid = data['qid'] |
|
|
|
if qid in train_queries: |
|
neg_added = 0 |
|
max_neg_added = 100 |
|
|
|
hits = sorted(data['hits'], key=lambda x: x['score'] if 'score' in x else x['bm25-score'], reverse=True) |
|
for hit in hits: |
|
pid = hit['corpus_id'] if 'corpus_id' in hit else hit['pid'] |
|
|
|
if pid in train_queries[qid]['pos']: |
|
continue |
|
|
|
if hit['bert-score'] < 0.1 and neg_added < max_neg_added: |
|
train_queries[qid]['neg'].add(pid) |
|
neg_added += 1 |
|
elif hit['bert-score'] > 0.9: |
|
train_queries[qid]['soft-pos'].add(pid) |
|
except: |
|
pass |
|
|
|
|
|
logging.info("Clean train queries with empty neg set") |
|
deleted_queries = 0 |
|
for qid in list(train_queries.keys()): |
|
if len(train_queries[qid]['neg']) == 0: |
|
deleted_queries += 1 |
|
del train_queries[qid] |
|
continue |
|
|
|
logging.info("Deleted queries neg empty: {}".format(deleted_queries)) |
|
|
|
train_queries = list(train_queries.values()) |
|
for idx in range(len(train_queries)): |
|
train_queries[idx]['pos'] = list(train_queries[idx]['pos']) |
|
train_queries[idx]['neg'] = list(train_queries[idx]['neg']) |
|
train_queries[idx]['soft-pos'] = list(train_queries[idx]['soft-pos']) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
param_optimizer = list(model.named_parameters()) |
|
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] |
|
optimizer_grouped_parameters = [ |
|
{'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01}, |
|
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} |
|
] |
|
|
|
|
|
|
|
grad_acc_steps, lr = 1, 2e-5 |
|
|
|
|
|
|
|
num_epochs = 1 |
|
optimizer = transformers.AdamW(model.parameters(), lr=lr, eps=1e-6) |
|
t_total = math.ceil(len(train_queries)/batch_size*num_epochs) |
|
num_warmup_steps = int(t_total/grad_acc_steps * 0.1) |
|
scheduler = transformers.get_linear_schedule_with_warmup(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=t_total) |
|
loss_fct = torch.nn.CrossEntropyLoss() |
|
max_grad_norm = 1 |
|
|
|
|
|
for epoch in tqdm.trange(num_epochs, desc='Epochs'): |
|
random.shuffle(train_queries) |
|
idx = 0 |
|
for start_idx in tqdm.trange(0, len(train_queries), batch_size): |
|
idx += 1 |
|
if (idx) % 5000 == 0: |
|
score = evaluate_msmarco() |
|
if score > best_score: |
|
best_score = score |
|
model.bert_model.save_pretrained(model_save_path) |
|
logging.info(f"Save to {model_save_path}") |
|
|
|
batch = train_queries[start_idx:start_idx+batch_size] |
|
queries = [b['query'] for b in batch] |
|
|
|
|
|
passages = [corpus[random.choice(b['pos'])] for b in batch] |
|
|
|
|
|
for b in batch: |
|
for pid in random.sample(b['neg'], k=min(len(b['neg']), num_negatives)): |
|
passages.append(corpus[pid]) |
|
|
|
|
|
label = torch.tensor(list(range(len(batch))), device=device) |
|
|
|
|
|
with autocast(): |
|
final_scores = model(queries, passages) |
|
final_scores = 5*final_scores |
|
loss_value = loss_fct(final_scores, label) / grad_acc_steps |
|
|
|
scaler.scale(loss_value).backward() |
|
if (idx + 1) % grad_acc_steps == 0: |
|
scaler.unscale_(optimizer) |
|
torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm) |
|
scaler.step(optimizer) |
|
scaler.update() |
|
model.zero_grad() |
|
scheduler.step() |
|
|
|
|
|
""" |
|
#Normal FP32 with grad acc |
|
final_scores = model(query, passages) |
|
#Compute loss |
|
loss_value = loss_fct(final_scores, label) |
|
if grad_acc_steps > 1: |
|
loss_value /= grad_acc_steps |
|
loss_value.backward() |
|
|
|
if (idx+1) % grad_acc_steps == 0: |
|
torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm) |
|
optimizer.step() |
|
model.zero_grad() |
|
scheduler.step() |
|
""" |
|
|
|
|
|
logging.info("Final eval:") |
|
evaluate_msmarco() |
|
|
|
|
|
|