File size: 3,003 Bytes
a16d1d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
from sentence_transformers import SentenceTransformer, LoggingHandler, util, models, evaluation, losses, InputExample, CrossEncoder
from torch import nn
import csv
from torch.utils.data import DataLoader, Dataset
import torch
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator, SentenceEvaluator, SimilarityFunction, RerankingEvaluator
from sentence_transformers.cross_encoder.evaluation import CERerankingEvaluator
import logging
import json
import random
import gzip

model_name = 'sentence-transformers/all-MiniLM-L6-v2'

train_batch_size = 100
max_seq_length = 128
num_epochs = 1
warmup_steps = 1000
model_save_path = 'cos-exp'
lr = 2e-5

class ESCIDataset(Dataset):
    def __init__(self, input):
        self.queries = []
        with gzip.open(input) as jsonfile:
            for line in jsonfile.readlines():
                query = json.loads(line)
                for p in query['e']:
                    positive = p['title']
                    self.queries.append(InputExample(texts=[query['query'], positive], label=1.0))
                for p in query['s']:
                    positive = p['title']
                    self.queries.append(InputExample(texts=[query['query'], positive], label=0.1))
                for p in query['c']:
                    positive = p['title']
                    self.queries.append(InputExample(texts=[query['query'], positive], label=0.01))
                for p in query['i']:
                    positive = p['title']
                    self.queries.append(InputExample(texts=[query['query'], positive], label=0.0))

    def __getitem__(self, item):
        return self.queries[item]

    def __len__(self):
        return len(self.queries)


model = SentenceTransformer(model_name, device='cpu')
model.max_seq_length = max_seq_length


train_dataset = ESCIDataset(input='train-small.json.gz')
eval_dataset = ESCIDataset(input='test-small.json.gz')
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=train_batch_size)
train_loss = losses.CosineSimilarityLoss(model=model)

# samples = {}
# for query in eval_dataset.queries:
#     qstr = query.texts[0]
#     sample = samples.get(qstr, {'query': qstr})
#     positive = sample.get('positive', [])
#     positive.append(query.texts[1])
#     sample['positive'] = positive
#     negative = sample.get('negative', [])
#     negative.append(query.texts[2])
#     sample['negative'] = negative
#     samples[qstr] = sample

# evaluator = RerankingEvaluator(samples=samples,name='esci')

# Train the model

model.fit(train_objectives=[(train_dataloader, train_loss)],
          epochs=num_epochs,
          warmup_steps=warmup_steps,
          use_amp=True,
#          checkpoint_path=model_save_path,
#          checkpoint_save_steps=len(train_dataloader),
          optimizer_params = {'lr': lr},
#          evaluator=evaluator,
#          evaluation_steps=1000,
          output_path=model_save_path
          )

# Save the model

model.save(model_save_path)