PeteBleackley commited on
Commit
8454a19
1 Parent(s): 5c2caf3

Removed train_base_models task. Not needed with RoBERTa models

Browse files
Files changed (1) hide show
  1. scripts.py +1 -28
scripts.py CHANGED
@@ -6,7 +6,6 @@ import tokenizers
6
  import transformers
7
  import huggingface_hub
8
  import qarac.corpora.BNCorpus
9
- import qarac.corpora.Batcher
10
  import qarac.models.qarac_base_model
11
  import qarac.models.QaracTrainerModel
12
  import qarac.corpora.CombinedCorpus
@@ -72,30 +71,6 @@ def prepare_wiki_qa(filename,outfilename):
72
  data[['Cleaned_question','Resolved_answer','Label']].to_csv(outfilename)
73
 
74
 
75
- def train_base_model(task,filename):
76
- tokenizer = tokenizers.Tokenizer.from_pretrained('xlm-roberta-base')
77
- tokenizer.add_special_tokens(['<start>','<end>','<pad>'])
78
- tokenizer.save('/'.join([os.environ['HOME'],
79
- 'QARAC',
80
- 'models',
81
- 'tokenizer.json']))
82
- bnc = qarac.corpora.BNCorpus.BNCorpus(tokenizer=tokenizer,
83
- task=task)
84
- (train,test)=bnc.split(0.01)
85
- train_data=qarac.corpora.Batcher.Batcher(train)
86
- model = qarac.models.qarac_base_model.qarac_base_model(tokenizer.get_vocab_size(),
87
- 768,
88
- 12,
89
- task=='decode')
90
- #optimizer = keras.optimizers.Nadam(learning_rate=keras.optimizers.schedules.ExponentialDecay(1.0e-5, 100, 0.99))
91
- #model.compile(optimizer=optimizer,loss='sparse_categorical_crossentropy',metrics='accuracy')
92
- #model.fit(train_data,
93
- # epochs=100,
94
- # workers = 16,
95
- # use_multiprocessing=True)
96
- test_data=qarac.corpora.Batcher.Batcher(test)
97
- print(model.evaluate(test_data))
98
- model.save(filename)
99
 
100
  def prepare_training_datasets():
101
  wikiqa = pandas.read_csv('corpora/WikiQA.csv')
@@ -478,9 +453,7 @@ if __name__ == '__main__':
478
  parser.add_argument('-t','--training-task')
479
  parser.add_argument('-o','--outputfile')
480
  args = parser.parse_args()
481
- if args.task == 'train_base_model':
482
- train_base_model(args.training_task,args.filename)
483
- elif args.task == 'prepare_wiki_qa':
484
  prepare_wiki_qa(args.filename,args.outputfile)
485
  elif args.task == 'prepare_training_datasets':
486
  prepare_training_datasets()
 
6
  import transformers
7
  import huggingface_hub
8
  import qarac.corpora.BNCorpus
 
9
  import qarac.models.qarac_base_model
10
  import qarac.models.QaracTrainerModel
11
  import qarac.corpora.CombinedCorpus
 
71
  data[['Cleaned_question','Resolved_answer','Label']].to_csv(outfilename)
72
 
73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
  def prepare_training_datasets():
76
  wikiqa = pandas.read_csv('corpora/WikiQA.csv')
 
453
  parser.add_argument('-t','--training-task')
454
  parser.add_argument('-o','--outputfile')
455
  args = parser.parse_args()
456
+ if args.task == 'prepare_wiki_qa':
 
 
457
  prepare_wiki_qa(args.filename,args.outputfile)
458
  elif args.task == 'prepare_training_datasets':
459
  prepare_training_datasets()