Fraser commited on
Commit
4286a16
·
1 Parent(s): 3d06908

add wiki sentences

Browse files
README.md CHANGED
@@ -4,7 +4,7 @@ A Transformer-VAE made using flax.
4
 
5
  Done as part of Huggingface community training ([see forum post](https://discuss.huggingface.co/t/train-a-vae-to-interpolate-on-english-sentences/7548)).
6
 
7
- Builds on T5, using an autoencoder to convert it into a VAE.
8
 
9
  [See training logs.](https://wandb.ai/fraser/flax-vae)
10
 
 
4
 
5
  Done as part of Huggingface community training ([see forum post](https://discuss.huggingface.co/t/train-a-vae-to-interpolate-on-english-sentences/7548)).
6
 
7
+ Builds on T5, using an autoencoder to convert it into an MMD-VAE.
8
 
9
  [See training logs.](https://wandb.ai/fraser/flax-vae)
10
 
datasets/{dataset.py → wiki_sentences.py} RENAMED
File without changes
train.py CHANGED
@@ -2,7 +2,7 @@
2
  Pre-training/Fine-tuning seq2seq models on autoencoding a dataset.
3
 
4
  TODO:
5
- - [ ] Get this running.
6
  - [x] Don't make decoder input ids.
7
  - [ ] Add reg loss
8
  - [x] calculate MMD loss
@@ -372,6 +372,13 @@ def main():
372
  config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
373
  )
374
 
 
 
 
 
 
 
 
375
  # Preprocessing the datasets.
376
  # First we tokenize all the texts.
377
  if training_args.do_train:
 
2
  Pre-training/Fine-tuning seq2seq models on autoencoding a dataset.
3
 
4
  TODO:
5
+ - [x] Get this running.
6
  - [x] Don't make decoder input ids.
7
  - [ ] Add reg loss
8
  - [x] calculate MMD loss
 
372
  config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
373
  )
374
 
375
+ if model_args.add_special_tokens:
376
+ special_tokens_dict = {'pad_token': '<PAD>', 'bos_token': '<BOS>', 'eos_token': '<EOS>'}
377
+ num_added_tokens = tokenizer.add_special_tokens(special_tokens_dict)
378
+ print('We have added', num_added_tokens, 'tokens to GPT2')
379
+ model.resize_token_embeddings(len(tokenizer))
380
+ assert tokenizer.pad_token == '<PAD>'
381
+
382
  # Preprocessing the datasets.
383
  # First we tokenize all the texts.
384
  if training_args.do_train:
wiki_sentences.sh ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export RUN_NAME=wiki_split
2
+
3
+ ./venv/bin/python train.py \
4
+ --t5_model_name_or_path="t5-base" \
5
+ --tokenizer_name=gpt2 \
6
+ --add_special_tokens \
7
+ --output_dir="output/${RUN_NAME}" \
8
+ --overwrite_output_dir \
9
+ --dataset_name=./datasets/wiki_sentences \
10
+ --input_ids_column=token_ids \
11
+ --do_train --do_eval \
12
+ --n_latent_tokens 6 \
13
+ --latent_token_size 32 \
14
+ --save_steps="2500" \
15
+ --eval_steps="2500" \
16
+ --block_size="256" \
17
+ --per_device_train_batch_size="10" \
18
+ --per_device_eval_batch_size="10" \
19
+ --learning_rate="5e-3" --warmup_steps="1000" \
20
+ --adam_beta1="0.9" --adam_beta2="0.98" --weight_decay="0.01" \
21
+ --overwrite_output_dir \
22
+ --num_train_epochs="1" \
23
+ --push_to_hub \
wiki_split.sh ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export RUN_NAME=wiki_split
2
+
3
+ ./venv/bin/python train.py \
4
+ --t5_model_name_or_path="t5-base" \
5
+ --output_dir="output/${RUN_NAME}" \
6
+ --overwrite_output_dir \
7
+ --dataset_name="wiki_split" \
8
+ --do_train --do_eval \
9
+ --n_latent_tokens 6 \
10
+ --latent_token_size 32 \
11
+ --save_steps="2500" \
12
+ --eval_steps="2500" \
13
+ --block_size="32" \
14
+ --per_device_train_batch_size="10" \
15
+ --per_device_eval_batch_size="10" \
16
+ --learning_rate="5e-3" --warmup_steps="1000" \
17
+ --adam_beta1="0.9" --adam_beta2="0.98" --weight_decay="0.01" \
18
+ --overwrite_output_dir \
19
+ --num_train_epochs="3" \
20
+ --push_to_hub \