add wiki sentences
Browse files- README.md +1 -1
- datasets/{dataset.py → wiki_sentences.py} +0 -0
- train.py +8 -1
- wiki_sentences.sh +23 -0
- wiki_split.sh +20 -0
README.md
CHANGED
|
@@ -4,7 +4,7 @@ A Transformer-VAE made using flax.
|
|
| 4 |
|
| 5 |
Done as part of Huggingface community training ([see forum post](https://discuss.huggingface.co/t/train-a-vae-to-interpolate-on-english-sentences/7548)).
|
| 6 |
|
| 7 |
-
Builds on T5, using an autoencoder to convert it into
|
| 8 |
|
| 9 |
[See training logs.](https://wandb.ai/fraser/flax-vae)
|
| 10 |
|
|
|
|
| 4 |
|
| 5 |
Done as part of Huggingface community training ([see forum post](https://discuss.huggingface.co/t/train-a-vae-to-interpolate-on-english-sentences/7548)).
|
| 6 |
|
| 7 |
+
Builds on T5, using an autoencoder to convert it into an MMD-VAE.
|
| 8 |
|
| 9 |
[See training logs.](https://wandb.ai/fraser/flax-vae)
|
| 10 |
|
datasets/{dataset.py → wiki_sentences.py}
RENAMED
|
File without changes
|
train.py
CHANGED
|
@@ -2,7 +2,7 @@
|
|
| 2 |
Pre-training/Fine-tuning seq2seq models on autoencoding a dataset.
|
| 3 |
|
| 4 |
TODO:
|
| 5 |
-
- [
|
| 6 |
- [x] Don't make decoder input ids.
|
| 7 |
- [ ] Add reg loss
|
| 8 |
- [x] calculate MMD loss
|
|
@@ -372,6 +372,13 @@ def main():
|
|
| 372 |
config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
|
| 373 |
)
|
| 374 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 375 |
# Preprocessing the datasets.
|
| 376 |
# First we tokenize all the texts.
|
| 377 |
if training_args.do_train:
|
|
|
|
| 2 |
Pre-training/Fine-tuning seq2seq models on autoencoding a dataset.
|
| 3 |
|
| 4 |
TODO:
|
| 5 |
+
- [x] Get this running.
|
| 6 |
- [x] Don't make decoder input ids.
|
| 7 |
- [ ] Add reg loss
|
| 8 |
- [x] calculate MMD loss
|
|
|
|
| 372 |
config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
|
| 373 |
)
|
| 374 |
|
| 375 |
+
if model_args.add_special_tokens:
|
| 376 |
+
special_tokens_dict = {'pad_token': '<PAD>', 'bos_token': '<BOS>', 'eos_token': '<EOS>'}
|
| 377 |
+
num_added_tokens = tokenizer.add_special_tokens(special_tokens_dict)
|
| 378 |
+
print('We have added', num_added_tokens, 'tokens to GPT2')
|
| 379 |
+
model.resize_token_embeddings(len(tokenizer))
|
| 380 |
+
assert tokenizer.pad_token == '<PAD>'
|
| 381 |
+
|
| 382 |
# Preprocessing the datasets.
|
| 383 |
# First we tokenize all the texts.
|
| 384 |
if training_args.do_train:
|
wiki_sentences.sh
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
export RUN_NAME=wiki_split
|
| 2 |
+
|
| 3 |
+
./venv/bin/python train.py \
|
| 4 |
+
--t5_model_name_or_path="t5-base" \
|
| 5 |
+
--tokenizer_name=gpt2 \
|
| 6 |
+
--add_special_tokens \
|
| 7 |
+
--output_dir="output/${RUN_NAME}" \
|
| 8 |
+
--overwrite_output_dir \
|
| 9 |
+
--dataset_name=./datasets/wiki_sentences \
|
| 10 |
+
--input_ids_column=token_ids \
|
| 11 |
+
--do_train --do_eval \
|
| 12 |
+
--n_latent_tokens 6 \
|
| 13 |
+
--latent_token_size 32 \
|
| 14 |
+
--save_steps="2500" \
|
| 15 |
+
--eval_steps="2500" \
|
| 16 |
+
--block_size="256" \
|
| 17 |
+
--per_device_train_batch_size="10" \
|
| 18 |
+
--per_device_eval_batch_size="10" \
|
| 19 |
+
--learning_rate="5e-3" --warmup_steps="1000" \
|
| 20 |
+
--adam_beta1="0.9" --adam_beta2="0.98" --weight_decay="0.01" \
|
| 21 |
+
--overwrite_output_dir \
|
| 22 |
+
--num_train_epochs="1" \
|
| 23 |
+
--push_to_hub \
|
wiki_split.sh
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
export RUN_NAME=wiki_split
|
| 2 |
+
|
| 3 |
+
./venv/bin/python train.py \
|
| 4 |
+
--t5_model_name_or_path="t5-base" \
|
| 5 |
+
--output_dir="output/${RUN_NAME}" \
|
| 6 |
+
--overwrite_output_dir \
|
| 7 |
+
--dataset_name="wiki_split" \
|
| 8 |
+
--do_train --do_eval \
|
| 9 |
+
--n_latent_tokens 6 \
|
| 10 |
+
--latent_token_size 32 \
|
| 11 |
+
--save_steps="2500" \
|
| 12 |
+
--eval_steps="2500" \
|
| 13 |
+
--block_size="32" \
|
| 14 |
+
--per_device_train_batch_size="10" \
|
| 15 |
+
--per_device_eval_batch_size="10" \
|
| 16 |
+
--learning_rate="5e-3" --warmup_steps="1000" \
|
| 17 |
+
--adam_beta1="0.9" --adam_beta2="0.98" --weight_decay="0.01" \
|
| 18 |
+
--overwrite_output_dir \
|
| 19 |
+
--num_train_epochs="3" \
|
| 20 |
+
--push_to_hub \
|