lxe commited on
Commit
dc5c63f
1 Parent(s): 98c88e7

Properly insert eos tokens when training

Browse files
Files changed (1) hide show
  1. main.py +7 -1
main.py CHANGED
@@ -126,7 +126,10 @@ def tokenize_and_train(
126
  global tokenizer
127
 
128
  if (model is None): load_base_model()
129
- if (tokenizer is None): load_tokenizer()
 
 
 
130
 
131
  assert model is not None
132
  assert tokenizer is not None
@@ -134,6 +137,8 @@ def tokenize_and_train(
134
  tokenizer.pad_token_id = 0
135
 
136
  paragraphs = training_text.split("\n\n\n")
 
 
137
  print("Number of samples: " + str(len(paragraphs)))
138
 
139
  def tokenize(item):
@@ -242,6 +247,7 @@ def tokenize_and_train(
242
  ),
243
  )
244
 
 
245
  result = trainer.train(resume_from_checkpoint=False)
246
  model.save_pretrained(output_dir)
247
 
 
126
  global tokenizer
127
 
128
  if (model is None): load_base_model()
129
+ if (tokenizer is None):
130
+ tokenizer = transformers.LlamaTokenizer.from_pretrained(
131
+ "decapoda-research/llama-7b-hf", add_eos_token=True
132
+ )
133
 
134
  assert model is not None
135
  assert tokenizer is not None
 
137
  tokenizer.pad_token_id = 0
138
 
139
  paragraphs = training_text.split("\n\n\n")
140
+ paragraphs = [x.strip() for x in paragraphs]
141
+
142
  print("Number of samples: " + str(len(paragraphs)))
143
 
144
  def tokenize(item):
 
247
  ),
248
  )
249
 
250
+ model.config.use_cache = False
251
  result = trainer.train(resume_from_checkpoint=False)
252
  model.save_pretrained(output_dir)
253