|
from sentence_transformers import SentenceTransformer, LoggingHandler, util, models, evaluation, losses, InputExample, CrossEncoder |
|
from torch import nn |
|
import csv |
|
from torch.utils.data import DataLoader, Dataset |
|
import torch |
|
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator, SentenceEvaluator, SimilarityFunction, RerankingEvaluator |
|
from sentence_transformers.cross_encoder.evaluation import CERerankingEvaluator |
|
import logging |
|
import json |
|
import random |
|
import gzip |
|
|
|
model_name = 'cross-encoder/ms-marco-MiniLM-L-12-v2' |
|
|
|
train_batch_size = 8 |
|
max_seq_length = 384 |
|
num_epochs = 1 |
|
warmup_steps = 1000 |
|
model_save_path = '.' |
|
lr = 2e-5 |
|
|
|
class ESCIDataset(Dataset): |
|
def __init__(self, input): |
|
self.queries = [] |
|
self.posneg = [] |
|
with gzip.open(input) as jsonfile: |
|
for line in jsonfile.readlines(): |
|
query = json.loads(line) |
|
for doc in query['e']: |
|
self.queries.append(InputExample(texts=[query['query'], doc['title'] + ' ' + doc['desc']], label=1.0)) |
|
for doc in query['s']: |
|
self.queries.append(InputExample(texts=[query['query'], doc['title'] + ' ' + doc['desc']], label=0.1)) |
|
for doc in query['c']: |
|
self.queries.append(InputExample(texts=[query['query'], doc['title'] + ' ' + doc['desc']], label=0.01)) |
|
for doc in query['i']: |
|
self.queries.append(InputExample(texts=[query['query'], doc['title'] + ' ' + doc['desc']], label=0.0)) |
|
|
|
def __getitem__(self, item): |
|
return self.queries[item] |
|
|
|
def __len__(self): |
|
return len(self.queries) |
|
|
|
class ESCIEvalDataset(Dataset): |
|
def __init__(self, input): |
|
self.queries = [] |
|
with gzip.open(input) as jsonfile: |
|
for line in jsonfile.readlines(): |
|
query = json.loads(line) |
|
if len(query['e']) > 0 and len(query['i']) > 0: |
|
for p in query['e']: |
|
positive = p['title'] + ' ' + p['title'] |
|
for n in query['i']: |
|
negative = n['title'] + ' ' + n['title'] |
|
self.queries.append(InputExample(texts=[query['query'], positive, negative])) |
|
|
|
def __getitem__(self, item): |
|
return self.queries[item] |
|
|
|
def __len__(self): |
|
return len(self.queries) |
|
|
|
model = CrossEncoder(model_name, num_labels=1) |
|
model.max_seq_length = max_seq_length |
|
|
|
|
|
train_dataset = ESCIDataset(input='train-small.json.gz') |
|
eval_dataset = ESCIEvalDataset(input='test-small.json.gz') |
|
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=train_batch_size) |
|
|
|
samples = {} |
|
for query in eval_dataset.queries: |
|
qstr = query.texts[0] |
|
sample = samples.get(qstr, {'query': qstr}) |
|
positive = sample.get('positive', []) |
|
positive.append(query.texts[1]) |
|
sample['positive'] = positive |
|
negative = sample.get('negative', []) |
|
negative.append(query.texts[2]) |
|
sample['negative'] = negative |
|
samples[qstr] = sample |
|
|
|
evaluator = CERerankingEvaluator(samples=samples,name='esci') |
|
|
|
|
|
|
|
model.fit(train_dataloader=train_dataloader, |
|
epochs=num_epochs, |
|
warmup_steps=warmup_steps, |
|
use_amp=True, |
|
optimizer_params = {'lr': lr}, |
|
evaluator=evaluator, |
|
|
|
output_path=model_save_path |
|
) |
|
|
|
|
|
|
|
model.save(model_save_path) |
|
|