shuttie commited on
Commit
c4a5571
1 Parent(s): 103bfab

use proper model output dim

Browse files
finetune.py CHANGED
@@ -12,8 +12,8 @@ import gzip
12
 
13
  model_name = 'cross-encoder/ms-marco-MiniLM-L-12-v2'
14
 
15
- train_batch_size = 32
16
- max_seq_length = 128
17
  num_epochs = 1
18
  warmup_steps = 1000
19
  model_save_path = '.'
@@ -27,13 +27,13 @@ class ESCIDataset(Dataset):
27
  for line in jsonfile.readlines():
28
  query = json.loads(line)
29
  for doc in query['e']:
30
- self.queries.append(InputExample(texts=[query['query'], doc['title']], label=1.0))
31
  for doc in query['s']:
32
- self.queries.append(InputExample(texts=[query['query'], doc['title']], label=0.1))
33
  for doc in query['c']:
34
- self.queries.append(InputExample(texts=[query['query'], doc['title']], label=0.01))
35
  for doc in query['i']:
36
- self.queries.append(InputExample(texts=[query['query'], doc['title']], label=0.0))
37
 
38
  def __getitem__(self, item):
39
  return self.queries[item]
@@ -49,9 +49,9 @@ class ESCIEvalDataset(Dataset):
49
  query = json.loads(line)
50
  if len(query['e']) > 0 and len(query['i']) > 0:
51
  for p in query['e']:
52
- positive = p['title']
53
  for n in query['i']:
54
- negative = n['title']
55
  self.queries.append(InputExample(texts=[query['query'], positive, negative]))
56
 
57
  def __getitem__(self, item):
12
 
13
  model_name = 'cross-encoder/ms-marco-MiniLM-L-12-v2'
14
 
15
+ train_batch_size = 8
16
+ max_seq_length = 384
17
  num_epochs = 1
18
  warmup_steps = 1000
19
  model_save_path = '.'
27
  for line in jsonfile.readlines():
28
  query = json.loads(line)
29
  for doc in query['e']:
30
+ self.queries.append(InputExample(texts=[query['query'], doc['title'] + ' ' + doc['desc']], label=1.0))
31
  for doc in query['s']:
32
+ self.queries.append(InputExample(texts=[query['query'], doc['title'] + ' ' + doc['desc']], label=0.1))
33
  for doc in query['c']:
34
+ self.queries.append(InputExample(texts=[query['query'], doc['title'] + ' ' + doc['desc']], label=0.01))
35
  for doc in query['i']:
36
+ self.queries.append(InputExample(texts=[query['query'], doc['title'] + ' ' + doc['desc']], label=0.0))
37
 
38
  def __getitem__(self, item):
39
  return self.queries[item]
49
  query = json.loads(line)
50
  if len(query['e']) > 0 and len(query['i']) > 0:
51
  for p in query['e']:
52
+ positive = p['title'] + ' ' + p['title']
53
  for n in query['i']:
54
+ negative = n['title'] + ' ' + n['title']
55
  self.queries.append(InputExample(texts=[query['query'], positive, negative]))
56
 
57
  def __getitem__(self, item):
onnx_convert.py CHANGED
@@ -1,9 +1,9 @@
1
- from transformers import AutoTokenizer, AutoModel
2
  import torch
3
 
4
  max_seq_length=128
5
 
6
- model = AutoModel.from_pretrained(".")
7
  model.eval()
8
 
9
  inputs = {"input_ids": torch.ones(1, max_seq_length, dtype=torch.int64),
1
+ from transformers import AutoTokenizer, AutoModel, AutoModelForSequenceClassification
2
  import torch
3
 
4
  max_seq_length=128
5
 
6
+ model = AutoModelForSequenceClassification.from_pretrained(".")
7
  model.eval()
8
 
9
  inputs = {"input_ids": torch.ones(1, max_seq_length, dtype=torch.int64),
pytorch_model.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:7dcfb2efa8e9be4d55c8353e38f61ccfd7223e0bfc2f24ab8af495b2cbbc8bc3
3
  size 133514357
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f8eb5889a76cfd3d6beaaf62bf061723ebf7edd212329fc527ff36c5ed1b571a
3
  size 133514357
pytorch_model.onnx CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:cb0312525f025d18e7013477ed8c389ad104591fdbfda838599762dad8608acb
3
- size 133694712
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5a0fe068eded0383c63e7e63e8d5fef4e6d30a5e4d3011b4e7d1602844fcd251
3
+ size 133717601
test-small.json.gz DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:fb557251b12addb55d94af30120d121dfa6391e58bcc4a9aee0f1d35cc2ea1c8
3
- size 8522018
 
 
 
train-small.json.gz DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:9c7c14a8910a3a6c09421a08a84cfc0e74fd198d0aaf43ab2c39250a8ae4e4dd
3
- size 19430577