rynmurdock commited on
Commit
c5ca37a
1 Parent(s): 5b8f2e0
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. Optimus/.gitignore +8 -0
  2. Optimus/README.md +121 -0
  3. Optimus/code/README.md +41 -0
  4. Optimus/code/app.py +0 -0
  5. Optimus/code/examples/README.md +392 -0
  6. Optimus/code/examples/__pycache__/utils_glue.cpython-37.pyc +0 -0
  7. Optimus/code/examples/big_ae/__pycache__/grad_app.cpython-310.pyc +0 -0
  8. Optimus/code/examples/big_ae/__pycache__/utils.cpython-37.pyc +0 -0
  9. Optimus/code/examples/big_ae/debug_data.py +6 -0
  10. Optimus/code/examples/big_ae/eval_dialog_multi_response.py +378 -0
  11. Optimus/code/examples/big_ae/eval_dialog_response.py +295 -0
  12. Optimus/code/examples/big_ae/grad_app.py +486 -0
  13. Optimus/code/examples/big_ae/metrics.py +196 -0
  14. Optimus/code/examples/big_ae/modules/__init__.py +7 -0
  15. Optimus/code/examples/big_ae/modules/__pycache__/__init__.cpython-310.pyc +0 -0
  16. Optimus/code/examples/big_ae/modules/__pycache__/__init__.cpython-37.pyc +0 -0
  17. Optimus/code/examples/big_ae/modules/__pycache__/arae.cpython-310.pyc +0 -0
  18. Optimus/code/examples/big_ae/modules/__pycache__/arae.cpython-37.pyc +0 -0
  19. Optimus/code/examples/big_ae/modules/__pycache__/cara.cpython-310.pyc +0 -0
  20. Optimus/code/examples/big_ae/modules/__pycache__/cara.cpython-37.pyc +0 -0
  21. Optimus/code/examples/big_ae/modules/__pycache__/spacefusion.cpython-310.pyc +0 -0
  22. Optimus/code/examples/big_ae/modules/__pycache__/spacefusion.cpython-37.pyc +0 -0
  23. Optimus/code/examples/big_ae/modules/__pycache__/utils.cpython-310.pyc +0 -0
  24. Optimus/code/examples/big_ae/modules/__pycache__/utils.cpython-37.pyc +0 -0
  25. Optimus/code/examples/big_ae/modules/__pycache__/vae.cpython-310.pyc +0 -0
  26. Optimus/code/examples/big_ae/modules/__pycache__/vae.cpython-37.pyc +0 -0
  27. Optimus/code/examples/big_ae/modules/arae.py +274 -0
  28. Optimus/code/examples/big_ae/modules/cara.py +374 -0
  29. Optimus/code/examples/big_ae/modules/ctrl_gen.py +371 -0
  30. Optimus/code/examples/big_ae/modules/decoders/dec_gpt2.py +358 -0
  31. Optimus/code/examples/big_ae/modules/decoders/decoder.py +79 -0
  32. Optimus/code/examples/big_ae/modules/encoders/__init__.py +1 -0
  33. Optimus/code/examples/big_ae/modules/encoders/enc_lstm.py +126 -0
  34. Optimus/code/examples/big_ae/modules/encoders/encoder.py +58 -0
  35. Optimus/code/examples/big_ae/modules/encoders/gaussian_encoder.py +147 -0
  36. Optimus/code/examples/big_ae/modules/spacefusion.py +143 -0
  37. Optimus/code/examples/big_ae/modules/utils.py +40 -0
  38. Optimus/code/examples/big_ae/modules/vae.py +638 -0
  39. Optimus/code/examples/big_ae/run_data_filtering.py +507 -0
  40. Optimus/code/examples/big_ae/run_dialog_dataloader.py +483 -0
  41. Optimus/code/examples/big_ae/run_encoding_generation.py +487 -0
  42. Optimus/code/examples/big_ae/run_generation_from_prior.py +414 -0
  43. Optimus/code/examples/big_ae/run_gpt2_generation.py +390 -0
  44. Optimus/code/examples/big_ae/run_latent_generation.py +577 -0
  45. Optimus/code/examples/big_ae/run_lm_ae_pretraining.py +692 -0
  46. Optimus/code/examples/big_ae/run_lm_causal_pretraining.py +692 -0
  47. Optimus/code/examples/big_ae/run_lm_finetuning_baseline.py +573 -0
  48. Optimus/code/examples/big_ae/run_lm_gpt2_training.py +658 -0
  49. Optimus/code/examples/big_ae/run_lm_vae_label_ctrl_gen.py +875 -0
  50. Optimus/code/examples/big_ae/run_lm_vae_pretraining.py +669 -0
Optimus/.gitignore ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ data/datasets/glue_data/glue_data
2
+ data/datasets/glue_data/train.tx
3
+ data/datasets/glue_data/cached_lm_gpt_bert_256_train.jsont
4
+ code/runs
5
+ output/*
6
+ code/pytorch_transformers/__pycache__/*
7
+ code/examples/big_ae/modules/encoders/__pycache__/*
8
+
Optimus/README.md ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Optimus: the first pre-trained Big VAE language model <img src="doc/figs/logo_optimus.png" width="100" align="right">
2
+
3
+ This repository contains source code necessary to reproduce the results presented in the EMNLP 2020 paper [Optimus: Organizing Sentences via Pre-trained Modeling of a Latent Space](https://arxiv.org/abs/2004.04092).
4
+
5
+
6
+ |<img src="doc/figs/optimus_scheme.png" width="350"> | <img src="doc/figs/headfig_optimus.png" width="800">
7
+ |-------------------------|:-------------------------:|
8
+ | The network architecture of Optimus: encoder for representation learning and decoder for generation | Sentences are organized and manipulated in a pre-trained compact and smooth latent space
9
+
10
+
11
+ For more on this project, see the [Microsoft Research Blog post](https://www.microsoft.com/en-us/research/blog/a-deep-generative-model-trifecta-three-advances-that-work-towards-harnessing-large-scale-power/).
12
+
13
+
14
+ ## News
15
+
16
+ May 21, 2020: Releasing a [`demo`](http://40.71.23.172:8899/) for latent space manipulation, including sentence interpolation and analogy. Check out the [`website`](http://40.71.23.172:8899/).
17
+
18
+ May 20, 2020: The latent space manipulation code is cleaned and released. See instructions at [`optimius_for_snli.md`](doc/optimius_for_snli.md).
19
+
20
+ May 13, 2020: The fine-tuning code for langauge modeling is released. See instructions at [`optimus_finetune_language_models.md`](doc/optimus_finetune_language_models.md)
21
+
22
+ ## Contents
23
+ There are four steps to use this codebase to reproduce the results in the paper.
24
+
25
+ 1. [Dependencies](#dependencies)
26
+ 2. [Prepare datasets](#prepare-datasets)
27
+ 3. [Model training](#Model-training)
28
+ 1. Pre-training on setences in Wikipedia
29
+ 2. Languange Modeling
30
+ 3. Guided Language Generation
31
+ 4. Low-resource Language Understanding
32
+ 4. [Collect and plot results](#collect-and-plot-results)
33
+
34
+
35
+ ## Dependencies
36
+
37
+ Pull docker from Docker Hub at: `chunyl/pytorch-transformers:v2`. Please see the instruction at [`doc/env.md`](doc/env.md)
38
+
39
+ The project is organized into the following structures, with ensential files & folders visualized. `output` saves the models checkpoints.
40
+ ```
41
+ ├── Optimus
42
+    └── code
43
+    ├── examples
44
+           ├── big_ae
45
+ ├── modules
46
+ ├── vae.py
47
+ └── ...
48
+ ├── run_lm_vae_pretraining_phdist_beta.py
49
+ ├── run_lm_vae_training.py
50
+ └── ...
51
+ ├── pytorch_transformers
52
+ ├── modeling_bert.py
53
+ ├── modeling_gpt2.py
54
+ └── ...
55
+    ├── scripts
56
+ ├── scripts_docker
57
+ ├── scripts_local
58
+ ├── scripts_philly
59
+    └── data
60
+ └── datasets
61
+ ├── wikipedia_json_64_filtered
62
+ └── ...
63
+ ├── snli_data
64
+ └── ...
65
+    └── output
66
+ ├── pretrain
67
+ ├── LM
68
+ └── ...
69
+ ```
70
+
71
+ ## Prepare Datasets
72
+
73
+ Please download or preparation the data via following the instructions at [`data/download_datasets.md`](data/download_datasets.md).
74
+
75
+ ## Model Training
76
+
77
+ **1. Pre-training on setences in Wikipedia**
78
+
79
+ We pre-trained our models on Philly (a Microsoft internal compute cluster), the code is specialized for multi-node multi-GPU compute on this platform. The pre-training main python is [`run_lm_vae_pretraining_phdist_beta.py`](code/examples/big_ae/run_lm_vae_pretraining_phdist_beta.py). You may need to adjust the distributed training scripts.
80
+
81
+ **2. Languange Modeling**
82
+
83
+ To have a fair comparison with existing VAE languange models, we consider a model with latent dimension 32. The pre-trained model is fine-tuned on four commonly datasets for one epoch. Please see the details at [`doc/optimus_finetune_language_models.md`](doc/optimus_finetune_language_models.md)
84
+
85
+ **3. Guided Language Generation**
86
+
87
+
88
+ **Latent Space Manipulation** To ensure good performance, we consider a model with latent dimension 768. The pre-trained model is fine-tuned on SNLI dataset, where sentences show related patterns. Please see the details at
89
+ Please see the details at [`doc/optimius_for_snli.md`](doc/optimius_for_snli.md)
90
+
91
+ **4. Low-resource Language Understanding**
92
+
93
+ ## Collect and Plot Results
94
+
95
+ Once the networks are trained and the results are saved, we extracted key results using Python script. The results can be plotted using the included IPython notebook `plots/main_plots.ipynb`.
96
+ Start the IPython Notebook server:
97
+
98
+ ```
99
+ $ cd plots
100
+ $ ipython notebook
101
+ ```
102
+
103
+ Select the `main_plots.ipynb` notebook and execute the included
104
+ code. Note that without modification, we have copyed our extracted results into the notebook, and script will output figures in the paper. If you've run your own training and wish to plot results, you'll have to organize your results in the same format instead.
105
+
106
+
107
+ ## Questions?
108
+
109
+ Please drop me ([Chunyuan](http://chunyuan.li/)) a line if you have any questions.
110
+
111
+
112
+ ```
113
+ @inproceedings{li2020_Optimus,
114
+ title={Optimus: Organizing Sentences via Pre-trained Modeling of a Latent Space},
115
+ author={Li, Chunyuan and Gao, Xiang and Li, Yuan and Li, Xiujun and Peng, Baolin and Zhang, Yizhe and Gao, Jianfeng},
116
+ booktitle={EMNLP},
117
+ year={2020}
118
+ }
119
+ ```
120
+
121
+
Optimus/code/README.md ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Set up Environment
2
+
3
+ Pull docker from Docker Hub at: chunyl/pytorch-transformers:v2
4
+
5
+ Edit the project path to the absolute path on your computer by changing the "SCRIPTPATH" in [run_docker.sh](./scripts/scripts_docker/run_docker.sh)
6
+
7
+ In this directory ("code"), and run docker
8
+
9
+ sh scripts/scripts_docker/run_docker.sh
10
+
11
+
12
+
13
+
14
+ ## Fine-tune Language Models
15
+
16
+ sh scripts/scripts_local/run_ft_lm_vae_optimus.sh
17
+
18
+
19
+ The main training script is [`run_lm_vae_training.py`](./examples/big_ae/run_lm_vae_training.py) and conducts the fine-tuning loop, taking the following options (among others) as arguments:
20
+
21
+ - `--checkpoint_dir`: the folder that the pre-trained Optimus is saved.
22
+ - `--gloabl_step_eval`: it specifies the checkpoint (the steps that Optimus is trained).
23
+ - `--train_data_file` and `--eval_data_file`: the path for training and testing datasets for the downstream fine-tuning.
24
+ - `--dataset`: the dataset for fine-tuning. such as `Penn`
25
+ - `--num_train_epochs`: number of training epochs (type=int); default 1.
26
+ - `--dim_target_kl`: the hyper-paramter used in dimension-wise thresholding used in fine-tuning(type=float); default 0.5.
27
+ - `--beta`: the maximum beta value used in cyclical annealing schedule used in fine-tuning(type=float); default 1.0.
28
+ - `--ratio_zero`: the proportion of beta=0 in one period for fine-tuning(type=float); default 0.5
29
+ - `--ratio_increase`: the proportion of beta that increases from 0 to the maximum value in one period in cyclical annealing schedule used in fine-tuning(type=float); default 0.25.
30
+
31
+
32
+ For more options, please see [`run_lm_vae_training.py`](./examples/big_ae/run_lm_vae_training.py) and see the examples we provided in [`run_ft_lm_vae_optimus.sh`](./scripts/scripts_local/run_ft_lm_vae_optimus.sh), or [more running scripts we used to run the code on a cluster](./scripts/scripts_philly).
33
+
34
+
35
+ ## Play with the latent space
36
+
37
+ sh scripts/scripts_local/eval_optimus_latent_space.sh
38
+
39
+ The main training script is [`run_latent_generation.py`](./examples/big_ae/run_latent_generation.py) and evaluates the various ways to generate text conditioned on latent vectors, taking the following options (among others) as arguments:
40
+
41
+ - `--play_mode`: The current scripts supports two ways to play with the pre-trained VAE models: [`reconstrction`, `interpolation`]
Optimus/code/app.py ADDED
File without changes
Optimus/code/examples/README.md ADDED
@@ -0,0 +1,392 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Examples
2
+
3
+ In this section a few examples are put together. All of these examples work for several models, making use of the very
4
+ similar API between the different models.
5
+
6
+ | Section | Description |
7
+ |----------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------|
8
+ | [Language Model fine-tuning](#language-model-fine-tuning) | Fine-tuning the library models for language modeling on a text dataset. Causal language modeling for GPT/GPT-2, masked language modeling for BERT/RoBERTa. |
9
+ | [Language Generation](#language-generation) | Conditional text generation using the auto-regressive models of the library: GPT, GPT-2, Transformer-XL and XLNet. |
10
+ | [GLUE](#glue) | Examples running BERT/XLM/XLNet/RoBERTa on the 9 GLUE tasks. Examples feature distributed training as well as half-precision. |
11
+ | [SQuAD](#squad) | Using BERT for question answering, examples with distributed training. |
12
+ | [Multiple Choice](#multiple choice) | Examples running BERT/XLNet/RoBERTa on the SWAG/RACE/ARC tasks.
13
+
14
+ ## Language model fine-tuning
15
+
16
+ Based on the script [`run_lm_finetuning.py`](https://github.com/huggingface/pytorch-transformers/blob/master/examples/run_lm_finetuning.py).
17
+
18
+ Fine-tuning the library models for language modeling on a text dataset for GPT, GPT-2, BERT and RoBERTa (DistilBERT
19
+ to be added soon). GPT and GPT-2 are fine-tuned using a causal language modeling (CLM) loss while BERT and RoBERTa
20
+ are fine-tuned using a masked language modeling (MLM) loss.
21
+
22
+ Before running the following example, you should get a file that contains text on which the language model will be
23
+ fine-tuned. A good example of such text is the [WikiText-2 dataset](https://blog.einstein.ai/the-wikitext-long-term-dependency-language-modeling-dataset/).
24
+
25
+ We will refer to two different files: `$TRAIN_FILE`, which contains text for training, and `$TEST_FILE`, which contains
26
+ text that will be used for evaluation.
27
+
28
+ ### GPT-2/GPT and causal language modeling
29
+
30
+ The following example fine-tunes GPT-2 on WikiText-2. We're using the raw WikiText-2 (no tokens were replaced before
31
+ the tokenization). The loss here is that of causal language modeling.
32
+
33
+ ```bash
34
+ export TRAIN_FILE=/path/to/dataset/wiki.train.raw
35
+ export TEST_FILE=/path/to/dataset/wiki.test.raw
36
+
37
+ python run_lm_finetuning.py \
38
+ --output_dir=output \
39
+ --model_type=gpt2 \
40
+ --model_name_or_path=gpt2 \
41
+ --do_train \
42
+ --train_data_file=$TRAIN_FILE \
43
+ --do_eval \
44
+ --eval_data_file=$TEST_FILE
45
+ ```
46
+
47
+ This takes about half an hour to train on a single K80 GPU and about one minute for the evaluation to run. It reaches
48
+ a score of ~20 perplexity once fine-tuned on the dataset.
49
+
50
+ ### RoBERTa/BERT and masked language modeling
51
+
52
+ The following example fine-tunes RoBERTa on WikiText-2. Here too, we're using the raw WikiText-2. The loss is different
53
+ as BERT/RoBERTa have a bidirectional mechanism; we're therefore using the same loss that was used during their
54
+ pre-training: masked language modeling.
55
+
56
+ In accordance to the RoBERTa paper, we use dynamic masking rather than static masking. The model may, therefore, converge
57
+ slightly slower (over-fitting takes more epochs).
58
+
59
+ We use the `--mlm` flag so that the script may change its loss function.
60
+
61
+ ```bash
62
+ export TRAIN_FILE=/path/to/dataset/wiki.train.raw
63
+ export TEST_FILE=/path/to/dataset/wiki.test.raw
64
+
65
+ python run_lm_finetuning.py \
66
+ --output_dir=output \
67
+ --model_type=roberta \
68
+ --model_name_or_path=roberta-base \
69
+ --do_train \
70
+ --train_data_file=$TRAIN_FILE \
71
+ --do_eval \
72
+ --eval_data_file=$TEST_FILE \
73
+ --mlm
74
+ ```
75
+
76
+ ## Language generation
77
+
78
+ Based on the script [`run_generation.py`](https://github.com/huggingface/pytorch-transformers/blob/master/examples/run_generation.py).
79
+
80
+ Conditional text generation using the auto-regressive models of the library: GPT, GPT-2, Transformer-XL and XLNet.
81
+ A similar script is used for our official demo [Write With Transfomer](https://transformer.huggingface.co), where you
82
+ can try out the different models available in the library.
83
+
84
+ Example usage:
85
+
86
+ ```bash
87
+ python run_generation.py \
88
+ --model_type=gpt2 \
89
+ --model_name_or_path=gpt2
90
+ ```
91
+
92
+ ## GLUE
93
+
94
+ Based on the script [`run_glue.py`](https://github.com/huggingface/pytorch-transformers/blob/master/examples/run_glue.py).
95
+
96
+ Fine-tuning the library models for sequence classification on the GLUE benchmark: [General Language Understanding
97
+ Evaluation](https://gluebenchmark.com/). This script can fine-tune the following models: BERT, XLM, XLNet and RoBERTa.
98
+
99
+ GLUE is made up of a total of 9 different tasks. We get the following results on the dev set of the benchmark with an
100
+ uncased BERT base model (the checkpoint `bert-base-uncased`). All experiments ran on 8 V100 GPUs with a total train
101
+ batch size of 24. Some of these tasks have a small dataset and training can lead to high variance in the results
102
+ between different runs. We report the median on 5 runs (with different seeds) for each of the metrics.
103
+
104
+ | Task | Metric | Result |
105
+ |-------|------------------------------|-------------|
106
+ | CoLA | Matthew's corr | 48.87 |
107
+ | SST-2 | Accuracy | 91.74 |
108
+ | MRPC | F1/Accuracy | 90.70/86.27 |
109
+ | STS-B | Person/Spearman corr. | 91.39/91.04 |
110
+ | QQP | Accuracy/F1 | 90.79/87.66 |
111
+ | MNLI | Matched acc./Mismatched acc. | 83.70/84.83 |
112
+ | QNLI | Accuracy | 89.31 |
113
+ | RTE | Accuracy | 71.43 |
114
+ | WNLI | Accuracy | 43.66 |
115
+
116
+ Some of these results are significantly different from the ones reported on the test set
117
+ of GLUE benchmark on the website. For QQP and WNLI, please refer to [FAQ #12](https://gluebenchmark.com/faq) on the webite.
118
+
119
+ Before running anyone of these GLUE tasks you should download the
120
+ [GLUE data](https://gluebenchmark.com/tasks) by running
121
+ [this script](https://gist.github.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e)
122
+ and unpack it to some directory `$GLUE_DIR`.
123
+
124
+ ```bash
125
+ export GLUE_DIR=/path/to/glue
126
+ export TASK_NAME=MRPC
127
+
128
+ python run_glue.py \
129
+ --model_type bert \
130
+ --model_name_or_path bert-base-cased \
131
+ --task_name $TASK_NAME \
132
+ --do_train \
133
+ --do_eval \
134
+ --do_lower_case \
135
+ --data_dir $GLUE_DIR/$TASK_NAME \
136
+ --max_seq_length 128 \
137
+ --per_gpu_train_batch_size 32 \
138
+ --learning_rate 2e-5 \
139
+ --num_train_epochs 3.0 \
140
+ --output_dir /tmp/$TASK_NAME/
141
+ ```
142
+
143
+ where task name can be one of CoLA, SST-2, MRPC, STS-B, QQP, MNLI, QNLI, RTE, WNLI.
144
+
145
+ The dev set results will be present within the text file `eval_results.txt` in the specified output_dir.
146
+ In case of MNLI, since there are two separate dev sets (matched and mismatched), there will be a separate
147
+ output folder called `/tmp/MNLI-MM/` in addition to `/tmp/MNLI/`.
148
+
149
+ The code has not been tested with half-precision training with apex on any GLUE task apart from MRPC, MNLI,
150
+ CoLA, SST-2. The following section provides details on how to run half-precision training with MRPC. With that being
151
+ said, there shouldn’t be any issues in running half-precision training with the remaining GLUE tasks as well,
152
+ since the data processor for each task inherits from the base class DataProcessor.
153
+
154
+ ### MRPC
155
+
156
+ #### Fine-tuning example
157
+
158
+ The following examples fine-tune BERT on the Microsoft Research Paraphrase Corpus (MRPC) corpus and runs in less
159
+ than 10 minutes on a single K-80 and in 27 seconds (!) on single tesla V100 16GB with apex installed.
160
+
161
+ Before running anyone of these GLUE tasks you should download the
162
+ [GLUE data](https://gluebenchmark.com/tasks) by running
163
+ [this script](https://gist.github.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e)
164
+ and unpack it to some directory `$GLUE_DIR`.
165
+
166
+ ```bash
167
+ export GLUE_DIR=/path/to/glue
168
+
169
+ python run_glue.py \
170
+ --model_type bert \
171
+ --model_name_or_path bert-base-cased \
172
+ --task_name MRPC \
173
+ --do_train \
174
+ --do_eval \
175
+ --do_lower_case \
176
+ --data_dir $GLUE_DIR/MRPC/ \
177
+ --max_seq_length 128 \
178
+ --per_gpu_train_batch_size 32 \
179
+ --learning_rate 2e-5 \
180
+ --num_train_epochs 3.0 \
181
+ --output_dir /tmp/mrpc_output/
182
+ ```
183
+
184
+ Our test ran on a few seeds with [the original implementation hyper-
185
+ parameters](https://github.com/google-research/bert#sentence-and-sentence-pair-classification-tasks) gave evaluation
186
+ results between 84% and 88%.
187
+
188
+ #### Using Apex and mixed-precision
189
+
190
+ Using Apex and 16 bit precision, the fine-tuning on MRPC only takes 27 seconds. First install
191
+ [apex](https://github.com/NVIDIA/apex), then run the following example:
192
+
193
+ ```bash
194
+ export GLUE_DIR=/path/to/glue
195
+
196
+ python run_glue.py \
197
+ --model_type bert \
198
+ --model_name_or_path bert-base-cased \
199
+ --task_name MRPC \
200
+ --do_train \
201
+ --do_eval \
202
+ --do_lower_case \
203
+ --data_dir $GLUE_DIR/MRPC/ \
204
+ --max_seq_length 128 \
205
+ --per_gpu_train_batch_size 32 \
206
+ --learning_rate 2e-5 \
207
+ --num_train_epochs 3.0 \
208
+ --output_dir /tmp/mrpc_output/ \
209
+ --fp16
210
+ ```
211
+
212
+ #### Distributed training
213
+
214
+ Here is an example using distributed training on 8 V100 GPUs. The model used is the BERT whole-word-masking and it
215
+ reaches F1 > 92 on MRPC.
216
+
217
+ ```bash
218
+ export GLUE_DIR=/path/to/glue
219
+
220
+ python -m torch.distributed.launch \
221
+ --nproc_per_node 8 run_glue.py \
222
+ --model_type bert \
223
+ --model_name_or_path bert-base-cased \
224
+ --task_name MRPC \
225
+ --do_train \
226
+ --do_eval \
227
+ --do_lower_case \
228
+ --data_dir $GLUE_DIR/MRPC/ \
229
+ --max_seq_length 128 \
230
+ --per_gpu_train_batch_size 8 \
231
+ --learning_rate 2e-5 \
232
+ --num_train_epochs 3.0 \
233
+ --output_dir /tmp/mrpc_output/
234
+ ```
235
+
236
+ Training with these hyper-parameters gave us the following results:
237
+
238
+ ```bash
239
+ acc = 0.8823529411764706
240
+ acc_and_f1 = 0.901702786377709
241
+ eval_loss = 0.3418912578906332
242
+ f1 = 0.9210526315789473
243
+ global_step = 174
244
+ loss = 0.07231863956341798
245
+ ```
246
+
247
+ ### MNLI
248
+
249
+ The following example uses the BERT-large, uncased, whole-word-masking model and fine-tunes it on the MNLI task.
250
+
251
+ ```bash
252
+ export GLUE_DIR=/path/to/glue
253
+
254
+ python -m torch.distributed.launch \
255
+ --nproc_per_node 8 run_glue.py \
256
+ --model_type bert \
257
+ --model_name_or_path bert-base-cased \
258
+ --task_name mnli \
259
+ --do_train \
260
+ --do_eval \
261
+ --do_lower_case \
262
+ --data_dir $GLUE_DIR/MNLI/ \
263
+ --max_seq_length 128 \
264
+ --per_gpu_train_batch_size 8 \
265
+ --learning_rate 2e-5 \
266
+ --num_train_epochs 3.0 \
267
+ --output_dir output_dir \
268
+ ```
269
+
270
+ The results are the following:
271
+
272
+ ```bash
273
+ ***** Eval results *****
274
+ acc = 0.8679706601466992
275
+ eval_loss = 0.4911287787382479
276
+ global_step = 18408
277
+ loss = 0.04755385363816904
278
+
279
+ ***** Eval results *****
280
+ acc = 0.8747965825874695
281
+ eval_loss = 0.45516540421714036
282
+ global_step = 18408
283
+ loss = 0.04755385363816904
284
+ ```
285
+
286
+ ##Multiple Choice
287
+
288
+ Based on the script [`run_multiple_choice.py`]().
289
+
290
+ #### Fine-tuning on SWAG
291
+ Download [swag](https://github.com/rowanz/swagaf/tree/master/data) data
292
+
293
+ ```
294
+ #training on 4 tesla V100(16GB) GPUS
295
+ export SWAG_DIR=/path/to/swag_data_dir
296
+ python ./examples/single_model_scripts/run_multiple_choice.py \
297
+ --model_type roberta \
298
+ --task_name swag \
299
+ --model_name_or_path roberta-base \
300
+ --do_train \
301
+ --do_eval \
302
+ --do_lower_case \
303
+ --data_dir $SWAG_DIR \
304
+ --learning_rate 5e-5 \
305
+ --num_train_epochs 3 \
306
+ --max_seq_length 80 \
307
+ --output_dir models_bert/swag_base \
308
+ --per_gpu_eval_batch_size=16 \
309
+ --per_gpu_train_batch_size=16 \
310
+ --gradient_accumulation_steps 2 \
311
+ --overwrite_output
312
+ ```
313
+ Training with the defined hyper-parameters yields the following results:
314
+ ```
315
+ ***** Eval results *****
316
+ eval_acc = 0.8338998300509847
317
+ eval_loss = 0.44457291918821606
318
+ ```
319
+
320
+ ## SQuAD
321
+
322
+ Based on the script [`run_squad.py`](https://github.com/huggingface/pytorch-transformers/blob/master/examples/run_squad.py).
323
+
324
+ #### Fine-tuning on SQuAD
325
+
326
+ This example code fine-tunes BERT on the SQuAD dataset. It runs in 24 min (with BERT-base) or 68 min (with BERT-large)
327
+ on a single tesla V100 16GB. The data for SQuAD can be downloaded with the following links and should be saved in a
328
+ $SQUAD_DIR directory.
329
+
330
+ * [train-v1.1.json](https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json)
331
+ * [dev-v1.1.json](https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json)
332
+ * [evaluate-v1.1.py](https://github.com/allenai/bi-att-flow/blob/master/squad/evaluate-v1.1.py)
333
+
334
+ ```bash
335
+ export SQUAD_DIR=/path/to/SQUAD
336
+
337
+ python run_squad.py \
338
+ --model_type bert \
339
+ --model_name_or_path bert-base-cased \
340
+ --do_train \
341
+ --do_eval \
342
+ --do_lower_case \
343
+ --train_file $SQUAD_DIR/train-v1.1.json \
344
+ --predict_file $SQUAD_DIR/dev-v1.1.json \
345
+ --per_gpu_train_batch_size 12 \
346
+ --learning_rate 3e-5 \
347
+ --num_train_epochs 2.0 \
348
+ --max_seq_length 384 \
349
+ --doc_stride 128 \
350
+ --output_dir /tmp/debug_squad/
351
+ ```
352
+
353
+ Training with the previously defined hyper-parameters yields the following results:
354
+
355
+ ```bash
356
+ f1 = 88.52
357
+ exact_match = 81.22
358
+ ```
359
+
360
+ #### Distributed training
361
+
362
+
363
+ Here is an example using distributed training on 8 V100 GPUs and Bert Whole Word Masking uncased model to reach a F1 > 93 on SQuAD:
364
+
365
+ ```bash
366
+ python -m torch.distributed.launch --nproc_per_node=8 run_squad.py \
367
+ --model_type bert \
368
+ --model_name_or_path bert-base-cased \
369
+ --do_train \
370
+ --do_eval \
371
+ --do_lower_case \
372
+ --train_file $SQUAD_DIR/train-v1.1.json \
373
+ --predict_file $SQUAD_DIR/dev-v1.1.json \
374
+ --learning_rate 3e-5 \
375
+ --num_train_epochs 2 \
376
+ --max_seq_length 384 \
377
+ --doc_stride 128 \
378
+ --output_dir ../models/wwm_uncased_finetuned_squad/ \
379
+ --per_gpu_train_batch_size 24 \
380
+ --gradient_accumulation_steps 12
381
+ ```
382
+
383
+ Training with the previously defined hyper-parameters yields the following results:
384
+
385
+ ```bash
386
+ f1 = 93.15
387
+ exact_match = 86.91
388
+ ```
389
+
390
+ This fine-tuneds model is available as a checkpoint under the reference
391
+ `bert-large-uncased-whole-word-masking-finetuned-squad`.
392
+
Optimus/code/examples/__pycache__/utils_glue.cpython-37.pyc ADDED
Binary file (21.5 kB). View file
 
Optimus/code/examples/big_ae/__pycache__/grad_app.cpython-310.pyc ADDED
Binary file (14 kB). View file
 
Optimus/code/examples/big_ae/__pycache__/utils.cpython-37.pyc ADDED
Binary file (40.3 kB). View file
 
Optimus/code/examples/big_ae/debug_data.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+
4
+ output_dir = "../output/philly_rr1_vae_wikipedia_pretraining_2nd_file"
5
+
6
+ data = torch.load(os.path.join(output_dir, 'batch_debug_6621.pt')
Optimus/code/examples/big_ae/eval_dialog_multi_response.py ADDED
@@ -0,0 +1,378 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from nltk.translate.bleu_score import sentence_bleu
5
+ from nltk.translate.bleu_score import SmoothingFunction
6
+ from sklearn.metrics.pairwise import cosine_similarity as cosine
7
+ from collections import Counter
8
+ import os, pickle, pdb
9
+
10
+ class Metrics:
11
+ # based on https://raw.githubusercontent.com/guxd/DialogWAE/29f206af05bfe5fe28fec4448e208310a7c9258d/experiments/metrics.py
12
+
13
+ def __init__(self, path_word2vec='../data/datasets/dailydialog_data/glove.twitter.27B.200d.txt'):
14
+ """
15
+ :param word2vec - a numpy array of word2vec with shape [vocab_size x emb_size]
16
+ """
17
+ super(Metrics, self).__init__()
18
+ self.load_word2vec(path_word2vec)
19
+ #self.word2vec = dict()
20
+
21
+ def load_word2vec(self, path_word2vec):
22
+ path_pkl = path_word2vec + '.pkl'
23
+ if os.path.exists(path_pkl):
24
+ print('loading word2vec from '+path_pkl)
25
+ self.word2vec = pickle.load(open(path_pkl, 'rb'))
26
+ else:
27
+ self.word2vec = dict()
28
+ for i, line in enumerate(open(path_word2vec, encoding='utf-8')):
29
+ ss = line.strip('\n').split()
30
+ self.word2vec[ss[0]] = [float(v) for v in ss[1:]]
31
+ if i % 1e4 == 0:
32
+ print('processed %ik word2vec'%(i/1e3))
33
+ print('dumping word2vec to '+path_pkl)
34
+ pickle.dump(self.word2vec, open(path_pkl, 'wb'))
35
+ self.embed_dim = len(list(self.word2vec.values())[0])
36
+ print('loaded %i word2vec of dim %i'%(len(self.word2vec), self.embed_dim))
37
+
38
+ def embedding(self, seqs):
39
+ # note: different from original implementation
40
+ batch_size, seqlen = seqs.shape
41
+ embs = np.zeros([batch_size, seqlen, self.embed_dim])
42
+ for i in range(batch_size):
43
+ for j in range(seqlen):
44
+ w = seqs[i,j]
45
+ if w != '' and w in self.word2vec:
46
+ embs[i, j, :] = self.word2vec[w]
47
+ return embs
48
+
49
+
50
+ def extrema(self, embs, lens): # embs: [batch_size x seq_len x emb_size] lens: [batch_size]
51
+ """
52
+ computes the value of every single dimension in the word vectors which has the greatest
53
+ difference from zero.
54
+ :param seq: sequence
55
+ :param seqlen: length of sequence
56
+ """
57
+ # Find minimum and maximum value for every dimension in predictions
58
+ batch_size, seq_len, emb_size = embs.shape
59
+ max_mask = np.zeros((batch_size, seq_len, emb_size), dtype=np.int)
60
+ for i,length in enumerate(lens):
61
+ max_mask[i,:length,:]=1
62
+ min_mask = 1-max_mask
63
+ seq_max = (embs*max_mask).max(1) # [batch_sz x emb_sz]
64
+ seq_min = (embs+min_mask).min(1)
65
+ # Find the maximum absolute value in min and max data
66
+ comp_mask = seq_max >= np.abs(seq_min)# [batch_sz x emb_sz]
67
+ # Add vectors for finding final sequence representation for predictions
68
+ extrema_emb = seq_max* comp_mask + seq_min* np.logical_not(comp_mask)
69
+ return extrema_emb
70
+
71
+ def mean(self, embs, lens):
72
+ batch_size, seq_len, emb_size=embs.shape
73
+ mask = np.zeros((batch_size, seq_len, emb_size), dtype=np.int)
74
+ for i,length in enumerate(lens):
75
+ mask[i,:length,:]=1
76
+ return (embs*mask).sum(1)/(mask.sum(1)+1e-8)
77
+
78
+ def sim_bleu(self, hyps, ref):
79
+ """
80
+ :param ref - a list of tokens of the reference
81
+ :param hyps - a list of tokens of the hypothesis
82
+
83
+ :return maxbleu - recall bleu
84
+ :return avgbleu - precision bleu
85
+ """
86
+ scores = []
87
+ for hyp in hyps:
88
+ try:
89
+ scores.append(sentence_bleu([ref], hyp, smoothing_function=SmoothingFunction().method7,
90
+ weights=[1./3, 1./3, 1./3]))
91
+ except:
92
+ scores.append(0.0)
93
+ return np.max(scores), np.mean(scores)
94
+
95
+
96
+ def sim_bow(self, pred, pred_lens, ref, ref_lens):
97
+ """
98
+ :param pred - ndarray [batch_size x seqlen]
99
+ :param pred_lens - list of integers
100
+ :param ref - ndarray [batch_size x seqlen]
101
+ """
102
+ # look up word embeddings for prediction and reference
103
+ emb_pred = self.embedding(pred) # [batch_sz x seqlen1 x emb_sz]
104
+ emb_ref = self.embedding(ref) # [batch_sz x seqlen2 x emb_sz]
105
+
106
+ ext_emb_pred=self.extrema(emb_pred, pred_lens)
107
+ ext_emb_ref=self.extrema(emb_ref, ref_lens)
108
+ bow_extrema=cosine(ext_emb_pred, ext_emb_ref) # [batch_sz_pred x batch_sz_ref]
109
+
110
+ avg_emb_pred = self.mean(emb_pred, pred_lens) # Calculate mean over seq
111
+ avg_emb_ref = self.mean(emb_ref, ref_lens)
112
+ bow_avg = cosine(avg_emb_pred, avg_emb_ref) # [batch_sz_pred x batch_sz_ref]
113
+
114
+
115
+ batch_pred, seqlen_pred, emb_size=emb_pred.shape
116
+ batch_ref, seqlen_ref, emb_size=emb_ref.shape
117
+ cos_sim = cosine(emb_pred.reshape((-1, emb_size)), emb_ref.reshape((-1, emb_size))) # [(batch_sz*seqlen1)x(batch_sz*seqlen2)]
118
+ cos_sim = cos_sim.reshape((batch_pred, seqlen_pred, batch_ref, seqlen_ref))
119
+ # Find words with max cosine similarity
120
+ max12 = cos_sim.max(1).mean(2) # max over seqlen_pred
121
+ max21 = cos_sim.max(3).mean(1) # max over seqlen_ref
122
+ bow_greedy=(max12+max21)/2 # [batch_pred x batch_ref(1)]
123
+ return np.max(bow_extrema), np.max(bow_avg), np.max(bow_greedy)
124
+
125
+ def div_distinct(self, seqs, seq_lens):
126
+ """
127
+ distinct-1 distinct-2 metrics for diversity measure proposed
128
+ by Li et al. "A Diversity-Promoting Objective Function for Neural Conversation Models"
129
+ we counted numbers of distinct unigrams and bigrams in the generated responses
130
+ and divide the numbers by total number of unigrams and bigrams.
131
+ The two metrics measure how informative and diverse the generated responses are.
132
+ High numbers and high ratios mean that there is much content in the generated responses,
133
+ and high numbers further indicate that the generated responses are long
134
+ """
135
+ batch_size = seqs.shape[0]
136
+ intra_dist1, intra_dist2=np.zeros(batch_size), np.zeros(batch_size)
137
+
138
+ n_unigrams, n_bigrams, n_unigrams_total , n_bigrams_total = 0. ,0., 0., 0.
139
+ unigrams_all, bigrams_all = Counter(), Counter()
140
+ for b in range(batch_size):
141
+ unigrams= Counter([tuple(seqs[b,i:i+1]) for i in range(seq_lens[b])])
142
+ bigrams = Counter([tuple(seqs[b,i:i+2]) for i in range(seq_lens[b]-1)])
143
+ intra_dist1[b]=(len(unigrams.items())+1e-12)/(seq_lens[b]+1e-5)
144
+ intra_dist2[b]=(len(bigrams.items())+1e-12)/(max(0, seq_lens[b]-1)+1e-5)
145
+
146
+ unigrams_all.update([tuple(seqs[b,i:i+1]) for i in range(seq_lens[b])])
147
+ bigrams_all.update([tuple(seqs[b,i:i+2]) for i in range(seq_lens[b]-1)])
148
+ n_unigrams_total += seq_lens[b]
149
+ n_bigrams_total += max(0, seq_lens[b]-1)
150
+
151
+ inter_dist1 = (len(unigrams_all.items())+1e-12)/(n_unigrams_total+1e-5)
152
+ inter_dist2 = (len(bigrams_all.items())+1e-12)/(n_bigrams_total+1e-5)
153
+ return intra_dist1, intra_dist2, inter_dist1, inter_dist2
154
+
155
+ import pdb
156
+
157
+ def eval_multi_ref(path, path_multi_ref=None):
158
+ """
159
+ based on: https://github.com/guxd/DialogWAE/blob/29f206af05bfe5fe28fec4448e208310a7c9258d/sample.py
160
+ path: each line is '\t'.join([src, ref, hyp])
161
+ path_multi_ref: each line is '\t'.join([src, hyp])
162
+ the order of unique src appeared in `path_multi_ref` should be the same as that in `path`
163
+ """
164
+ metrics = Metrics()
165
+ d_ref = dict()
166
+ d_hyp = dict()
167
+ src2ix = dict()
168
+ ix2src = dict()
169
+ ix = 0
170
+ for line in open(path, encoding='utf-8'):
171
+ line = line.strip('\n').strip()
172
+ if len(line) == 0:
173
+ continue
174
+
175
+ # pdb.set_trace()
176
+ src, ref, hyp = line.split('\t')
177
+ #src, ref = line.split('\t'); hyp = ref
178
+ src = src.replace(' EOS ',' [SEP] ').strip()
179
+ ref = ref.strip().split()
180
+ hyp = hyp.strip().split()
181
+ if src not in d_ref:
182
+ d_ref[src] = ref
183
+ d_hyp[src] = [hyp]
184
+ src2ix[src] = ix
185
+ ix2src[ix] = src
186
+ ix += 1
187
+ else:
188
+ d_hyp[src].append(hyp)
189
+ print('loaded %i src-ref-hyp tuples'%(len(d_ref)))
190
+
191
+ def chr_only(s):
192
+ ret = ''
193
+ for c in s:
194
+ if c.isalpha():
195
+ ret += c
196
+ return ret
197
+
198
+ if path_multi_ref is not None:
199
+ set_src4multiref = set()
200
+ ix = -1
201
+ d_multi_ref = dict()
202
+ for line in open(path_multi_ref, encoding='utf-8'):
203
+ line = line.strip('\n').strip()
204
+ if len(line) == 0:
205
+ continue
206
+ src4multiref, ref = line.split('\t')[:2]
207
+ src4multiref = src4multiref.replace(' EOS ', ' ').replace(' [SEP] ',' ').strip()
208
+ ref = ref.strip().split()
209
+ if src4multiref not in set_src4multiref:
210
+ set_src4multiref.add(src4multiref)
211
+ ix += 1
212
+ src = ix2src[ix]
213
+ id_hyp = chr_only(src)
214
+ id_multiref = chr_only(src4multiref)
215
+ if id_multiref != id_hyp:
216
+ print('[ERROR] cannot match src4multiref and src4hyp')
217
+ print('src4multiref:', src4multiref)
218
+ print('src4hyp:', ix2src[ix])
219
+ # pdb.set_trace()
220
+ raise ValueError
221
+ d_multi_ref[src] = [ref]
222
+ else:
223
+ d_multi_ref[src].append(ref)
224
+
225
+ n_ref = [len(d_multi_ref[k]) for k in d_multi_ref]
226
+ print('loaded %i src with multi-ref, avg n_ref = %.3f'%(len(d_multi_ref), np.mean(n_ref)))
227
+
228
+ n_miss = 0
229
+ for src in d_ref:
230
+ if src not in d_multi_ref:
231
+ n_miss += 1
232
+ print('[WARNING] cannot find multiref for src: '+src)
233
+ d_multi_ref[src] = [d_ref[src]]
234
+ if n_miss > 5:
235
+ raise ValueError
236
+
237
+ n = len(d_ref)
238
+ print(path)
239
+ print('n_src\t%i'%n)
240
+
241
+ avg_lens = 0
242
+ maxbleu = 0
243
+ avgbleu = 0
244
+ intra_dist1, intra_dist2, inter_dist1, inter_dist2 = 0,0,0,0
245
+ bow_extrema, bow_avg, bow_greedy = 0,0,0
246
+ for src in d_ref:
247
+
248
+ # BLEU ----
249
+
250
+ if path_multi_ref is None:
251
+ m, a = metrics.sim_bleu(d_hyp[src], d_ref[src])
252
+ else:
253
+ n_ref = len(d_multi_ref[src])
254
+ m, a = 0, 0
255
+ for ref in d_multi_ref[src]:
256
+ _m, _a = metrics.sim_bleu(d_hyp[src], ref)
257
+ m += _m
258
+ a += _a
259
+ m /= n_ref
260
+ a /= n_ref
261
+
262
+ maxbleu += m
263
+ avgbleu += a
264
+
265
+ # diversity ----
266
+
267
+ seq_len = [len(hyp) for hyp in d_hyp[src]]
268
+ max_len = max(seq_len)
269
+ seqs = []
270
+ for hyp in d_hyp[src]:
271
+ padded = hyp + [''] * (max_len - len(hyp))
272
+ seqs.append(np.reshape(padded, [1, -1]))
273
+ seqs = np.concatenate(seqs, axis=0)
274
+ intra1, intra2, inter1, inter2 = metrics.div_distinct(seqs, seq_len)
275
+ intra_dist1 += np.mean(intra1)
276
+ intra_dist2 += np.mean(intra2)
277
+ inter_dist1 += inter1
278
+ inter_dist2 += inter2
279
+
280
+ avg_lens += np.mean(seq_len)
281
+
282
+ # BOW ----
283
+
284
+ def calc_bow(ref):
285
+ n_hyp = len(d_hyp[src])
286
+ seqs_ref = np.concatenate([np.reshape(ref, [1,-1])] * n_hyp, axis=0)
287
+ seq_len_ref = [len(ref)] * n_hyp
288
+ return metrics.sim_bow(seqs, seq_len, seqs_ref, seq_len_ref)
289
+
290
+ if path_multi_ref is None:
291
+ extrema, avg, greedy = calc_bow(d_ref[src])
292
+ else:
293
+ extrema, avg, greedy = 0, 0, 0
294
+ for ref in d_multi_ref[src]:
295
+ e, a, g = calc_bow(ref)
296
+ extrema += e
297
+ avg += a
298
+ greedy += g
299
+ extrema /= n_ref
300
+ avg /= n_ref
301
+ greedy /= n_ref
302
+
303
+ bow_extrema += extrema
304
+ bow_avg += avg
305
+ bow_greedy += greedy
306
+
307
+ recall_bleu = maxbleu/n
308
+ prec_bleu = avgbleu/n
309
+ f1 = 2*(prec_bleu*recall_bleu) / (prec_bleu+recall_bleu+10e-12)
310
+
311
+ print('BLEU')
312
+ print(' R\t%.3f'%recall_bleu)
313
+ print(' P\t%.3f'%prec_bleu)
314
+ print(' F1\t%.3f'%f1)
315
+ print('BOW')
316
+ print(' A\t%.3f'%(bow_avg/n))
317
+ print(' E\t%.3f'%(bow_extrema/n))
318
+ print(' G\t%.3f'%(bow_greedy/n))
319
+ print('intra_dist')
320
+ print(' 1\t%.3f'%(intra_dist1/n))
321
+ print(' 2\t%.3f'%(intra_dist2/n))
322
+ print('inter_dist')
323
+ print(' 1\t%.3f'%(inter_dist1/n))
324
+ print(' 2\t%.3f'%(inter_dist2/n))
325
+ print('avg_L\t%.1f'%(avg_lens/n))
326
+
327
+ results = {
328
+ "BLEU_R": recall_bleu, "BLEU_P": prec_bleu, "BLEU_F1": f1, "BOW_A": bow_avg/n, "BOW_E": bow_extrema/n, "BOW_G": bow_greedy/n, "intra_dist1": intra_dist1/n, "intra_dist2": intra_dist2/n, "inter_dist1": inter_dist1/n, "inter_dist2": inter_dist2/n, "avg_L": avg_lens/n
329
+ }
330
+
331
+ return results
332
+
333
+
334
+ def create_rand_baseline():
335
+ path = 'data/datasets/dailydialog_data/test.txt'
336
+ srcs = []
337
+ refs = []
338
+ for line in open(path, encoding='utf-8'):
339
+ src, ref = line.strip('\n').split('\t')
340
+ srcs.append(src.strip())
341
+ refs.append(ref.strip())
342
+
343
+ hyps = set()
344
+ path = 'data/datasets/dailydialog_data/train.txt'
345
+ for line in open(path, encoding='utf-8'):
346
+ _, ref = line.strip('\n').split('\t')
347
+ hyps.add(ref)
348
+ if len(hyps) == len(srcs) *10:
349
+ print('collected training ref')
350
+ break
351
+
352
+ hyps = list(hyps)
353
+ lines = []
354
+ j = 0
355
+ for i in range(len(srcs)):
356
+ lines += ['\t'.join([srcs[i], refs[i], hyp]) for hyp in hyps[j:j+10]]
357
+ j = j + 10
358
+ with open('out/rand.tsv', 'w', encoding='utf-8') as f:
359
+ f.write('\n'.join(lines))
360
+
361
+
362
+ def create_human_baseline():
363
+ path = 'data/datasets/dailydialog_data/test.txt'
364
+ lines = []
365
+ for line in open(path, encoding='utf-8'):
366
+ src, ref = line.strip('\n').split('\t')
367
+ src = src.strip()
368
+ ref = ref.strip()
369
+ lines.append('\t'.join([src, ref, ref]))
370
+
371
+ with open('out/human.tsv', 'w', encoding='utf-8') as f:
372
+ f.write('\n'.join(lines))
373
+
374
+
375
+ if __name__ == "__main__":
376
+ path = 'D:/data/switchboard/test.txt.1ref'
377
+ path_multi_ref = 'D:/data/switchboard/test.txt'
378
+ eval_multi_ref(path_multi_ref, path)
Optimus/code/examples/big_ae/eval_dialog_response.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from nltk.translate.bleu_score import sentence_bleu
5
+ from nltk.translate.bleu_score import SmoothingFunction
6
+ from sklearn.metrics.pairwise import cosine_similarity as cosine
7
+ from collections import Counter
8
+ import os, pickle
9
+
10
+ class Metrics:
11
+ # based on https://raw.githubusercontent.com/guxd/DialogWAE/29f206af05bfe5fe28fec4448e208310a7c9258d/experiments/metrics.py
12
+
13
+ def __init__(self, path_word2vec='../data/datasets/dailydialog_data/glove.twitter.27B.200d.txt'):
14
+ """
15
+ :param word2vec - a numpy array of word2vec with shape [vocab_size x emb_size]
16
+ """
17
+ self.path_word2vec = path_word2vec
18
+ super(Metrics, self).__init__()
19
+ self.load_word2vec(path_word2vec)
20
+
21
+ def load_word2vec(self, path_word2vec):
22
+ path_pkl = path_word2vec + '.pkl'
23
+ if os.path.exists(path_pkl):
24
+ print('loading word2vec from '+path_pkl)
25
+ self.word2vec = pickle.load(open(path_pkl, 'rb'))
26
+ else:
27
+ self.word2vec = dict()
28
+ for i, line in enumerate(open(path_word2vec, encoding='utf-8')):
29
+ ss = line.strip('\n').split()
30
+ self.word2vec[ss[0]] = [float(v) for v in ss[1:]]
31
+ if i % 1e4 == 0:
32
+ print('processed %ik word2vec'%(i/1e3))
33
+ print('dumping word2vec to '+path_pkl)
34
+ pickle.dump(self.word2vec, open(path_pkl, 'wb'))
35
+ # pdb.set_trace()
36
+ self.embed_dim = len(self.word2vec["."]) # len(self.word2vec.values()[0])
37
+ print('loaded %i word2vec of dim %i'%(len(self.word2vec), self.embed_dim))
38
+
39
+ def embedding(self, seqs):
40
+ # note: different from original implementation
41
+ batch_size, seqlen = seqs.shape
42
+ embs = np.zeros([batch_size, seqlen, self.embed_dim])
43
+ for i in range(batch_size):
44
+ for j in range(seqlen):
45
+ w = seqs[i,j]
46
+ if w != '' and w in self.word2vec:
47
+ embs[i, j, :] = self.word2vec[w]
48
+ return embs
49
+
50
+
51
+ def extrema(self, embs, lens): # embs: [batch_size x seq_len x emb_size] lens: [batch_size]
52
+ """
53
+ computes the value of every single dimension in the word vectors which has the greatest
54
+ difference from zero.
55
+ :param seq: sequence
56
+ :param seqlen: length of sequence
57
+ """
58
+ # Find minimum and maximum value for every dimension in predictions
59
+ batch_size, seq_len, emb_size = embs.shape
60
+ max_mask = np.zeros((batch_size, seq_len, emb_size), dtype=np.int)
61
+ for i,length in enumerate(lens):
62
+ max_mask[i,:length,:]=1
63
+ min_mask = 1-max_mask
64
+ seq_max = (embs*max_mask).max(1) # [batch_sz x emb_sz]
65
+ seq_min = (embs+min_mask).min(1)
66
+ # Find the maximum absolute value in min and max data
67
+ comp_mask = seq_max >= np.abs(seq_min)# [batch_sz x emb_sz]
68
+ # Add vectors for finding final sequence representation for predictions
69
+ extrema_emb = seq_max* comp_mask + seq_min* np.logical_not(comp_mask)
70
+ return extrema_emb
71
+
72
+ def mean(self, embs, lens):
73
+ batch_size, seq_len, emb_size=embs.shape
74
+ mask = np.zeros((batch_size, seq_len, emb_size), dtype=np.int)
75
+ for i,length in enumerate(lens):
76
+ mask[i,:length,:]=1
77
+ return (embs*mask).sum(1)/(mask.sum(1)+1e-8)
78
+
79
+ def sim_bleu(self, hyps, ref):
80
+ """
81
+ :param ref - a list of tokens of the reference
82
+ :param hyps - a list of tokens of the hypothesis
83
+
84
+ :return maxbleu - recall bleu
85
+ :return avgbleu - precision bleu
86
+ """
87
+ scores = []
88
+ for hyp in hyps:
89
+ try:
90
+ scores.append(sentence_bleu([ref], hyp, smoothing_function=SmoothingFunction().method7,
91
+ weights=[1./3, 1./3, 1./3]))
92
+ except:
93
+ scores.append(0.0)
94
+ return np.max(scores), np.mean(scores)
95
+
96
+
97
+ def sim_bow(self, pred, pred_lens, ref, ref_lens):
98
+ """
99
+ :param pred - ndarray [batch_size x seqlen]
100
+ :param pred_lens - list of integers
101
+ :param ref - ndarray [batch_size x seqlen]
102
+ """
103
+ # look up word embeddings for prediction and reference
104
+ emb_pred = self.embedding(pred) # [batch_sz x seqlen1 x emb_sz]
105
+ emb_ref = self.embedding(ref) # [batch_sz x seqlen2 x emb_sz]
106
+
107
+ ext_emb_pred=self.extrema(emb_pred, pred_lens)
108
+ ext_emb_ref=self.extrema(emb_ref, ref_lens)
109
+ bow_extrema=cosine(ext_emb_pred, ext_emb_ref) # [batch_sz_pred x batch_sz_ref]
110
+
111
+ avg_emb_pred = self.mean(emb_pred, pred_lens) # Calculate mean over seq
112
+ avg_emb_ref = self.mean(emb_ref, ref_lens)
113
+ bow_avg = cosine(avg_emb_pred, avg_emb_ref) # [batch_sz_pred x batch_sz_ref]
114
+
115
+
116
+ batch_pred, seqlen_pred, emb_size=emb_pred.shape
117
+ batch_ref, seqlen_ref, emb_size=emb_ref.shape
118
+ cos_sim = cosine(emb_pred.reshape((-1, emb_size)), emb_ref.reshape((-1, emb_size))) # [(batch_sz*seqlen1)x(batch_sz*seqlen2)]
119
+ cos_sim = cos_sim.reshape((batch_pred, seqlen_pred, batch_ref, seqlen_ref))
120
+ # Find words with max cosine similarity
121
+ max12 = cos_sim.max(1).mean(2) # max over seqlen_pred
122
+ max21 = cos_sim.max(3).mean(1) # max over seqlen_ref
123
+ bow_greedy=(max12+max21)/2 # [batch_pred x batch_ref(1)]
124
+ return np.max(bow_extrema), np.max(bow_avg), np.max(bow_greedy)
125
+
126
+ def div_distinct(self, seqs, seq_lens):
127
+ """
128
+ distinct-1 distinct-2 metrics for diversity measure proposed
129
+ by Li et al. "A Diversity-Promoting Objective Function for Neural Conversation Models"
130
+ we counted numbers of distinct unigrams and bigrams in the generated responses
131
+ and divide the numbers by total number of unigrams and bigrams.
132
+ The two metrics measure how informative and diverse the generated responses are.
133
+ High numbers and high ratios mean that there is much content in the generated responses,
134
+ and high numbers further indicate that the generated responses are long
135
+ """
136
+ batch_size = seqs.shape[0]
137
+ intra_dist1, intra_dist2=np.zeros(batch_size), np.zeros(batch_size)
138
+
139
+ n_unigrams, n_bigrams, n_unigrams_total , n_bigrams_total = 0. ,0., 0., 0.
140
+ unigrams_all, bigrams_all = Counter(), Counter()
141
+ for b in range(batch_size):
142
+ unigrams= Counter([tuple(seqs[b,i:i+1]) for i in range(seq_lens[b])])
143
+ bigrams = Counter([tuple(seqs[b,i:i+2]) for i in range(seq_lens[b]-1)])
144
+ intra_dist1[b]=(len(unigrams.items())+1e-12)/(seq_lens[b]+1e-5)
145
+ intra_dist2[b]=(len(bigrams.items())+1e-12)/(max(0, seq_lens[b]-1)+1e-5)
146
+
147
+ unigrams_all.update([tuple(seqs[b,i:i+1]) for i in range(seq_lens[b])])
148
+ bigrams_all.update([tuple(seqs[b,i:i+2]) for i in range(seq_lens[b]-1)])
149
+ n_unigrams_total += seq_lens[b]
150
+ n_bigrams_total += max(0, seq_lens[b]-1)
151
+
152
+ inter_dist1 = (len(unigrams_all.items())+1e-12)/(n_unigrams_total+1e-5)
153
+ inter_dist2 = (len(bigrams_all.items())+1e-12)/(n_bigrams_total+1e-5)
154
+ return intra_dist1, intra_dist2, inter_dist1, inter_dist2
155
+
156
+ import pdb
157
+
158
+ def eval_dialog_response(generated_text_file_path):
159
+ """
160
+ based on: https://github.com/guxd/DialogWAE/blob/29f206af05bfe5fe28fec4448e208310a7c9258d/sample.py
161
+ quoted from the DialogWAE paper: https://arxiv.org/pdf/1805.12352.pdf
162
+ * "For each test context, we sample 10 responses from the models and compute their BLEU scores"
163
+ * "We use Glove vectors" "For each test context, we report the maximum BOW embedding score among the 10 sampled responses."
164
+ * "intra-dist as the average of distinct values within each sampled response"
165
+ " "inter-dist as the distinct value among all sampled responses."
166
+ """
167
+ metrics = Metrics()
168
+ d_ref = dict()
169
+ d_hyp = dict()
170
+ for line in open(generated_text_file_path, encoding='utf-8'):
171
+ line = line.strip('\n').strip()
172
+ if len(line) == 0:
173
+ continue
174
+ src, ref, hyp = line.split('\t')
175
+ src = src.strip()
176
+ ref = ref.strip().split()
177
+ hyp = hyp.strip().split()
178
+ if src not in d_ref:
179
+ d_ref[src] = ref
180
+ d_hyp[src] = [hyp]
181
+ else:
182
+ d_hyp[src].append(hyp)
183
+
184
+ n = len(d_ref)
185
+ print(generated_text_file_path)
186
+ print('n_src\t%i'%n)
187
+
188
+ avg_lens = 0
189
+ maxbleu = 0
190
+ avgbleu = 0
191
+ intra_dist1, intra_dist2, inter_dist1, inter_dist2 = 0,0,0,0
192
+ bow_extrema, bow_avg, bow_greedy = 0,0,0
193
+ for src in d_ref:
194
+ m, a = metrics.sim_bleu(d_hyp[src], d_ref[src])
195
+ maxbleu += m
196
+ avgbleu += a
197
+
198
+ seq_len = [len(hyp) for hyp in d_hyp[src]]
199
+ max_len = max(seq_len)
200
+ seqs = []
201
+ for hyp in d_hyp[src]:
202
+ padded = hyp + [''] * (max_len - len(hyp))
203
+ seqs.append(np.reshape(padded, [1, -1]))
204
+ seqs = np.concatenate(seqs, axis=0)
205
+ intra1, intra2, inter1, inter2 = metrics.div_distinct(seqs, seq_len)
206
+ intra_dist1 += np.mean(intra1)
207
+ intra_dist2 += np.mean(intra2)
208
+ inter_dist1 += inter1
209
+ inter_dist2 += inter2
210
+
211
+ n_hyp = len(d_hyp[src])
212
+ seqs_ref = np.concatenate([np.reshape(d_ref[src], [1,-1])] * n_hyp, axis=0)
213
+ seq_len_ref = [len(d_ref[src])] * n_hyp
214
+ if metrics.word2vec is not None:
215
+ extrema, avg, greedy = metrics.sim_bow(seqs, seq_len, seqs_ref, seq_len_ref)
216
+ bow_extrema += extrema
217
+ bow_avg += avg
218
+ bow_greedy += greedy
219
+
220
+ avg_lens += np.mean(seq_len)
221
+
222
+ recall_bleu = maxbleu/n
223
+ prec_bleu = avgbleu/n
224
+ f1 = 2*(prec_bleu*recall_bleu) / (prec_bleu+recall_bleu+10e-12)
225
+
226
+ print('BLEU')
227
+ print(' R\t%.3f'%recall_bleu)
228
+ print(' P\t%.3f'%prec_bleu)
229
+ print(' F1\t%.3f'%f1)
230
+ print('BOW')
231
+ print(' A\t%.3f'%(bow_avg/n))
232
+ print(' E\t%.3f'%(bow_extrema/n))
233
+ print(' G\t%.3f'%(bow_greedy/n))
234
+ print('intra_dist')
235
+ print(' 1\t%.3f'%(intra_dist1/n))
236
+ print(' 2\t%.3f'%(intra_dist2/n))
237
+ print('inter_dist')
238
+ print(' 1\t%.3f'%(inter_dist1/n))
239
+ print(' 2\t%.3f'%(inter_dist2/n))
240
+ print('avg_L\t%.1f'%(avg_lens/n))
241
+
242
+ results = {
243
+ "BLEU_R": recall_bleu, "BLEU_P": prec_bleu, "BLEU_F1": f1, "BOW_A": bow_avg/n, "BOW_E": bow_extrema/n, "BOW_G": bow_greedy/n, "intra_dist1": intra_dist1/n, "intra_dist2": intra_dist2/n, "inter_dist1": inter_dist1/n, "inter_dist2": inter_dist2/n, "avg_L": avg_lens/n
244
+ }
245
+
246
+ return results
247
+
248
+
249
+
250
+ def create_rand_baseline():
251
+ path = 'data/datasets/dailydialog_data/test.txt'
252
+ srcs = []
253
+ refs = []
254
+ for line in open(path, encoding='utf-8'):
255
+ src, ref = line.strip('\n').split('\t')
256
+ srcs.append(src.strip())
257
+ refs.append(ref.strip())
258
+
259
+ hyps = set()
260
+ path = 'data/datasets/dailydialog_data/train.txt'
261
+ for line in open(path, encoding='utf-8'):
262
+ _, ref = line.strip('\n').split('\t')
263
+ hyps.add(ref)
264
+ if len(hyps) == len(srcs) *10:
265
+ print('collected training ref')
266
+ break
267
+
268
+ hyps = list(hyps)
269
+ lines = []
270
+ j = 0
271
+ for i in range(len(srcs)):
272
+ lines += ['\t'.join([srcs[i], refs[i], hyp]) for hyp in hyps[j:j+10]]
273
+ j = j + 10
274
+ with open('out/rand.tsv', 'w', encoding='utf-8') as f:
275
+ f.write('\n'.join(lines))
276
+
277
+
278
+ def create_human_baseline():
279
+ path = 'data/datasets/dailydialog_data/test.txt'
280
+ lines = []
281
+ for line in open(path, encoding='utf-8'):
282
+ src, ref = line.strip('\n').split('\t')
283
+ src = src.strip()
284
+ ref = ref.strip()
285
+ lines.append('\t'.join([src, ref, ref]))
286
+
287
+ with open('out/human.tsv', 'w', encoding='utf-8') as f:
288
+ f.write('\n'.join(lines))
289
+
290
+
291
+ if __name__ == "__main__":
292
+ #create_rand_baseline()
293
+ #create_human_baseline()
294
+ eval_dialog_response('out/eval_text_generation_results (1).txt')
295
+ #eval('out/rand.tsv')
Optimus/code/examples/big_ae/grad_app.py ADDED
@@ -0,0 +1,486 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """message_bottle.ipynb
3
+
4
+ Automatically generated by Colab.
5
+
6
+ Original file is located at
7
+ https://colab.research.google.com/drive/1I47sLakpuwERGzn-XoNct67mwiDS1mQD
8
+ """
9
+
10
+ import matplotlib.pyplot as plt
11
+ import matplotlib
12
+
13
+ import argparse
14
+ import glob
15
+ import logging
16
+ import os
17
+ import pickle
18
+ import random
19
+
20
+
21
+ import torch
22
+ import torch.nn.functional as F
23
+ import numpy as np
24
+
25
+ from tqdm import tqdm, trange
26
+ from types import SimpleNamespace
27
+
28
+ import sys
29
+ sys.path.append('/home/ryn_mote/Misc/generative_recommender/text_space/Optimus/code/examples/big_ae/')
30
+ sys.path.append('/home/ryn_mote/Misc/generative_recommender/text_space/Optimus/code/')
31
+ from pytorch_transformers import GPT2Config, OpenAIGPTConfig, XLNetConfig, TransfoXLConfig, BertConfig
32
+ from pytorch_transformers import GPT2LMHeadModel, GPT2Tokenizer, GPT2ForLatentConnector
33
+ from pytorch_transformers import OpenAIGPTLMHeadModel, OpenAIGPTTokenizer
34
+ from pytorch_transformers import XLNetLMHeadModel, XLNetTokenizer
35
+ from pytorch_transformers import TransfoXLLMHeadModel, TransfoXLTokenizer
36
+ from pytorch_transformers import BertForLatentConnector, BertTokenizer
37
+
38
+ from modules import VAE
39
+
40
+ import torch
41
+ import torch.nn as nn
42
+ import torch.nn.functional as F
43
+ torch.set_float32_matmul_precision('high')
44
+
45
+ from tqdm import tqdm
46
+
47
+ ################################################
48
+
49
+
50
+
51
+ def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
52
+ """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
53
+ Args:
54
+ logits: logits distribution shape (vocabulary size)
55
+ top_k > 0: keep only top k tokens with highest probability (top-k filtering).
56
+ top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
57
+ Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
58
+ From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
59
+ """
60
+ assert logits.dim() == 1 # batch size 1 for now - could be updated for more but the code would be less clear
61
+ top_k = min(top_k, logits.size(-1)) # Safety check
62
+ if top_k > 0:
63
+ # Remove all tokens with a probability less than the last token of the top-k
64
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
65
+ logits[indices_to_remove] = filter_value
66
+
67
+ if top_p > 0.0:
68
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
69
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
70
+
71
+ # Remove tokens with cumulative probability above the threshold
72
+ sorted_indices_to_remove = cumulative_probs > top_p
73
+ # Shift the indices to the right to keep also the first token above the threshold
74
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
75
+ sorted_indices_to_remove[..., 0] = 0
76
+
77
+ indices_to_remove = sorted_indices[sorted_indices_to_remove]
78
+ logits[indices_to_remove] = filter_value
79
+ return logits
80
+
81
+ def sample_sequence_conditional(model, length, context, past=None, num_samples=1, temperature=1, top_k=0, top_p=0.0, device='cpu', decoder_tokenizer=None):
82
+
83
+ context = torch.tensor(context, dtype=torch.long, device=device)
84
+ context = context.unsqueeze(0).repeat(num_samples, 1)
85
+ generated = context
86
+ with torch.no_grad():
87
+ while True:
88
+ # for _ in trange(length):
89
+ inputs = {'input_ids': generated, 'past': past}
90
+ outputs = model(**inputs) # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet (cached hidden-states)
91
+ next_token_logits = outputs[0][0, -1, :] / temperature
92
+ filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
93
+ next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
94
+ generated = torch.cat((generated, next_token.unsqueeze(0)), dim=1)
95
+
96
+ # pdb.set_trace()
97
+ if next_token.unsqueeze(0)[0,0].item() == decoder_tokenizer.encode('<EOS>')[0]:
98
+ break
99
+
100
+ return generated
101
+
102
+
103
+ def latent_code_from_text(text,):# args):
104
+ tokenized1 = tokenizer_encoder.encode(text)
105
+ tokenized1 = [101] + tokenized1 + [102]
106
+ coded1 = torch.Tensor([tokenized1])
107
+ coded1 =torch.Tensor.long(coded1)
108
+ with torch.no_grad():
109
+ x0 = coded1
110
+ x0 = x0.to('cuda')
111
+ pooled_hidden_fea = model_vae.encoder(x0, attention_mask=(x0 > 0).float())[1]
112
+ mean, logvar = model_vae.encoder.linear(pooled_hidden_fea).chunk(2, -1)
113
+ latent_z = mean.squeeze(1)
114
+ coded_length = len(tokenized1)
115
+ return latent_z, coded_length
116
+
117
+ # args
118
+ def text_from_latent_code(latent_z):
119
+ past = latent_z
120
+ context_tokens = tokenizer_decoder.encode('<BOS>')
121
+
122
+ length = 128 # maximum length, but not used
123
+ out = sample_sequence_conditional(
124
+ model=model_vae.decoder,
125
+ context=context_tokens,
126
+ past=past,
127
+ length= length, # Chunyuan: Fix length; or use <EOS> to complete a sentence
128
+ temperature=.2,
129
+ top_k=50,
130
+ top_p=.98,
131
+ device='cuda',
132
+ decoder_tokenizer = tokenizer_decoder
133
+ )
134
+ text_x1 = tokenizer_decoder.decode(out[0,:].tolist(), clean_up_tokenization_spaces=True)
135
+ text_x1 = text_x1.split()[1:-1]
136
+ text_x1 = ' '.join(text_x1)
137
+ return text_x1
138
+
139
+
140
+ ################################################
141
+ # Load model
142
+
143
+
144
+ MODEL_CLASSES = {
145
+ 'gpt2': (GPT2Config, GPT2ForLatentConnector, GPT2Tokenizer),
146
+ 'bert': (BertConfig, BertForLatentConnector, BertTokenizer)
147
+ }
148
+
149
+ latent_size = 768
150
+ model_path = '/home/ryn_mote/Misc/generative_recommender/text_space/1.0_checkpoint-31250/checkpoint-31250/checkpoint-full-31250/'
151
+ encoder_path = '/home/ryn_mote/Misc/generative_recommender/text_space/1.0_checkpoint-31250/checkpoint-31250/checkpoint-encoder-31250/'
152
+ decoder_path = '/home/ryn_mote/Misc/generative_recommender/text_space/1.0_checkpoint-31250/checkpoint-31250/checkpoint-decoder-31250/'
153
+ block_size = 100
154
+
155
+ # Load a trained Encoder model and vocabulary that you have fine-tuned
156
+ encoder_config_class, encoder_model_class, encoder_tokenizer_class = MODEL_CLASSES['bert']
157
+ model_encoder = encoder_model_class.from_pretrained(encoder_path, latent_size=latent_size)
158
+ tokenizer_encoder = encoder_tokenizer_class.from_pretrained('bert-base-cased', do_lower_case=True)
159
+
160
+ model_encoder.to('cuda')
161
+ if block_size <= 0:
162
+ block_size = tokenizer_encoder.max_len_single_sentence # Our input block size will be the max possible for the model
163
+ block_size = min(block_size, tokenizer_encoder.max_len_single_sentence)
164
+
165
+ # Load a trained Decoder model and vocabulary that you have fine-tuned
166
+ decoder_config_class, decoder_model_class, decoder_tokenizer_class = MODEL_CLASSES['gpt2']
167
+ model_decoder = decoder_model_class.from_pretrained(decoder_path, latent_size=latent_size)
168
+ tokenizer_decoder = decoder_tokenizer_class.from_pretrained('gpt2', do_lower_case=False)
169
+ model_decoder.to('cuda')
170
+ if block_size <= 0:
171
+ block_size = tokenizer_decoder.max_len_single_sentence # Our input block size will be the max possible for the model
172
+ block_size = min(block_size, tokenizer_decoder.max_len_single_sentence)
173
+
174
+ # Load full model
175
+ output_full_dir = '/home/ryn_mote/Misc/generative_recommender/text_space/'
176
+ checkpoint = torch.load(os.path.join(model_path, 'training.bin'))
177
+
178
+ # Chunyuan: Add Padding token to GPT2
179
+ special_tokens_dict = {'pad_token': '<PAD>', 'bos_token': '<BOS>', 'eos_token': '<EOS>'}
180
+ num_added_toks = tokenizer_decoder.add_special_tokens(special_tokens_dict)
181
+ print('We have added', num_added_toks, 'tokens to GPT2')
182
+ model_decoder.resize_token_embeddings(len(tokenizer_decoder)) # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer.
183
+ assert tokenizer_decoder.pad_token == '<PAD>'
184
+
185
+
186
+ # Evaluation
187
+ model_vae = VAE(model_encoder, model_decoder, tokenizer_encoder, tokenizer_decoder, SimpleNamespace(**{'latent_size': latent_size, 'device':'cuda'}))
188
+ model_vae.load_state_dict(checkpoint['model_state_dict'])
189
+ print("Pre-trained Optimus is successfully loaded")
190
+ model_vae.to('cuda').to(torch.bfloat16)
191
+
192
+ l = latent_code_from_text('A photo of a mountain.')[0]
193
+ t = text_from_latent_code(l)
194
+ print(t, l, l.shape)
195
+ ################################################
196
+
197
+ import gradio as gr
198
+ import numpy as np
199
+ from sklearn.svm import SVC
200
+ from sklearn.inspection import permutation_importance
201
+ from sklearn import preprocessing
202
+ import pandas as pd
203
+ import random
204
+ import time
205
+
206
+
207
+ dtype = torch.bfloat16
208
+ torch.set_grad_enabled(False)
209
+
210
+ prompt_list = [p for p in list(set(
211
+ pd.read_csv('./twitter_prompts.csv').iloc[:, 1].tolist())) if type(p) == str]
212
+
213
+ start_time = time.time()
214
+
215
+ ####################### Setup Model
216
+
217
+ # TODO put back
218
+ # @spaces.GPU()
219
+ def generate(prompt, in_embs=None,):
220
+ if prompt != '':
221
+ print(prompt)
222
+ #in_embs = in_embs / in_embs.abs().max() * .15 if in_embs != None else None
223
+ in_embs = .9 * in_embs.to('cuda') + .5 * latent_code_from_text(prompt)[0] if in_embs != None else latent_code_from_text(prompt)[0]
224
+ else:
225
+ print('From embeds.')
226
+ in_embs = in_embs / in_embs.abs().max() * .6
227
+ in_embs = in_embs.to('cuda').to(torch.bfloat16)
228
+ plt.close('all')
229
+ plt.hist(np.array(in_embs.detach().to('cpu').to(torch.float)).flatten(), bins=5)
230
+ plt.savefig('real_im_emb_plot.jpg')
231
+
232
+
233
+ text = text_from_latent_code(in_embs)
234
+ in_embs = latent_code_from_text(text)[0]
235
+ print(text)
236
+ return text, in_embs.to('cpu')
237
+
238
+
239
+ #######################
240
+
241
+ # TODO add to state instead of shared across all
242
+ glob_idx = 0
243
+
244
+ def next_one(embs, ys, calibrate_prompts):
245
+ global glob_idx
246
+ glob_idx = glob_idx + 1
247
+
248
+ with torch.no_grad():
249
+ if len(calibrate_prompts) > 0:
250
+ print('######### Calibrating with sample prompts #########')
251
+ prompt = calibrate_prompts.pop(0)
252
+ text, img_embs = generate(prompt)
253
+ embs += img_embs
254
+ print(len(embs))
255
+ return text, embs, ys, calibrate_prompts
256
+ else:
257
+ print('######### Roaming #########')
258
+
259
+
260
+ # handle case where every instance of calibration prompts is 'Neither' or 'Like' or 'Dislike'
261
+ if len(list(set(ys))) <= 1:
262
+ embs.append(.01*torch.randn(latent_size))
263
+ embs.append(.01*torch.randn(latent_size))
264
+ ys.append(0)
265
+ ys.append(1)
266
+ if len(list(ys)) < 10:
267
+ embs += [.01*torch.randn(latent_size)] * 3
268
+ ys += [0] * 3
269
+
270
+ pos_indices = [i for i in range(len(embs)) if ys[i] == 1]
271
+ neg_indices = [i for i in range(len(embs)) if ys[i] == 0]
272
+
273
+ # the embs & ys stay tied by index but we shuffle to drop randomly
274
+ random.shuffle(pos_indices)
275
+ random.shuffle(neg_indices)
276
+
277
+ #if len(pos_indices) - len(neg_indices) > 48 and len(pos_indices) > 80:
278
+ # pos_indices = pos_indices[32:]
279
+ if len(neg_indices) - len(pos_indices) > 48/16 and len(pos_indices) > 6:
280
+ pos_indices = pos_indices[5:]
281
+ if len(neg_indices) - len(pos_indices) > 48/16 and len(neg_indices) > 6:
282
+ neg_indices = neg_indices[5:]
283
+
284
+
285
+ if len(neg_indices) > 25:
286
+ neg_indices = neg_indices[1:]
287
+
288
+ print(len(pos_indices), len(neg_indices))
289
+ indices = pos_indices + neg_indices
290
+
291
+ embs = [embs[i] for i in indices]
292
+ ys = [ys[i] for i in indices]
293
+
294
+
295
+ indices = list(range(len(embs)))
296
+
297
+ # also add the latest 0 and the latest 1
298
+ has_0 = False
299
+ has_1 = False
300
+ for i in reversed(range(len(ys))):
301
+ if ys[i] == 0 and has_0 == False:
302
+ indices.append(i)
303
+ has_0 = True
304
+ elif ys[i] == 1 and has_1 == False:
305
+ indices.append(i)
306
+ has_1 = True
307
+ if has_0 and has_1:
308
+ break
309
+
310
+ # we may have just encountered a rare multi-threading diffusers issue (https://github.com/huggingface/diffusers/issues/5749);
311
+ # this ends up adding a rating but losing an embedding, it seems.
312
+ # let's take off a rating if so to continue without indexing errors.
313
+ if len(ys) > len(embs):
314
+ print('ys are longer than embs; popping latest rating')
315
+ ys.pop(-1)
316
+
317
+ feature_embs = np.array(torch.stack([embs[i].to('cpu') for i in indices]).to('cpu'))
318
+ scaler = preprocessing.StandardScaler().fit(feature_embs)
319
+ feature_embs = scaler.transform(feature_embs)
320
+ chosen_y = np.array([ys[i] for i in indices])
321
+
322
+ print('Gathering coefficients')
323
+ lin_class = SVC(max_iter=50000, kernel='linear', class_weight='balanced', C=.1).fit(feature_embs, chosen_y)
324
+ coef_ = torch.tensor(lin_class.coef_, dtype=torch.double)
325
+ print(coef_.shape, 'COEF')
326
+ print('Gathered')
327
+
328
+ rng_prompt = random.choice(prompt_list)
329
+ w = 1# if len(embs) % 2 == 0 else 0
330
+ im_emb = w * coef_.to(dtype=dtype)
331
+
332
+ prompt= '' if glob_idx % 3 != 0 else rng_prompt
333
+ text, im_emb = generate(prompt, im_emb)
334
+ embs += im_emb
335
+
336
+
337
+ return text, embs, ys, calibrate_prompts
338
+
339
+
340
+
341
+
342
+
343
+
344
+
345
+
346
+
347
+ def start(_, embs, ys, calibrate_prompts):
348
+ text, embs, ys, calibrate_prompts = next_one(embs, ys, calibrate_prompts)
349
+ return [
350
+ gr.Button(value='Like (L)', interactive=True),
351
+ gr.Button(value='Neither (Space)', interactive=True),
352
+ gr.Button(value='Dislike (A)', interactive=True),
353
+ gr.Button(value='Start', interactive=False),
354
+ text,
355
+ embs,
356
+ ys,
357
+ calibrate_prompts
358
+ ]
359
+
360
+
361
+ def choose(text, choice, embs, ys, calibrate_prompts):
362
+ if choice == 'Like (L)':
363
+ choice = 1
364
+ elif choice == 'Neither (Space)':
365
+ embs = embs[:-1]
366
+ text, embs, ys, calibrate_prompts = next_one(embs, ys, calibrate_prompts)
367
+ return text, embs, ys, calibrate_prompts
368
+ else:
369
+ choice = 0
370
+
371
+ # if we detected NSFW, leave that area of latent space regardless of how they rated chosen.
372
+ # TODO skip allowing rating
373
+ if text == None:
374
+ print('NSFW -- choice is disliked')
375
+ choice = 0
376
+
377
+ ys += [choice]*1
378
+ text, embs, ys, calibrate_prompts = next_one(embs, ys, calibrate_prompts)
379
+ return text, embs, ys, calibrate_prompts
380
+
381
+ css = '''.gradio-container{max-width: 700px !important}
382
+ #description{text-align: center}
383
+ #description h1, #description h3{display: block}
384
+ #description p{margin-top: 0}
385
+ .fade-in-out {animation: fadeInOut 3s forwards}
386
+ @keyframes fadeInOut {
387
+ 0% {
388
+ background: var(--bg-color);
389
+ }
390
+ 100% {
391
+ background: var(--button-secondary-background-fill);
392
+ }
393
+ }
394
+ '''
395
+ js_head = '''
396
+ <script>
397
+ document.addEventListener('keydown', function(event) {
398
+ if (event.key === 'a' || event.key === 'A') {
399
+ // Trigger click on 'dislike' if 'A' is pressed
400
+ document.getElementById('dislike').click();
401
+ } else if (event.key === ' ' || event.keyCode === 32) {
402
+ // Trigger click on 'neither' if Spacebar is pressed
403
+ document.getElementById('neither').click();
404
+ } else if (event.key === 'l' || event.key === 'L') {
405
+ // Trigger click on 'like' if 'L' is pressed
406
+ document.getElementById('like').click();
407
+ }
408
+ });
409
+ function fadeInOut(button, color) {
410
+ button.style.setProperty('--bg-color', color);
411
+ button.classList.remove('fade-in-out');
412
+ void button.offsetWidth; // This line forces a repaint by accessing a DOM property
413
+
414
+ button.classList.add('fade-in-out');
415
+ button.addEventListener('animationend', () => {
416
+ button.classList.remove('fade-in-out'); // Reset the animation state
417
+ }, {once: true});
418
+ }
419
+ document.body.addEventListener('click', function(event) {
420
+ const target = event.target;
421
+ if (target.id === 'dislike') {
422
+ fadeInOut(target, '#ff1717');
423
+ } else if (target.id === 'like') {
424
+ fadeInOut(target, '#006500');
425
+ } else if (target.id === 'neither') {
426
+ fadeInOut(target, '#cccccc');
427
+ }
428
+ });
429
+
430
+ </script>
431
+ '''
432
+
433
+ with gr.Blocks(css=css, head=js_head) as demo:
434
+ gr.Markdown('''# Compass
435
+ ### Generative Recommenders for Exporation of Text
436
+
437
+ Explore the latent space without prompting based on your preferences. Learn more on [the write-up](https://rynmurdock.github.io/posts/2024/3/generative_recomenders/).
438
+ ''', elem_id="description")
439
+ embs = gr.State([])
440
+ ys = gr.State([])
441
+ calibrate_prompts = gr.State([
442
+ 'the moon is melting into my glass of tea',
443
+ 'a sea slug -- pair of claws scuttling -- jelly fish glowing',
444
+ 'an adorable creature. It may be a goblin or a pig or a slug.',
445
+ 'an animation about a gorgeous nebula',
446
+ 'a sketch of an impressive mountain by da vinci',
447
+ 'a watercolor painting: the octopus writhes',
448
+ ])
449
+ def l():
450
+ return None
451
+
452
+ with gr.Row(elem_id='output-image'):
453
+ text = gr.Textbox(interactive=False, elem_id="text")
454
+ with gr.Row(equal_height=True):
455
+ b3 = gr.Button(value='Dislike (A)', interactive=False, elem_id="dislike")
456
+ b2 = gr.Button(value='Neither (Space)', interactive=False, elem_id="neither")
457
+ b1 = gr.Button(value='Like (L)', interactive=False, elem_id="like")
458
+ b1.click(
459
+ choose,
460
+ [text, b1, embs, ys, calibrate_prompts],
461
+ [text, embs, ys, calibrate_prompts]
462
+ )
463
+ b2.click(
464
+ choose,
465
+ [text, b2, embs, ys, calibrate_prompts],
466
+ [text, embs, ys, calibrate_prompts]
467
+ )
468
+ b3.click(
469
+ choose,
470
+ [text, b3, embs, ys, calibrate_prompts],
471
+ [text, embs, ys, calibrate_prompts]
472
+ )
473
+ with gr.Row():
474
+ b4 = gr.Button(value='Start')
475
+ b4.click(start,
476
+ [b4, embs, ys, calibrate_prompts],
477
+ [b1, b2, b3, b4, text, embs, ys, calibrate_prompts])
478
+ with gr.Row():
479
+ html = gr.HTML('''<div style='text-align:center; font-size:20px'>You will calibrate for several prompts and then roam. </ div><br><br><br>
480
+ <div style='text-align:center; font-size:14px'>Note that while the model is unlikely to produce NSFW text, this may still occur, and users should avoid NSFW content when rating.
481
+ </ div>
482
+ <br><br>
483
+ <div style='text-align:center; font-size:14px'>Thanks to @multimodalart for their contributions to the demo, esp. the interface and @maxbittker for feedback.
484
+ </ div>''')
485
+
486
+ demo.launch(share=True)
Optimus/code/examples/big_ae/metrics.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from multiprocessing import Pool
3
+ import pdb
4
+ import numpy as np
5
+ import nltk
6
+ nltk.download('punkt')
7
+
8
+ from nltk.translate.bleu_score import SmoothingFunction
9
+
10
+ try:
11
+ from multiprocessing import cpu_count
12
+ except:
13
+ from os import cpu_count
14
+
15
+ class Metrics(object):
16
+ def __init__(self):
17
+ self.name = 'Metric'
18
+
19
+ def get_name(self):
20
+ return self.name
21
+
22
+ def set_name(self, name):
23
+ self.name = name
24
+
25
+ def get_score(self):
26
+ pass
27
+
28
+
29
+ class Bleu(Metrics):
30
+ def __init__(self, test_text='', real_text='', gram=3, num_real_sentences=500, num_fake_sentences=10000):
31
+ super(Bleu, self).__init__()
32
+ self.name = 'Bleu'
33
+ self.test_data = test_text
34
+ self.real_data = real_text
35
+ self.gram = gram
36
+ self.sample_size = num_real_sentences
37
+ self.reference = None
38
+ self.is_first = True
39
+ self.num_sentences = num_fake_sentences
40
+
41
+
42
+ def get_name(self):
43
+ return self.name
44
+
45
+ def get_score(self, is_fast=True, ignore=False):
46
+ if ignore:
47
+ return 0
48
+ if self.is_first:
49
+ self.get_reference()
50
+ self.is_first = False
51
+ if is_fast:
52
+ return self.get_bleu_fast()
53
+ return self.get_bleu_parallel()
54
+
55
+ # fetch REAL DATA
56
+ def get_reference(self):
57
+ if self.reference is None:
58
+ reference = list()
59
+ with open(self.real_data) as real_data:
60
+ for text in real_data:
61
+ text = nltk.word_tokenize(text)
62
+ reference.append(text)
63
+ self.reference = reference
64
+ return reference
65
+ else:
66
+ return self.reference
67
+
68
+ def get_bleu(self):
69
+ raise Exception('make sure you call BLEU paralell')
70
+ ngram = self.gram
71
+ bleu = list()
72
+ reference = self.get_reference()
73
+ weight = tuple((1. / ngram for _ in range(ngram)))
74
+ with open(self.test_data) as test_data:
75
+ for hypothesis in test_data:
76
+ hypothesis = nltk.word_tokenize(hypothesis)
77
+ bleu.append(nltk.translate.bleu_score.sentence_bleu(reference, hypothesis, weight,
78
+ smoothing_function=SmoothingFunction().method1))
79
+ return sum(bleu) / len(bleu)
80
+
81
+ def calc_bleu(self, reference, hypothesis, weight):
82
+ return nltk.translate.bleu_score.sentence_bleu(reference, hypothesis, weight,
83
+ smoothing_function=SmoothingFunction().method1)
84
+
85
+ def get_bleu_fast(self):
86
+ reference = self.get_reference()
87
+ reference = reference[0:self.sample_size]
88
+ return self.get_bleu_parallel(reference=reference)
89
+
90
+ def get_bleu_parallel(self, reference=None):
91
+ ngram = self.gram
92
+ if reference is None:
93
+ reference = self.get_reference()
94
+ weight = tuple((1. / ngram for _ in range(ngram)))
95
+ pool = Pool(cpu_count())
96
+ result = list()
97
+ maxx = self.num_sentences
98
+ with open(self.test_data) as test_data:
99
+ for i, hypothesis in enumerate(test_data):
100
+ #print('i : {}'.format(i))
101
+ hypothesis = nltk.word_tokenize(hypothesis)
102
+ result.append(pool.apply_async(self.calc_bleu, args=(reference, hypothesis, weight)))
103
+ if i > maxx : break
104
+ score = 0.0
105
+ cnt = 0
106
+ for it, i in enumerate(result):
107
+ #print('i : {}'.format(it))
108
+ score += i.get()
109
+ cnt += 1
110
+ pool.close()
111
+ pool.join()
112
+ return score / cnt
113
+
114
+
115
+
116
+
117
+ class SelfBleu(Metrics):
118
+ def __init__(self, test_text='', gram=3, model_path='', num_sentences=500):
119
+ super(SelfBleu, self).__init__()
120
+ self.name = 'Self-Bleu'
121
+ self.test_data = test_text
122
+ self.gram = gram
123
+ self.sample_size = num_sentences
124
+ self.reference = None
125
+ self.is_first = True
126
+
127
+
128
+ def get_name(self):
129
+ return self.name
130
+
131
+ def get_score(self, is_fast=True, ignore=False):
132
+ if ignore:
133
+ return 0
134
+ if self.is_first:
135
+ self.get_reference()
136
+ self.is_first = False
137
+ if is_fast:
138
+ return self.get_bleu_fast()
139
+ return self.get_bleu_parallel()
140
+
141
+ def get_reference(self):
142
+ if self.reference is None:
143
+ reference = list()
144
+ with open(self.test_data) as real_data:
145
+ for text in real_data:
146
+ text = nltk.word_tokenize(text)
147
+ reference.append(text)
148
+ self.reference = reference
149
+ return reference
150
+ else:
151
+ return self.reference
152
+
153
+ def get_bleu(self):
154
+ ngram = self.gram
155
+ bleu = list()
156
+ reference = self.get_reference()
157
+ weight = tuple((1. / ngram for _ in range(ngram)))
158
+ with open(self.test_data) as test_data:
159
+ for hypothesis in test_data:
160
+ hypothesis = nltk.word_tokenize(hypothesis)
161
+ bleu.append(nltk.translate.bleu_score.sentence_bleu(reference, hypothesis, weight,
162
+ smoothing_function=SmoothingFunction().method1))
163
+ return sum(bleu) / len(bleu)
164
+
165
+ def calc_bleu(self, reference, hypothesis, weight):
166
+ return nltk.translate.bleu_score.sentence_bleu(reference, hypothesis, weight,
167
+ smoothing_function=SmoothingFunction().method1)
168
+
169
+ def get_bleu_fast(self):
170
+ reference = self.get_reference()
171
+ # random.shuffle(reference)
172
+ reference = reference[0:self.sample_size]
173
+ return self.get_bleu_parallel(reference=reference)
174
+
175
+ def get_bleu_parallel(self, reference=None):
176
+ ngram = self.gram
177
+ if reference is None:
178
+ reference = self.get_reference()
179
+ weight = tuple((1. / ngram for _ in range(ngram)))
180
+ pool = Pool(cpu_count())
181
+ result = list()
182
+ sentence_num = len(reference)
183
+ for index in range(sentence_num):
184
+ #genious:
185
+ hypothesis = reference[index]
186
+ other = reference[:index] + reference[index+1:]
187
+ result.append(pool.apply_async(self.calc_bleu, args=(other, hypothesis, weight)))
188
+
189
+ score = 0.0
190
+ cnt = 0
191
+ for i in result:
192
+ score += i.get()
193
+ cnt += 1
194
+ pool.close()
195
+ pool.join()
196
+ return score / cnt
Optimus/code/examples/big_ae/modules/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from .encoders import *
2
+ from .decoders import *
3
+ from .vae import *
4
+ from .utils import *
5
+ from .spacefusion import *
6
+ from .cara import *
7
+ from .arae import *
Optimus/code/examples/big_ae/modules/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (327 Bytes). View file
 
Optimus/code/examples/big_ae/modules/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (270 Bytes). View file
 
Optimus/code/examples/big_ae/modules/__pycache__/arae.cpython-310.pyc ADDED
Binary file (6.64 kB). View file
 
Optimus/code/examples/big_ae/modules/__pycache__/arae.cpython-37.pyc ADDED
Binary file (6.44 kB). View file
 
Optimus/code/examples/big_ae/modules/__pycache__/cara.cpython-310.pyc ADDED
Binary file (8.63 kB). View file
 
Optimus/code/examples/big_ae/modules/__pycache__/cara.cpython-37.pyc ADDED
Binary file (8.41 kB). View file
 
Optimus/code/examples/big_ae/modules/__pycache__/spacefusion.cpython-310.pyc ADDED
Binary file (4.44 kB). View file
 
Optimus/code/examples/big_ae/modules/__pycache__/spacefusion.cpython-37.pyc ADDED
Binary file (4.37 kB). View file
 
Optimus/code/examples/big_ae/modules/__pycache__/utils.cpython-310.pyc ADDED
Binary file (1.34 kB). View file
 
Optimus/code/examples/big_ae/modules/__pycache__/utils.cpython-37.pyc ADDED
Binary file (1.28 kB). View file
 
Optimus/code/examples/big_ae/modules/__pycache__/vae.cpython-310.pyc ADDED
Binary file (14.8 kB). View file
 
Optimus/code/examples/big_ae/modules/__pycache__/vae.cpython-37.pyc ADDED
Binary file (15 kB). View file
 
Optimus/code/examples/big_ae/modules/arae.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ from .utils import log_sum_exp
5
+ import pdb
6
+ import sys
7
+ sys.path.append('../../')
8
+ from pytorch_transformers.modeling_bert import BertEmbeddings
9
+ import torch.nn.functional as F
10
+
11
+
12
+ class ARAE(nn.Module):
13
+ def __init__(self, encoder, decoder, tokenizer_encoder, tokenizer_decoder, args): #
14
+ super(ARAE, self).__init__()
15
+ self.encoder = encoder
16
+ self.decoder = decoder
17
+ self.tokenizer_encoder = tokenizer_encoder
18
+ self.tokenizer_decoder = tokenizer_decoder
19
+
20
+ self.args = args
21
+ self.nz = args.latent_size
22
+
23
+ self.bos_token_id_list = self.tokenizer_decoder.encode(self.tokenizer_decoder.bos_token)
24
+ self.pad_token_id = self.tokenizer_decoder.encode(self.tokenizer_decoder.pad_token)[0]
25
+
26
+ # connector: from Bert hidden units to the latent space
27
+ self.linear = nn.Linear(encoder.config.hidden_size, self.nz, bias=False)
28
+
29
+ # # Standard Normal prior
30
+ # loc = torch.zeros(self.nz, device=args.device)
31
+ # scale = torch.ones(self.nz, device=args.device)
32
+ # self.prior = torch.distributions.normal.Normal(loc, scale)
33
+
34
+ self.label_embedding = nn.Embedding(args.label_size, self.nz, padding_idx=0) # use the same size as latent_z so as to use the same decoder.linear()
35
+ self.latent_generator = nn.Linear(self.nz, self.nz)
36
+ self.latent_classifier = nn.Linear(self.nz, args.label_size if args.label_size > 2 else 1)
37
+ self.latent_discriminator = nn.Linear(self.nz, 1)
38
+
39
+ self.gpt_embeddings = nn.Embedding(self.decoder.config.vocab_size, self.decoder.config.n_embd)
40
+ self.gpt_embeddings.weight.data = decoder.transformer.wte.weight.data
41
+
42
+ self.conv1 = nn.Conv1d(self.encoder.config.hidden_size, self.encoder.config.hidden_size, 3)
43
+ self.classifier = nn.Linear(self.encoder.config.hidden_size, 1 if args.label_size <= 2 else args.label_size)
44
+
45
+ self.CrossEntropyLoss = torch.nn.CrossEntropyLoss()
46
+ self.BCEWithLogitsLoss = torch.nn.BCEWithLogitsLoss()
47
+
48
+ def forward(self, input_seq_ids, tgt_seq_ids, cond_labels, attention_mask=None):
49
+ # inputs: (B, seq_len)
50
+ # labels: (B, seq_len)
51
+ # cond_labels: (B), conditional labels.
52
+
53
+ ones_label = torch.ones_like(cond_labels).to(dtype=torch.float32)
54
+ zeros_label = torch.zeros_like(cond_labels).to(dtype=torch.float32)
55
+ random_noise = torch.nn.init.normal_(torch.empty(input_seq_ids.size(0), self.nz)).to(device=input_seq_ids.device, dtype=torch.float32)
56
+
57
+ # Encode inputs
58
+ outputs = self.encoder(input_seq_ids, attention_mask=attention_mask)
59
+ pooled_hidden_fea = outputs[1] # (B, dim_h)
60
+
61
+ # Encode z
62
+ latent_z = self.linear(pooled_hidden_fea) # (B, nz)
63
+
64
+ # Generate z
65
+ gen_z = self.latent_generator(random_noise) # (B, nz)
66
+
67
+ # Latent discriminator
68
+ prob_encode_z_dis = self.latent_discriminator(latent_z).squeeze(1).float() # (B)
69
+ prob_gen_z_dis = self.latent_discriminator(gen_z).squeeze(1).float() # (B)
70
+ # Train latent discriminator
71
+ loss_lsd = self.BCEWithLogitsLoss(prob_gen_z_dis, zeros_label) + self.BCEWithLogitsLoss(prob_encode_z_dis, ones_label)
72
+ acc_encode_z_dis = ((prob_encode_z_dis >= 0).float() == ones_label).float()
73
+ acc_gen_z_dis = ((prob_gen_z_dis >= 0).float() == zeros_label).float()
74
+ # Train sampler adversarially
75
+ loss_lsg = self.BCEWithLogitsLoss(prob_gen_z_dis, ones_label)
76
+
77
+ # Latent classifier
78
+ prob_encode_z_cls = self.latent_classifier(latent_z) # (B, n_labels)
79
+ if self.args.label_size <= 2:
80
+ prob_encode_z_cls = prob_encode_z_cls.squeeze(1) # (B)
81
+ # Train latent classifier
82
+ loss_lsc = self.BCEWithLogitsLoss(prob_encode_z_cls, cond_labels.float())
83
+ acc_encode_z_cls = ((prob_encode_z_cls >= 0).float() == cond_labels.float()).float()
84
+ # Train encoder adversarially
85
+ loss_encoder = 1 - self.BCEWithLogitsLoss(prob_encode_z_cls, cond_labels.float())
86
+ else:
87
+ # Train latent classifier
88
+ loss_lsc = self.CrossEntropyLoss(prob_encode_z_cls, cond_labels)
89
+ acc_encode_z_cls = (torch.argmax(prob_encode_z_cls, dim=-1) == cond_labels).float()
90
+ # Train encoder adversarially
91
+ loss_encoder = 1 - self.CrossEntropyLoss(prob_encode_z_cls, cond_labels)
92
+
93
+ # Embed labels
94
+ label_emb = self.label_embedding(cond_labels) # (B, hidden_size)
95
+ past_label = self.decoder.linear(label_emb) # (B, n_blocks * hidden_size) # todo: use the same linear layer for latent_z for now.
96
+ if self.args.label_size <= 2:
97
+ sampled_cond_labels = 1 - cond_labels
98
+ else:
99
+ raise NotImplementedError # todo: currently only implemented for binary labels. need to change for multi-class labels.
100
+ sampled_label_emb = self.label_embedding(sampled_cond_labels) # (B, hidden_size)
101
+ past_sampled_label = self.decoder.linear(sampled_label_emb) # (B, n_blocks * hidden_size) # todo: use the same linear layer for latent_z for now.
102
+
103
+ # Generate based on encoded z and gt labels. (reconstruction)
104
+ past_z = self.decoder.linear(latent_z) # (B, n_blocks * hidden_size)
105
+ gen_past_z = self.decoder.linear(gen_z) # (B, n_blocks * hidden_size)
106
+
107
+ past = torch.cat([past_z.unsqueeze(1), past_label.unsqueeze(1)], dim=1) # (B, 2, n_blocks * hidden_size)
108
+ outputs = self.decoder(input_ids=tgt_seq_ids, past=past, labels=tgt_seq_ids, label_ignore=self.pad_token_id)
109
+ loss_rec = outputs[0]
110
+
111
+ # Train a classifier in the observation space
112
+ tgt_emb = self.gpt_embeddings(tgt_seq_ids)
113
+ tgt_encode = self.conv1(tgt_emb.transpose(1, 2)) # (B, dim_h, seq_len)
114
+ tgt_encode = torch.mean(tgt_encode, dim=-1) # (B, dim_h)
115
+ prob_cls = self.classifier(tgt_encode) # (B, n_labels)
116
+ if self.args.label_size <= 2:
117
+ prob_cls = prob_cls.squeeze(1)
118
+ loss_cls = self.BCEWithLogitsLoss(prob_cls, cond_labels.float())
119
+ pred_cls = (prob_cls >= 0).to(dtype=torch.long)
120
+ else:
121
+ loss_cls = self.CrossEntropyLoss(prob_cls, cond_labels)
122
+ pred_cls = torch.argmax(prob_cls, dim=-1)
123
+ acc_cls = (pred_cls == cond_labels).float()
124
+
125
+ # Loss
126
+ loss = loss_rec + loss_encoder + loss_lsc + loss_lsd + loss_lsg + loss_cls
127
+
128
+ if not self.training:
129
+ # Generate based on encoded z and gt labels
130
+ generated = self.sample_sequence_conditional_batch(past=past, context=self.bos_token_id_list)
131
+
132
+ # Generate based on encoded z and sampled labels (attribute transfer)
133
+ at_past = torch.cat([past_z.unsqueeze(1), past_sampled_label.unsqueeze(1)], dim=1) # (B, 2, n_blocks * hidden_size)
134
+ at_generated = self.sample_sequence_conditional_batch(past=at_past, context=self.bos_token_id_list) # (B, seq_len)
135
+
136
+ # Generate based on sampled z and sampled labels. (conditional generation)
137
+ cg_past = torch.cat([gen_past_z.unsqueeze(1), past_sampled_label.unsqueeze(1)], dim=1) # (B, 2, n_blocks * hidden_size)
138
+ cg_generated = self.sample_sequence_conditional_batch(past=cg_past, context=self.bos_token_id_list) # (B, seq_len)
139
+
140
+ # classifier on gt generated sentences.
141
+ ge_emb = self.gpt_embeddings(generated)
142
+ ge_encode = self.conv1(ge_emb.transpose(1, 2)) # (B, dim_h, seq_len)
143
+ ge_encode = torch.mean(ge_encode, dim=-1) # (B, dim_h)
144
+ prob_ge_cls = self.classifier(ge_encode) # (B, 1)
145
+
146
+ if self.args.label_size <= 2:
147
+ pred_ge_cls = (prob_ge_cls.squeeze(1) >= 0).to(torch.long)
148
+ else:
149
+ pred_ge_cls = torch.argmax(prob_ge_cls, dim=-1)
150
+ acc_ge_cls = (pred_ge_cls == cond_labels).float()
151
+
152
+ # classifier on attribute transfer generated sentences.
153
+ at_emb = self.gpt_embeddings(at_generated)
154
+ at_encode = self.conv1(at_emb.transpose(1, 2)) # (B, dim_h, seq_len)
155
+ at_encode = torch.mean(at_encode, dim=-1) # (B, dim_h)
156
+ prob_at_cls = self.classifier(at_encode) # (B, 1)
157
+ if self.args.label_size <= 2:
158
+ pred_at_cls = (prob_at_cls.squeeze(1) >= 0).to(torch.long)
159
+ else:
160
+ pred_at_cls = torch.argmax(prob_at_cls, dim=-1)
161
+ acc_at_cls = (pred_at_cls == sampled_cond_labels).float()
162
+
163
+ # classifier on conditional generated sentences.
164
+ cg_emb = self.gpt_embeddings(cg_generated)
165
+ cg_encode = self.conv1(cg_emb.transpose(1, 2)) # (B, dim_h, seq_len)
166
+ cg_encode = torch.mean(cg_encode, dim=-1) # (B, dim_h)
167
+ prob_cg_cls = self.classifier(cg_encode) # (B, 1)
168
+ if self.args.label_size <= 2:
169
+ pred_cg_cls = (prob_cg_cls.squeeze(1) >= 0).to(torch.long)
170
+ else:
171
+ pred_cg_cls = torch.argmax(prob_cg_cls, dim=-1)
172
+ acc_cg_cls = (pred_cg_cls == sampled_cond_labels).float()
173
+
174
+ result = {
175
+ 'sampled_cond_labels': sampled_cond_labels,
176
+ 'cond_labels': cond_labels,
177
+
178
+ 'tgt_seq_ids': tgt_seq_ids,
179
+ 'generated': generated,
180
+ 'at_generated': at_generated,
181
+ 'cg_generated': cg_generated,
182
+
183
+ 'acc_encode_z_dis': acc_encode_z_dis,
184
+ 'acc_gen_z_dis': acc_gen_z_dis,
185
+ 'acc_encode_z_cls': acc_encode_z_cls,
186
+ 'acc_cls': acc_cls,
187
+ 'acc_ge_cls': acc_ge_cls,
188
+ 'acc_at_cls': acc_at_cls,
189
+ 'acc_cg_cls': acc_cg_cls,
190
+
191
+ 'pred_cls': pred_cls,
192
+ 'pred_ge_cls': pred_ge_cls,
193
+ 'pred_at_cls': pred_at_cls,
194
+ 'pred_cg_cls': pred_cg_cls,
195
+ }
196
+
197
+ return result
198
+
199
+ loss_dict = {
200
+ 'loss': loss,
201
+ 'loss_rec': loss_rec,
202
+ 'loss_encoder': loss_encoder,
203
+ 'loss_lsc': loss_lsc,
204
+ 'loss_lsd': loss_lsd,
205
+ 'loss_lsg': loss_lsg,
206
+ 'loss_cls': loss_cls,
207
+ }
208
+ acc_dict = {
209
+ 'acc_encode_z_dis': acc_encode_z_dis,
210
+ 'acc_gen_z_dis': acc_gen_z_dis,
211
+ 'acc_encode_z_cls': acc_encode_z_cls,
212
+ 'acc_cls': acc_cls,
213
+ }
214
+ return loss_dict, acc_dict
215
+
216
+ def sample_sequence_conditional_batch(self, past, context):
217
+ # context: a single id of <BOS>
218
+ # past: (B, past_seq_len dim_h)
219
+ num_samples = past.size(0)
220
+ context = torch.tensor(context, dtype=torch.long, device=past.device)
221
+ context = context.unsqueeze(0).repeat(num_samples, 1)
222
+ generated = context # (B, 1)
223
+
224
+ # with torch.no_grad():
225
+ while generated.size(-1) < self.args.block_size:
226
+ inputs = {'input_ids': generated, 'past': past}
227
+ outputs = self.decoder(**inputs) # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet (cached hidden-states)
228
+ lm_logits = outputs[0]
229
+ next_tokens_logits = lm_logits[:, -1, :] / self.args.temperature # (B, 1, vocab_size)
230
+ filtered_logits = self.top_k_top_p_filtering_batch(next_tokens_logits, top_k=self.args.top_k, top_p=self.args.top_p) # (B, vocab_size)
231
+ filtered_logits = F.softmax(filtered_logits, dim=-1)
232
+ next_tokens = torch.multinomial(filtered_logits, num_samples=1) # (B, 1)
233
+ generated = torch.cat((generated, next_tokens), dim=1) # (B, seq_len+1)
234
+
235
+ not_finished = next_tokens != self.tokenizer_decoder.encode('<EOS>')[0]
236
+ if torch.sum(not_finished) == 0:
237
+ break
238
+
239
+ return generated # (B, seq_len)
240
+
241
+ def top_k_top_p_filtering_batch(self, logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
242
+ """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
243
+ Args:
244
+ logits: logits distribution shape (vocabulary size)
245
+ top_k > 0: keep only top k tokens with highest probability (top-k filtering).
246
+ top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
247
+ Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
248
+ From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
249
+ """
250
+ # assert logits.dim() == 1 # batch size 1 for now - could be updated for more but the code would be less clear
251
+
252
+ top_k = min(top_k, logits.size(-1)) # Safety check
253
+
254
+ if top_k > 0:
255
+ # Remove all tokens with a probability less than the last token of the top-k
256
+ threshold = torch.topk(logits, top_k, dim=-1)[0][:, -1, None]
257
+ logits.masked_fill_(logits < threshold, filter_value) # (B, vocab_size)
258
+
259
+ if top_p > 0.0:
260
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True) # (B, vocab_size)
261
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) # (B, vocab_size)
262
+
263
+ # Remove tokens with cumulative probability above the threshold
264
+ sorted_indices_to_remove = cumulative_probs > top_p
265
+
266
+ # Shift the indices to the right to keep also the first token above the threshold
267
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
268
+ sorted_indices_to_remove[..., 0] = 0
269
+
270
+ indices_to_remove = sorted_indices[sorted_indices_to_remove]
271
+
272
+ logits.masked_fill_(indices_to_remove, filter_value)
273
+
274
+ return logits
Optimus/code/examples/big_ae/modules/cara.py ADDED
@@ -0,0 +1,374 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ from .utils import log_sum_exp
5
+ import pdb
6
+ import sys
7
+ sys.path.append('../../')
8
+ from pytorch_transformers.modeling_bert import BertEmbeddings
9
+ import torch.nn.functional as F
10
+
11
+
12
+ class CARA(nn.Module):
13
+ def __init__(self, encoder, decoder, tokenizer_encoder, tokenizer_decoder, args): #
14
+ super(CARA, self).__init__()
15
+ self.encoder = encoder
16
+ self.decoder = decoder
17
+ self.tokenizer_encoder = tokenizer_encoder
18
+ self.tokenizer_decoder = tokenizer_decoder
19
+
20
+ self.args = args
21
+ self.nz = args.latent_size
22
+
23
+ self.bos_token_id_list = self.tokenizer_decoder.encode(self.tokenizer_decoder.bos_token)
24
+ self.pad_token_id = self.tokenizer_decoder.encode(self.tokenizer_decoder.pad_token)[0]
25
+
26
+ # connector: from Bert hidden units to the latent space
27
+ self.linear = nn.Linear(encoder.config.hidden_size, self.nz, bias=False)
28
+
29
+ # # Standard Normal prior
30
+ # loc = torch.zeros(self.nz, device=args.device)
31
+ # scale = torch.ones(self.nz, device=args.device)
32
+ # self.prior = torch.distributions.normal.Normal(loc, scale)
33
+
34
+ self.label_embedding = nn.Embedding(args.label_size, self.nz, padding_idx=0) # use the same size as latent_z so as to use the same decoder.linear()
35
+ self.latent_generator = nn.Linear(self.nz, self.nz)
36
+ self.latent_classifier = nn.Linear(self.nz, args.label_size if args.label_size > 2 else 1)
37
+ self.latent_discriminator = nn.Linear(self.nz, 1)
38
+
39
+ self.gpt_embeddings = nn.Embedding(self.decoder.config.vocab_size, self.decoder.config.n_embd)
40
+ self.gpt_embeddings.weight.data = decoder.transformer.wte.weight.data
41
+
42
+ self.conv1 = nn.Conv1d(self.encoder.config.hidden_size, self.encoder.config.hidden_size, 3)
43
+ self.classifier = nn.Linear(self.encoder.config.hidden_size, 1 if args.label_size <= 2 else args.label_size)
44
+
45
+ self.CrossEntropyLoss = torch.nn.CrossEntropyLoss()
46
+ self.BCEWithLogitsLoss = torch.nn.BCEWithLogitsLoss()
47
+
48
+ def forward(self, input_seq_ids, tgt_seq_ids, cond_labels, attention_mask):
49
+ # inputs: (B, seq_len)
50
+ # labels: (B, seq_len)
51
+ # cond_labels: (B), conditional labels.
52
+
53
+ ones_label = torch.ones_like(cond_labels).to(dtype=torch.float32)
54
+ zeros_label = torch.zeros_like(cond_labels).to(dtype=torch.float32)
55
+ random_noise = torch.nn.init.normal_(torch.empty(input_seq_ids.size(0), self.nz)).to(device=input_seq_ids.device, dtype=torch.float32)
56
+
57
+ # Encode inputs
58
+ outputs = self.encoder(input_seq_ids, attention_mask=attention_mask)
59
+ pooled_hidden_fea = outputs[1] # (B, dim_h)
60
+
61
+ # Encode z
62
+ latent_z = self.linear(pooled_hidden_fea) # (B, nz)
63
+
64
+ # Generate z
65
+ gen_z = self.latent_generator(random_noise) # (B, nz)
66
+
67
+ #################### Latent discriminator for sampling from a simple distribution ####################
68
+ prob_encode_z_dis = self.latent_discriminator(latent_z).squeeze(1).float() # (B)
69
+ prob_gen_z_dis = self.latent_discriminator(gen_z).squeeze(1).float() # (B)
70
+ # Train latent discriminator
71
+ loss_lsd = self.BCEWithLogitsLoss(prob_gen_z_dis, zeros_label) + self.BCEWithLogitsLoss(prob_encode_z_dis, ones_label)
72
+ acc_encode_z_dis = ((prob_encode_z_dis >= 0).float() == ones_label).float()
73
+ acc_gen_z_dis = ((prob_gen_z_dis >= 0).float() == zeros_label).float()
74
+ # Train sampler adversarially
75
+ loss_lsg = self.BCEWithLogitsLoss(prob_gen_z_dis, ones_label)
76
+
77
+ #################### Latent classifier for disentanglement ####################
78
+ prob_encode_z_cls = self.latent_classifier(latent_z) # (B, n_labels)
79
+ if self.args.label_size <= 2:
80
+ prob_encode_z_cls = prob_encode_z_cls.squeeze(1) # (B)
81
+ # Train latent classifier
82
+ loss_lsc = self.BCEWithLogitsLoss(prob_encode_z_cls, cond_labels.float())
83
+ acc_encode_z_cls = ((prob_encode_z_cls >= 0).float() == cond_labels.float()).float()
84
+ # Train encoder adversarially
85
+ loss_encoder = 1 - self.BCEWithLogitsLoss(prob_encode_z_cls, cond_labels.float())
86
+ else:
87
+ # Train latent classifier
88
+ loss_lsc = self.CrossEntropyLoss(prob_encode_z_cls, cond_labels)
89
+ acc_encode_z_cls = (torch.argmax(prob_encode_z_cls, dim=-1) == cond_labels).float()
90
+ # Train encoder adversarially
91
+ loss_encoder = 1 - self.CrossEntropyLoss(prob_encode_z_cls, cond_labels)
92
+
93
+
94
+ #################### Recontruction loss with latent z and label emb ####################
95
+ # Embed labels
96
+ label_emb = self.label_embedding(cond_labels) # (B, hidden_size)
97
+ # past_label = self.decoder.linear(label_emb) # (B, n_blocks * hidden_size) # todo: use the same linear layer for latent_z for now.
98
+ if self.args.label_size <= 2:
99
+ sampled_cond_labels = 1 - cond_labels
100
+ else:
101
+ raise NotImplementedError # todo: currently only implemented for binary labels. need to change for multi-class labels.
102
+ sampled_label_emb = self.label_embedding(sampled_cond_labels) # (B, hidden_size)
103
+ # past_sampled_label = self.decoder.linear(sampled_label_emb) # (B, n_blocks * hidden_size) # todo: use the same linear layer for latent_z for now.
104
+ past_sampled_label = sampled_label_emb
105
+
106
+ # Generate based on encoded z and gt labels. (reconstruction)
107
+ # past_z = self.decoder.linear(latent_z) # (B, n_blocks * hidden_size)
108
+ past_z = latent_z
109
+ # gen_past_z = self.decoder.linear(gen_z) # (B, n_blocks * hidden_size)
110
+ gen_past_z = gen_z # (B, n_blocks * hidden_size)
111
+
112
+ # past = torch.cat([past_z.unsqueeze(1), past_label.unsqueeze(1)], dim=1) # (B, 2, n_blocks * hidden_size)
113
+
114
+ past = latent_z + label_emb # (B, n_blocks * hidden_size)
115
+
116
+ outputs = self.decoder(input_ids=tgt_seq_ids, past=past, labels=tgt_seq_ids, label_ignore=self.pad_token_id)
117
+ loss_rec = outputs[0]
118
+
119
+ #################### Train a classifier in the observation space ####################
120
+ tgt_emb = self.gpt_embeddings(tgt_seq_ids)
121
+ tgt_encode = self.conv1(tgt_emb.transpose(1, 2)) # (B, dim_h, seq_len)
122
+ tgt_encode = torch.mean(tgt_encode, dim=-1) # (B, dim_h)
123
+ prob_cls = self.classifier(tgt_encode) # (B, n_labels)
124
+ if self.args.label_size <= 2:
125
+ prob_cls = prob_cls.squeeze(1)
126
+ loss_cls = self.BCEWithLogitsLoss(prob_cls, cond_labels.float())
127
+ pred_cls = (prob_cls >= 0).to(dtype=torch.long)
128
+ else:
129
+ loss_cls = self.CrossEntropyLoss(prob_cls, cond_labels)
130
+ pred_cls = torch.argmax(prob_cls, dim=-1)
131
+ acc_cls = (pred_cls == cond_labels).float()
132
+
133
+ # Generate based on encoded z and sampled labels (attribute transfer)
134
+ # at_past = torch.cat([past_z.unsqueeze(1), past_sampled_label.unsqueeze(1)], dim=1) # (B, 2, n_blocks * hidden_size)
135
+ # at_generated_soft = self.sample_sequence_conditional_batch_soft(past=at_past, context=self.bos_token_id_list) # (B, seq_len, vocab_size)
136
+
137
+ # # Classifier on attribute transfer generated sentences. Train Generator on attribute transfer.
138
+ # at_soft_emb = torch.matmul(at_generated_soft, self.gpt_embeddings.weight)
139
+ # at_soft_encode = self.conv1(at_soft_emb.transpose(1, 2)) # (B, dim_h, seq_len)
140
+ # at_soft_encode = torch.mean(at_soft_encode, dim=-1) # (B, dim_h)
141
+ # prob_at_soft_cls = self.classifier(at_soft_encode) # (B, 1)
142
+ # if self.args.label_size <= 2:
143
+ # prob_at_soft_cls = prob_at_soft_cls.squeeze(1)
144
+ # loss_at_soft_cls = self.BCEWithLogitsLoss(prob_at_soft_cls, sampled_cond_labels.float())
145
+ # pred_at_soft_cls = (prob_at_soft_cls >= 0).to(torch.long)
146
+ # else:
147
+ # loss_at_soft_cls = self.CrossEntropyLoss(prob_at_soft_cls, sampled_cond_labels)
148
+ # pred_at_soft_cls = torch.argmax(prob_at_soft_cls, dim=-1)
149
+ # acc_at_soft_cls = (pred_at_soft_cls == sampled_cond_labels).float()
150
+
151
+ # Loss
152
+ loss_latent_space = (loss_encoder + loss_lsc) + (loss_lsd + loss_lsg) + self.args.beta_cls * loss_cls # + loss_at_soft_cls
153
+ loss = loss_rec + 0.0 * loss_latent_space
154
+
155
+ if not self.training:
156
+ # Generate based on encoded z and gt labels
157
+ generated = self.sample_sequence_conditional_batch(past=past, context=self.bos_token_id_list)
158
+
159
+ # Generate based on encoded z and sampled labels (attribute transfer)
160
+ # at_past = torch.cat([past_z.unsqueeze(1), past_sampled_label.unsqueeze(1)], dim=1) # (B, 2, n_blocks * hidden_size)
161
+ at_past = past_z + past_sampled_label # (B, n_blocks * hidden_size)
162
+ at_generated = self.sample_sequence_conditional_batch(past=at_past, context=self.bos_token_id_list) # (B, seq_len)
163
+
164
+ # Generate based on sampled z and sampled labels. (conditional generation)
165
+ # cg_past = torch.cat([gen_past_z.unsqueeze(1), past_sampled_label.unsqueeze(1)], dim=1) # (B, 2, n_blocks * hidden_size)
166
+ cg_past = gen_past_z + past_sampled_label # (B, n_blocks * hidden_size)
167
+ cg_generated = self.sample_sequence_conditional_batch(past=cg_past, context=self.bos_token_id_list) # (B, seq_len)
168
+
169
+ # classifier on gt generated sentences.
170
+ ge_emb = self.gpt_embeddings(generated)
171
+ ge_encode = self.conv1(ge_emb.transpose(1, 2)) # (B, dim_h, seq_len)
172
+ ge_encode = torch.mean(ge_encode, dim=-1) # (B, dim_h)
173
+ prob_ge_cls = self.classifier(ge_encode) # (B, 1)
174
+
175
+ if self.args.label_size <= 2:
176
+ pred_ge_cls = (prob_ge_cls.squeeze(1) >= 0).to(torch.long)
177
+ else:
178
+ pred_ge_cls = torch.argmax(prob_ge_cls, dim=-1)
179
+ acc_ge_cls = (pred_ge_cls == cond_labels).float()
180
+
181
+ # classifier on attribute transfer generated sentences.
182
+ at_emb = self.gpt_embeddings(at_generated)
183
+ at_encode = self.conv1(at_emb.transpose(1, 2)) # (B, dim_h, seq_len)
184
+ at_encode = torch.mean(at_encode, dim=-1) # (B, dim_h)
185
+ prob_at_cls = self.classifier(at_encode) # (B, 1)
186
+ if self.args.label_size <= 2:
187
+ pred_at_cls = (prob_at_cls.squeeze(1) >= 0).to(torch.long)
188
+ else:
189
+ pred_at_cls = torch.argmax(prob_at_cls, dim=-1)
190
+ acc_at_cls = (pred_at_cls == sampled_cond_labels).float()
191
+
192
+ # classifier on conditional generated sentences.
193
+ cg_emb = self.gpt_embeddings(cg_generated)
194
+ cg_encode = self.conv1(cg_emb.transpose(1, 2)) # (B, dim_h, seq_len)
195
+ cg_encode = torch.mean(cg_encode, dim=-1) # (B, dim_h)
196
+ prob_cg_cls = self.classifier(cg_encode) # (B, 1)
197
+ if self.args.label_size <= 2:
198
+ pred_cg_cls = (prob_cg_cls.squeeze(1) >= 0).to(torch.long)
199
+ else:
200
+ pred_cg_cls = torch.argmax(prob_cg_cls, dim=-1)
201
+ acc_cg_cls = (pred_cg_cls == sampled_cond_labels).float()
202
+
203
+ result = {
204
+ 'sampled_cond_labels': sampled_cond_labels,
205
+ 'cond_labels': cond_labels,
206
+
207
+ 'tgt_seq_ids': tgt_seq_ids,
208
+ 'generated': generated,
209
+ 'at_generated': at_generated,
210
+ 'cg_generated': cg_generated,
211
+
212
+ 'acc_encode_z_dis': acc_encode_z_dis,
213
+ 'acc_gen_z_dis': acc_gen_z_dis,
214
+ 'acc_encode_z_cls': acc_encode_z_cls,
215
+ 'acc_cls': acc_cls,
216
+ 'acc_ge_cls': acc_ge_cls,
217
+ 'acc_at_cls': acc_at_cls,
218
+ 'acc_cg_cls': acc_cg_cls,
219
+
220
+ 'pred_cls': pred_cls,
221
+ 'pred_ge_cls': pred_ge_cls,
222
+ 'pred_at_cls': pred_at_cls,
223
+ 'pred_cg_cls': pred_cg_cls,
224
+ }
225
+
226
+ return result
227
+
228
+ loss_dict = {
229
+ 'loss': loss,
230
+ 'loss_rec': loss_rec,
231
+ 'loss_encoder': loss_encoder,
232
+ 'loss_lsc': loss_lsc,
233
+ 'loss_lsd': loss_lsd,
234
+ 'loss_lsg': loss_lsg,
235
+ 'loss_cls': loss_cls,
236
+ # 'loss_at_soft_cls': loss_at_soft_cls,
237
+ }
238
+ acc_dict = {
239
+ 'acc_encode_z_dis': acc_encode_z_dis,
240
+ 'acc_gen_z_dis': acc_gen_z_dis,
241
+ 'acc_encode_z_cls': acc_encode_z_cls,
242
+ 'acc_cls': acc_cls,
243
+ # 'acc_at_soft_cls': acc_at_soft_cls,
244
+ }
245
+ return loss_dict, acc_dict
246
+
247
+ def sample_sequence_conditional_batch(self, past, context):
248
+ # context: a single id of <BOS>
249
+ # past: (B, past_seq_len dim_h)
250
+ num_samples = past.size(0)
251
+ context = torch.tensor(context, dtype=torch.long, device=past.device)
252
+ context = context.unsqueeze(0).repeat(num_samples, 1)
253
+ generated = context # (B, 1)
254
+
255
+ # with torch.no_grad():
256
+ while generated.size(-1) < self.args.block_size:
257
+ inputs = {'input_ids': generated, 'past': past}
258
+ outputs = self.decoder(**inputs) # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet (cached hidden-states)
259
+ lm_logits = outputs[0]
260
+
261
+ # softmax sample
262
+ next_tokens_logits = lm_logits[:, -1, :] / self.args.temperature # (B, 1, vocab_size)
263
+ filtered_logits = self.top_k_top_p_filtering_batch(next_tokens_logits, top_k=self.args.top_k, top_p=self.args.top_p) # (B, 1, vocab_size)
264
+ filtered_logits = F.softmax(filtered_logits, dim=-1)
265
+ next_tokens = torch.multinomial(filtered_logits, num_samples=1) # (B, 1)
266
+ generated = torch.cat((generated, next_tokens), dim=1) # (B, seq_len+1)
267
+
268
+ not_finished = next_tokens != self.tokenizer_decoder.encode('<EOS>')[0]
269
+ if torch.sum(not_finished) == 0:
270
+ break
271
+
272
+ return generated # (B, seq_len)
273
+
274
+ def top_k_top_p_filtering_batch(self, logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
275
+ """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
276
+ Args:
277
+ logits: logits distribution shape (vocabulary size)
278
+ top_k > 0: keep only top k tokens with highest probability (top-k filtering).
279
+ top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
280
+ Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
281
+ From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
282
+ """
283
+ # assert logits.dim() == 1 # batch size 1 for now - could be updated for more but the code would be less clear
284
+
285
+ top_k = min(top_k, logits.size(-1)) # Safety check
286
+
287
+ if top_k > 0:
288
+ # Remove all tokens with a probability less than the last token of the top-k
289
+ threshold = torch.topk(logits, top_k, dim=-1)[0][:, -1, None]
290
+ logits.masked_fill_(logits < threshold, filter_value) # (B, vocab_size)
291
+
292
+ if top_p > 0.0:
293
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True) # (B, vocab_size)
294
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) # (B, vocab_size)
295
+
296
+ # Remove tokens with cumulative probability above the threshold
297
+ sorted_indices_to_remove = cumulative_probs > top_p
298
+
299
+ # Shift the indices to the right to keep also the first token above the threshold
300
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
301
+ sorted_indices_to_remove[..., 0] = 0
302
+
303
+ indices_to_remove = sorted_indices[sorted_indices_to_remove]
304
+
305
+ logits.masked_fill_(indices_to_remove, filter_value)
306
+
307
+ return logits
308
+
309
+ def sample_sequence_conditional_batch_soft(self, past, context):
310
+ # context: a single id of <BOS>
311
+ # past: (B, past_seq_len dim_h)
312
+ num_samples = past.size(0)
313
+ context = torch.tensor(context, dtype=torch.long, device=past.device).unsqueeze(0).repeat(num_samples, 1) # (B, 1)
314
+ context_soft = torch.FloatTensor(num_samples, self.decoder.config.vocab_size).zero_().to(device=past.device) # (B, vocab_size)
315
+ context_soft.scatter_(1, context, 1) # (B, vocab_size)
316
+ generated_soft = context_soft.unsqueeze(1) # (B, 1, vocab_size)
317
+
318
+ # with torch.no_grad():
319
+ while generated_soft.size(1) < self.args.block_size: # generated_soft: (B, seq_len, vocab_size)
320
+ inputs = {'soft_ids': generated_soft, 'past': past}
321
+ outputs = self.decoder(**inputs) # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet (cached hidden-states)
322
+ lm_logits = outputs[0] # (B, seq_len, vocab_size)
323
+
324
+ # Gumbel softmax sample
325
+ next_tokens_soft = gumbel_softmax(logits=lm_logits[:, -1:, :], temperature=self.args.soft_temperature, hard=False) # (B, 1, vocab_size)
326
+ generated_soft = torch.cat((generated_soft, next_tokens_soft), dim=1) # (B, seq_len+1, vocab_size)
327
+
328
+ # # softmax sample
329
+ # next_tokens_logits = lm_logits[:, -1, :] / self.args.temperature # (B, 1, vocab_size)
330
+ # filtered_logits = self.top_k_top_p_filtering_batch(next_tokens_logits, top_k=self.args.top_k, top_p=self.args.top_p) # (B, 1, vocab_size)
331
+ # filtered_logits = F.softmax(filtered_logits, dim=-1)
332
+ # next_tokens = torch.multinomial(filtered_logits, num_samples=1) # (B, 1)
333
+ # generated = torch.cat((generated, next_tokens), dim=1) # (B, seq_len+1)
334
+
335
+ next_tokens = torch.argmax(next_tokens_soft, dim=-1) # (B, 1)
336
+ not_finished = next_tokens != self.tokenizer_decoder.encode('<EOS>')[0]
337
+ if torch.sum(not_finished) == 0:
338
+ break
339
+
340
+ return generated_soft # (B, seq_len, vocab_size)
341
+
342
+
343
+ ### Gumbel Softmax
344
+ def gumbel_softmax(logits, temperature, hard=False):
345
+ """Sample from the Gumbel-Softmax distribution and optionally discretize.
346
+ Args:
347
+ logits: [..., n_class] unnormalized log-probs
348
+ temperature: non-negative scalar
349
+ hard: if True, take argmax, but differentiate w.r.t. soft sample y
350
+ Returns:
351
+ [..., n_class] sample from the Gumbel-Softmax distribution.
352
+ If hard=True, then the returned sample will be one-hot, otherwise it will be a probabilitiy distribution that sums to 1 across classes
353
+ """
354
+ y = gumbel_softmax_sample(logits, temperature) # (..., n_class)
355
+
356
+ if hard: # return onehot
357
+ shape = y.size()
358
+ _, ind = y.max(dim=-1)
359
+ y_hard = torch.zeros_like(y).view(-1, shape[-1])
360
+ y_hard.scatter_(1, ind.view(-1, 1), 1) # one hot
361
+ y_hard = y_hard.view(*shape)
362
+ # Set gradients w.r.t. y_hard gradients w.r.t. y
363
+ y = (y_hard - y).detach() + y
364
+
365
+ return y # (..., n_class)
366
+
367
+ from torch.nn import functional as F
368
+ def gumbel_softmax_sample(logits, temperature):
369
+ y = logits + sample_gumbel(logits.size(), logits.device)
370
+ return F.softmax(y / temperature, dim=-1)
371
+
372
+ def sample_gumbel(shape, device, eps=1e-20):
373
+ U = torch.rand(shape).to(device=device)
374
+ return -torch.log(-torch.log(U + eps) + eps)
Optimus/code/examples/big_ae/modules/ctrl_gen.py ADDED
@@ -0,0 +1,371 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ from .utils import log_sum_exp
5
+ import pdb
6
+ import sys
7
+ sys.path.append('../../')
8
+ from pytorch_transformers.modeling_bert import BertEmbeddings
9
+ import torch.nn.functional as F
10
+
11
+
12
+ class Ctrl_Gen(nn.Module):
13
+ def __init__(self, encoder, decoder, tokenizer_encoder, tokenizer_decoder, args): #
14
+ super(Ctrl_Gen, self).__init__()
15
+ self.encoder = encoder
16
+ self.decoder = decoder
17
+ self.tokenizer_encoder = tokenizer_encoder
18
+ self.tokenizer_decoder = tokenizer_decoder
19
+
20
+ self.args = args
21
+ self.nz = args.latent_size
22
+
23
+ self.bos_token_id_list = self.tokenizer_decoder.encode(self.tokenizer_decoder.bos_token)
24
+ self.pad_token_id = self.tokenizer_decoder.encode(self.tokenizer_decoder.pad_token)[0]
25
+
26
+ # connector: from Bert hidden units to the latent space
27
+ self.linear = nn.Linear(encoder.config.hidden_size, self.nz, bias=False)
28
+
29
+ # # Standard Normal prior
30
+ # loc = torch.zeros(self.nz, device=args.device)
31
+ # scale = torch.ones(self.nz, device=args.device)
32
+ # self.prior = torch.distributions.normal.Normal(loc, scale)
33
+
34
+ self.label_embedding = nn.Embedding(args.label_size, self.nz, padding_idx=0) # use the same size as latent_z so as to use the same decoder.linear()
35
+ self.latent_generator = nn.Linear(self.nz, self.nz)
36
+ self.latent_classifier = nn.Linear(self.nz, args.label_size if args.label_size > 2 else 1)
37
+ self.latent_discriminator = nn.Linear(self.nz, 1)
38
+
39
+ self.gpt_embeddings = nn.Embedding(self.decoder.config.vocab_size, self.decoder.config.n_embd)
40
+ self.gpt_embeddings.weight.data = decoder.transformer.wte.weight.data
41
+
42
+ self.conv1 = nn.Conv1d(self.encoder.config.hidden_size, self.encoder.config.hidden_size, 3)
43
+ self.classifier = nn.Linear(self.encoder.config.hidden_size, 1 if args.label_size <= 2 else args.label_size)
44
+
45
+ self.CrossEntropyLoss = torch.nn.CrossEntropyLoss()
46
+ self.BCEWithLogitsLoss = torch.nn.BCEWithLogitsLoss()
47
+
48
+ def forward(self, input_seq_ids, tgt_seq_ids, cond_labels, attention_mask):
49
+ # inputs: (B, seq_len)
50
+ # labels: (B, seq_len)
51
+ # cond_labels: (B), conditional labels.
52
+
53
+ ones_label = torch.ones_like(cond_labels).to(dtype=torch.float32)
54
+ zeros_label = torch.zeros_like(cond_labels).to(dtype=torch.float32)
55
+ random_noise = torch.nn.init.normal_(torch.empty(input_seq_ids.size(0), self.nz)).to(device=input_seq_ids.device, dtype=torch.float32)
56
+
57
+ # Encode inputs
58
+ outputs = self.encoder(input_seq_ids, attention_mask=attention_mask)
59
+ pooled_hidden_fea = outputs[1] # (B, dim_h)
60
+
61
+ # Encode z
62
+ latent_z = self.linear(pooled_hidden_fea) # (B, nz)
63
+
64
+ # Generate z
65
+ gen_z = self.latent_generator(random_noise) # (B, nz)
66
+
67
+ # Latent discriminator
68
+ prob_encode_z_dis = self.latent_discriminator(latent_z).squeeze(1).float() # (B)
69
+ prob_gen_z_dis = self.latent_discriminator(gen_z).squeeze(1).float() # (B)
70
+ # Train latent discriminator
71
+ loss_lsd = self.BCEWithLogitsLoss(prob_gen_z_dis, zeros_label) + self.BCEWithLogitsLoss(prob_encode_z_dis, ones_label)
72
+ acc_encode_z_dis = ((prob_encode_z_dis >= 0).float() == ones_label).float()
73
+ acc_gen_z_dis = ((prob_gen_z_dis >= 0).float() == zeros_label).float()
74
+ # Train sampler adversarially
75
+ loss_lsg = self.BCEWithLogitsLoss(prob_gen_z_dis, ones_label)
76
+
77
+ # Latent classifier
78
+ prob_encode_z_cls = self.latent_classifier(latent_z) # (B, n_labels)
79
+ if self.args.label_size <= 2:
80
+ prob_encode_z_cls = prob_encode_z_cls.squeeze(1) # (B)
81
+ # Train latent classifier
82
+ loss_lsc = self.BCEWithLogitsLoss(prob_encode_z_cls, cond_labels.float())
83
+ acc_encode_z_cls = ((prob_encode_z_cls >= 0).float() == cond_labels.float()).float()
84
+ # Train encoder adversarially
85
+ loss_encoder = 1 - self.BCEWithLogitsLoss(prob_encode_z_cls, cond_labels.float())
86
+ else:
87
+ # Train latent classifier
88
+ loss_lsc = self.CrossEntropyLoss(prob_encode_z_cls, cond_labels)
89
+ acc_encode_z_cls = (torch.argmax(prob_encode_z_cls, dim=-1) == cond_labels).float()
90
+ # Train encoder adversarially
91
+ loss_encoder = 1 - self.CrossEntropyLoss(prob_encode_z_cls, cond_labels)
92
+
93
+ # Embed labels
94
+ label_emb = self.label_embedding(cond_labels) # (B, hidden_size)
95
+ # past_label = self.decoder.linear(label_emb) # (B, n_blocks * hidden_size) # todo: use the same linear layer for latent_z for now.
96
+ if self.args.label_size <= 2:
97
+ sampled_cond_labels = 1 - cond_labels
98
+ else:
99
+ raise NotImplementedError # todo: currently only implemented for binary labels. need to change for multi-class labels.
100
+ sampled_label_emb = self.label_embedding(sampled_cond_labels) # (B, hidden_size)
101
+ # past_sampled_label = self.decoder.linear(sampled_label_emb) # (B, n_blocks * hidden_size) # todo: use the same linear layer for latent_z for now.
102
+ past_sampled_label = sampled_label_emb
103
+
104
+ # Generate based on encoded z and gt labels. (reconstruction)
105
+ # past_z = self.decoder.linear(latent_z) # (B, n_blocks * hidden_size)
106
+ past_z = latent_z
107
+ # gen_past_z = self.decoder.linear(gen_z) # (B, n_blocks * hidden_size)
108
+ gen_past_z = gen_z # (B, n_blocks * hidden_size)
109
+
110
+ # past = torch.cat([past_z.unsqueeze(1), past_label.unsqueeze(1)], dim=1) # (B, 2, n_blocks * hidden_size)
111
+
112
+ past = latent_z + label_emb # (B, n_blocks * hidden_size)
113
+
114
+ outputs = self.decoder(input_ids=tgt_seq_ids, past=past, labels=tgt_seq_ids, label_ignore=self.pad_token_id)
115
+ loss_rec = outputs[0]
116
+
117
+ # Train a classifier in the observation space
118
+ tgt_emb = self.gpt_embeddings(tgt_seq_ids)
119
+ tgt_encode = self.conv1(tgt_emb.transpose(1, 2)) # (B, dim_h, seq_len)
120
+ tgt_encode = torch.mean(tgt_encode, dim=-1) # (B, dim_h)
121
+ prob_cls = self.classifier(tgt_encode) # (B, n_labels)
122
+ if self.args.label_size <= 2:
123
+ prob_cls = prob_cls.squeeze(1)
124
+ loss_cls = self.BCEWithLogitsLoss(prob_cls, cond_labels.float())
125
+ pred_cls = (prob_cls >= 0).to(dtype=torch.long)
126
+ else:
127
+ loss_cls = self.CrossEntropyLoss(prob_cls, cond_labels)
128
+ pred_cls = torch.argmax(prob_cls, dim=-1)
129
+ acc_cls = (pred_cls == cond_labels).float()
130
+
131
+ # Generate based on encoded z and sampled labels (attribute transfer)
132
+ # at_past = torch.cat([past_z.unsqueeze(1), past_sampled_label.unsqueeze(1)], dim=1) # (B, 2, n_blocks * hidden_size)
133
+ # at_generated_soft = self.sample_sequence_conditional_batch_soft(past=at_past, context=self.bos_token_id_list) # (B, seq_len, vocab_size)
134
+
135
+ # # Classifier on attribute transfer generated sentences. Train Generator on attribute transfer.
136
+ # at_soft_emb = torch.matmul(at_generated_soft, self.gpt_embeddings.weight)
137
+ # at_soft_encode = self.conv1(at_soft_emb.transpose(1, 2)) # (B, dim_h, seq_len)
138
+ # at_soft_encode = torch.mean(at_soft_encode, dim=-1) # (B, dim_h)
139
+ # prob_at_soft_cls = self.classifier(at_soft_encode) # (B, 1)
140
+ # if self.args.label_size <= 2:
141
+ # prob_at_soft_cls = prob_at_soft_cls.squeeze(1)
142
+ # loss_at_soft_cls = self.BCEWithLogitsLoss(prob_at_soft_cls, sampled_cond_labels.float())
143
+ # pred_at_soft_cls = (prob_at_soft_cls >= 0).to(torch.long)
144
+ # else:
145
+ # loss_at_soft_cls = self.CrossEntropyLoss(prob_at_soft_cls, sampled_cond_labels)
146
+ # pred_at_soft_cls = torch.argmax(prob_at_soft_cls, dim=-1)
147
+ # acc_at_soft_cls = (pred_at_soft_cls == sampled_cond_labels).float()
148
+
149
+ # Loss
150
+ loss = loss_rec + loss_encoder + loss_lsc + loss_lsd + loss_lsg + self.args.beta_cls * loss_cls # + loss_at_soft_cls
151
+
152
+ if not self.training:
153
+ # Generate based on encoded z and gt labels
154
+ generated = self.sample_sequence_conditional_batch(past=past, context=self.bos_token_id_list)
155
+
156
+ # Generate based on encoded z and sampled labels (attribute transfer)
157
+ # at_past = torch.cat([past_z.unsqueeze(1), past_sampled_label.unsqueeze(1)], dim=1) # (B, 2, n_blocks * hidden_size)
158
+ at_past = past_z + past_sampled_label # (B, n_blocks * hidden_size)
159
+ at_generated = self.sample_sequence_conditional_batch(past=at_past, context=self.bos_token_id_list) # (B, seq_len)
160
+
161
+ # Generate based on sampled z and sampled labels. (conditional generation)
162
+ # cg_past = torch.cat([gen_past_z.unsqueeze(1), past_sampled_label.unsqueeze(1)], dim=1) # (B, 2, n_blocks * hidden_size)
163
+ cg_past = gen_past_z + past_sampled_label # (B, n_blocks * hidden_size)
164
+ cg_generated = self.sample_sequence_conditional_batch(past=cg_past, context=self.bos_token_id_list) # (B, seq_len)
165
+
166
+ # classifier on gt generated sentences.
167
+ ge_emb = self.gpt_embeddings(generated)
168
+ ge_encode = self.conv1(ge_emb.transpose(1, 2)) # (B, dim_h, seq_len)
169
+ ge_encode = torch.mean(ge_encode, dim=-1) # (B, dim_h)
170
+ prob_ge_cls = self.classifier(ge_encode) # (B, 1)
171
+
172
+ if self.args.label_size <= 2:
173
+ pred_ge_cls = (prob_ge_cls.squeeze(1) >= 0).to(torch.long)
174
+ else:
175
+ pred_ge_cls = torch.argmax(prob_ge_cls, dim=-1)
176
+ acc_ge_cls = (pred_ge_cls == cond_labels).float()
177
+
178
+ # classifier on attribute transfer generated sentences.
179
+ at_emb = self.gpt_embeddings(at_generated)
180
+ at_encode = self.conv1(at_emb.transpose(1, 2)) # (B, dim_h, seq_len)
181
+ at_encode = torch.mean(at_encode, dim=-1) # (B, dim_h)
182
+ prob_at_cls = self.classifier(at_encode) # (B, 1)
183
+ if self.args.label_size <= 2:
184
+ pred_at_cls = (prob_at_cls.squeeze(1) >= 0).to(torch.long)
185
+ else:
186
+ pred_at_cls = torch.argmax(prob_at_cls, dim=-1)
187
+ acc_at_cls = (pred_at_cls == sampled_cond_labels).float()
188
+
189
+ # classifier on conditional generated sentences.
190
+ cg_emb = self.gpt_embeddings(cg_generated)
191
+ cg_encode = self.conv1(cg_emb.transpose(1, 2)) # (B, dim_h, seq_len)
192
+ cg_encode = torch.mean(cg_encode, dim=-1) # (B, dim_h)
193
+ prob_cg_cls = self.classifier(cg_encode) # (B, 1)
194
+ if self.args.label_size <= 2:
195
+ pred_cg_cls = (prob_cg_cls.squeeze(1) >= 0).to(torch.long)
196
+ else:
197
+ pred_cg_cls = torch.argmax(prob_cg_cls, dim=-1)
198
+ acc_cg_cls = (pred_cg_cls == sampled_cond_labels).float()
199
+
200
+ result = {
201
+ 'sampled_cond_labels': sampled_cond_labels,
202
+ 'cond_labels': cond_labels,
203
+
204
+ 'tgt_seq_ids': tgt_seq_ids,
205
+ 'generated': generated,
206
+ 'at_generated': at_generated,
207
+ 'cg_generated': cg_generated,
208
+
209
+ 'acc_encode_z_dis': acc_encode_z_dis,
210
+ 'acc_gen_z_dis': acc_gen_z_dis,
211
+ 'acc_encode_z_cls': acc_encode_z_cls,
212
+ 'acc_cls': acc_cls,
213
+ 'acc_ge_cls': acc_ge_cls,
214
+ 'acc_at_cls': acc_at_cls,
215
+ 'acc_cg_cls': acc_cg_cls,
216
+
217
+ 'pred_cls': pred_cls,
218
+ 'pred_ge_cls': pred_ge_cls,
219
+ 'pred_at_cls': pred_at_cls,
220
+ 'pred_cg_cls': pred_cg_cls,
221
+ }
222
+
223
+ return result
224
+
225
+ loss_dict = {
226
+ 'loss': loss,
227
+ 'loss_rec': loss_rec,
228
+ 'loss_encoder': loss_encoder,
229
+ 'loss_lsc': loss_lsc,
230
+ 'loss_lsd': loss_lsd,
231
+ 'loss_lsg': loss_lsg,
232
+ 'loss_cls': loss_cls,
233
+ # 'loss_at_soft_cls': loss_at_soft_cls,
234
+ }
235
+ acc_dict = {
236
+ 'acc_encode_z_dis': acc_encode_z_dis,
237
+ 'acc_gen_z_dis': acc_gen_z_dis,
238
+ 'acc_encode_z_cls': acc_encode_z_cls,
239
+ 'acc_cls': acc_cls,
240
+ # 'acc_at_soft_cls': acc_at_soft_cls,
241
+ }
242
+ return loss_dict, acc_dict
243
+
244
+ def sample_sequence_conditional_batch(self, past, context):
245
+ # context: a single id of <BOS>
246
+ # past: (B, past_seq_len dim_h)
247
+ num_samples = past.size(0)
248
+ context = torch.tensor(context, dtype=torch.long, device=past.device)
249
+ context = context.unsqueeze(0).repeat(num_samples, 1)
250
+ generated = context # (B, 1)
251
+
252
+ # with torch.no_grad():
253
+ while generated.size(-1) < self.args.block_size:
254
+ inputs = {'input_ids': generated, 'past': past}
255
+ outputs = self.decoder(**inputs) # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet (cached hidden-states)
256
+ lm_logits = outputs[0]
257
+
258
+ # softmax sample
259
+ next_tokens_logits = lm_logits[:, -1, :] / self.args.temperature # (B, 1, vocab_size)
260
+ filtered_logits = self.top_k_top_p_filtering_batch(next_tokens_logits, top_k=self.args.top_k, top_p=self.args.top_p) # (B, 1, vocab_size)
261
+ filtered_logits = F.softmax(filtered_logits, dim=-1)
262
+ next_tokens = torch.multinomial(filtered_logits, num_samples=1) # (B, 1)
263
+ generated = torch.cat((generated, next_tokens), dim=1) # (B, seq_len+1)
264
+
265
+ not_finished = next_tokens != self.tokenizer_decoder.encode('<EOS>')[0]
266
+ if torch.sum(not_finished) == 0:
267
+ break
268
+
269
+ return generated # (B, seq_len)
270
+
271
+ def top_k_top_p_filtering_batch(self, logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
272
+ """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
273
+ Args:
274
+ logits: logits distribution shape (vocabulary size)
275
+ top_k > 0: keep only top k tokens with highest probability (top-k filtering).
276
+ top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
277
+ Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
278
+ From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
279
+ """
280
+ # assert logits.dim() == 1 # batch size 1 for now - could be updated for more but the code would be less clear
281
+
282
+ top_k = min(top_k, logits.size(-1)) # Safety check
283
+
284
+ if top_k > 0:
285
+ # Remove all tokens with a probability less than the last token of the top-k
286
+ threshold = torch.topk(logits, top_k, dim=-1)[0][:, -1, None]
287
+ logits.masked_fill_(logits < threshold, filter_value) # (B, vocab_size)
288
+
289
+ if top_p > 0.0:
290
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True) # (B, vocab_size)
291
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) # (B, vocab_size)
292
+
293
+ # Remove tokens with cumulative probability above the threshold
294
+ sorted_indices_to_remove = cumulative_probs > top_p
295
+
296
+ # Shift the indices to the right to keep also the first token above the threshold
297
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
298
+ sorted_indices_to_remove[..., 0] = 0
299
+
300
+ indices_to_remove = sorted_indices[sorted_indices_to_remove]
301
+
302
+ logits.masked_fill_(indices_to_remove, filter_value)
303
+
304
+ return logits
305
+
306
+ def sample_sequence_conditional_batch_soft(self, past, context):
307
+ # context: a single id of <BOS>
308
+ # past: (B, past_seq_len dim_h)
309
+ num_samples = past.size(0)
310
+ context = torch.tensor(context, dtype=torch.long, device=past.device).unsqueeze(0).repeat(num_samples, 1) # (B, 1)
311
+ context_soft = torch.FloatTensor(num_samples, self.decoder.config.vocab_size).zero_().to(device=past.device) # (B, vocab_size)
312
+ context_soft.scatter_(1, context, 1) # (B, vocab_size)
313
+ generated_soft = context_soft.unsqueeze(1) # (B, 1, vocab_size)
314
+
315
+ # with torch.no_grad():
316
+ while generated_soft.size(1) < self.args.block_size: # generated_soft: (B, seq_len, vocab_size)
317
+ inputs = {'soft_ids': generated_soft, 'past': past}
318
+ outputs = self.decoder(**inputs) # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet (cached hidden-states)
319
+ lm_logits = outputs[0] # (B, seq_len, vocab_size)
320
+
321
+ # Gumbel softmax sample
322
+ next_tokens_soft = gumbel_softmax(logits=lm_logits[:, -1:, :], temperature=self.args.soft_temperature, hard=False) # (B, 1, vocab_size)
323
+ generated_soft = torch.cat((generated_soft, next_tokens_soft), dim=1) # (B, seq_len+1, vocab_size)
324
+
325
+ # # softmax sample
326
+ # next_tokens_logits = lm_logits[:, -1, :] / self.args.temperature # (B, 1, vocab_size)
327
+ # filtered_logits = self.top_k_top_p_filtering_batch(next_tokens_logits, top_k=self.args.top_k, top_p=self.args.top_p) # (B, 1, vocab_size)
328
+ # filtered_logits = F.softmax(filtered_logits, dim=-1)
329
+ # next_tokens = torch.multinomial(filtered_logits, num_samples=1) # (B, 1)
330
+ # generated = torch.cat((generated, next_tokens), dim=1) # (B, seq_len+1)
331
+
332
+ next_tokens = torch.argmax(next_tokens_soft, dim=-1) # (B, 1)
333
+ not_finished = next_tokens != self.tokenizer_decoder.encode('<EOS>')[0]
334
+ if torch.sum(not_finished) == 0:
335
+ break
336
+
337
+ return generated_soft # (B, seq_len, vocab_size)
338
+
339
+
340
+ ### Gumbel Softmax
341
+ def gumbel_softmax(logits, temperature, hard=False):
342
+ """Sample from the Gumbel-Softmax distribution and optionally discretize.
343
+ Args:
344
+ logits: [..., n_class] unnormalized log-probs
345
+ temperature: non-negative scalar
346
+ hard: if True, take argmax, but differentiate w.r.t. soft sample y
347
+ Returns:
348
+ [..., n_class] sample from the Gumbel-Softmax distribution.
349
+ If hard=True, then the returned sample will be one-hot, otherwise it will be a probabilitiy distribution that sums to 1 across classes
350
+ """
351
+ y = gumbel_softmax_sample(logits, temperature) # (..., n_class)
352
+
353
+ if hard: # return onehot
354
+ shape = y.size()
355
+ _, ind = y.max(dim=-1)
356
+ y_hard = torch.zeros_like(y).view(-1, shape[-1])
357
+ y_hard.scatter_(1, ind.view(-1, 1), 1) # one hot
358
+ y_hard = y_hard.view(*shape)
359
+ # Set gradients w.r.t. y_hard gradients w.r.t. y
360
+ y = (y_hard - y).detach() + y
361
+
362
+ return y # (..., n_class)
363
+
364
+ from torch.nn import functional as F
365
+ def gumbel_softmax_sample(logits, temperature):
366
+ y = logits + sample_gumbel(logits.size(), logits.device)
367
+ return F.softmax(y / temperature, dim=-1)
368
+
369
+ def sample_gumbel(shape, device, eps=1e-20):
370
+ U = torch.rand(shape).to(device=device)
371
+ return -torch.log(-torch.log(U + eps) + eps)
Optimus/code/examples/big_ae/modules/decoders/dec_gpt2.py ADDED
@@ -0,0 +1,358 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import torch
2
+
3
+ import time
4
+ import argparse
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence
11
+
12
+ import numpy as np
13
+
14
+ from .decoder import DecoderBase
15
+
16
+ class LSTMDecoder(DecoderBase):
17
+ """LSTM decoder with constant-length data"""
18
+ def __init__(self, args, vocab, model_init, emb_init):
19
+ super(LSTMDecoder, self).__init__()
20
+ self.ni = args.ni
21
+ self.nh = args.dec_nh
22
+ self.nz = args.nz
23
+ self.vocab = vocab
24
+ self.device = args.device
25
+
26
+ # no padding when setting padding_idx to -1
27
+ self.embed = nn.Embedding(len(vocab), args.ni, padding_idx=-1)
28
+
29
+ self.dropout_in = nn.Dropout(args.dec_dropout_in)
30
+ self.dropout_out = nn.Dropout(args.dec_dropout_out)
31
+
32
+ # for initializing hidden state and cell
33
+ self.trans_linear = nn.Linear(args.nz, args.dec_nh, bias=False)
34
+
35
+ # concatenate z with input
36
+ self.lstm = nn.LSTM(input_size=args.ni + args.nz,
37
+ hidden_size=args.dec_nh,
38
+ num_layers=1,
39
+ batch_first=True)
40
+
41
+ # prediction layer
42
+ self.pred_linear = nn.Linear(args.dec_nh, len(vocab), bias=False)
43
+
44
+ vocab_mask = torch.ones(len(vocab))
45
+ # vocab_mask[vocab['<pad>']] = 0
46
+ self.loss = nn.CrossEntropyLoss(weight=vocab_mask, reduce=False)
47
+
48
+ self.reset_parameters(model_init, emb_init)
49
+
50
+ def reset_parameters(self, model_init, emb_init):
51
+ # for name, param in self.lstm.named_parameters():
52
+ # # self.initializer(param)
53
+ # if 'bias' in name:
54
+ # nn.init.constant_(param, 0.0)
55
+ # # model_init(param)
56
+ # elif 'weight' in name:
57
+ # model_init(param)
58
+
59
+ # model_init(self.trans_linear.weight)
60
+ # model_init(self.pred_linear.weight)
61
+ for param in self.parameters():
62
+ model_init(param)
63
+ emb_init(self.embed.weight)
64
+
65
+ def sample_text(self, input, z, EOS, device):
66
+ sentence = [input]
67
+ max_index = 0
68
+
69
+ input_word = input
70
+ batch_size, n_sample, _ = z.size()
71
+ seq_len = 1
72
+ z_ = z.expand(batch_size, seq_len, self.nz)
73
+ seq_len = input.size(1)
74
+ softmax = torch.nn.Softmax(dim=0)
75
+ while max_index != EOS and len(sentence) < 100:
76
+ # (batch_size, seq_len, ni)
77
+ word_embed = self.embed(input_word)
78
+ word_embed = torch.cat((word_embed, z_), -1)
79
+ c_init = self.trans_linear(z).unsqueeze(0)
80
+ h_init = torch.tanh(c_init)
81
+ if len(sentence) == 1:
82
+ h_init = h_init.squeeze(dim=1)
83
+ c_init = c_init.squeeze(dim=1)
84
+ output, hidden = self.lstm.forward(word_embed, (h_init, c_init))
85
+ else:
86
+ output, hidden = self.lstm.forward(word_embed, hidden)
87
+ # (batch_size * n_sample, seq_len, vocab_size)
88
+ output_logits = self.pred_linear(output)
89
+ output_logits = output_logits.view(-1)
90
+ probs = softmax(output_logits)
91
+ # max_index = torch.argmax(output_logits)
92
+ max_index = torch.multinomial(probs, num_samples=1)
93
+ input_word = torch.tensor([[max_index]]).to(device)
94
+ sentence.append(max_index)
95
+ return sentence
96
+
97
+ def decode(self, input, z):
98
+ """
99
+ Args:
100
+ input: (batch_size, seq_len)
101
+ z: (batch_size, n_sample, nz)
102
+ """
103
+
104
+ # not predicting start symbol
105
+ # sents_len -= 1
106
+
107
+ batch_size, n_sample, _ = z.size()
108
+ seq_len = input.size(1)
109
+
110
+ # (batch_size, seq_len, ni)
111
+ word_embed = self.embed(input)
112
+ word_embed = self.dropout_in(word_embed)
113
+
114
+ if n_sample == 1:
115
+ z_ = z.expand(batch_size, seq_len, self.nz)
116
+
117
+ else:
118
+ word_embed = word_embed.unsqueeze(1).expand(batch_size, n_sample, seq_len, self.ni) \
119
+ .contiguous()
120
+
121
+ # (batch_size * n_sample, seq_len, ni)
122
+ word_embed = word_embed.view(batch_size * n_sample, seq_len, self.ni)
123
+
124
+ z_ = z.unsqueeze(2).expand(batch_size, n_sample, seq_len, self.nz).contiguous()
125
+ z_ = z_.view(batch_size * n_sample, seq_len, self.nz)
126
+
127
+ # (batch_size * n_sample, seq_len, ni + nz)
128
+ word_embed = torch.cat((word_embed, z_), -1)
129
+
130
+ z = z.view(batch_size * n_sample, self.nz)
131
+ c_init = self.trans_linear(z).unsqueeze(0)
132
+ h_init = torch.tanh(c_init)
133
+ # h_init = self.trans_linear(z).unsqueeze(0)
134
+ # c_init = h_init.new_zeros(h_init.size())
135
+ output, _ = self.lstm(word_embed, (h_init, c_init))
136
+
137
+ output = self.dropout_out(output)
138
+
139
+ # (batch_size * n_sample, seq_len, vocab_size)
140
+ output_logits = self.pred_linear(output)
141
+
142
+ return output_logits
143
+
144
+ def reconstruct_error(self, x, z):
145
+ """Cross Entropy in the language case
146
+ Args:
147
+ x: (batch_size, seq_len)
148
+ z: (batch_size, n_sample, nz)
149
+ Returns:
150
+ loss: (batch_size, n_sample). Loss
151
+ across different sentence and z
152
+ """
153
+
154
+ #remove end symbol
155
+ src = x[:, :-1]
156
+
157
+ # remove start symbol
158
+ tgt = x[:, 1:]
159
+
160
+ batch_size, seq_len = src.size()
161
+ n_sample = z.size(1)
162
+
163
+ # (batch_size * n_sample, seq_len, vocab_size)
164
+ output_logits = self.decode(src, z)
165
+
166
+ if n_sample == 1:
167
+ tgt = tgt.contiguous().view(-1)
168
+ else:
169
+ # (batch_size * n_sample * seq_len)
170
+ tgt = tgt.unsqueeze(1).expand(batch_size, n_sample, seq_len) \
171
+ .contiguous().view(-1)
172
+
173
+ # (batch_size * n_sample * seq_len)
174
+ loss = self.loss(output_logits.view(-1, output_logits.size(2)),
175
+ tgt)
176
+
177
+
178
+ # (batch_size, n_sample)
179
+ return loss.view(batch_size, n_sample, -1).sum(-1)
180
+
181
+
182
+ def log_probability(self, x, z):
183
+ """Cross Entropy in the language case
184
+ Args:
185
+ x: (batch_size, seq_len)
186
+ z: (batch_size, n_sample, nz)
187
+ Returns:
188
+ log_p: (batch_size, n_sample).
189
+ log_p(x|z) across different x and z
190
+ """
191
+
192
+ return -self.reconstruct_error(x, z)
193
+
194
+
195
+
196
+
197
+ def greedy_decode(self, z):
198
+ return self.sample_decode(z, greedy=True)
199
+
200
+ def sample_decode(self, z, greedy=False):
201
+ """sample/greedy decoding from z
202
+ Args:
203
+ z: (batch_size, nz)
204
+ Returns: List1
205
+ List1: the decoded word sentence list
206
+ """
207
+
208
+ batch_size = z.size(0)
209
+ decoded_batch = [[] for _ in range(batch_size)]
210
+
211
+ # (batch_size, 1, nz)
212
+ c_init = self.trans_linear(z).unsqueeze(0)
213
+ h_init = torch.tanh(c_init)
214
+
215
+ decoder_hidden = (h_init, c_init)
216
+ decoder_input = torch.tensor([self.vocab["<s>"]] * batch_size, dtype=torch.long, device=self.device).unsqueeze(1)
217
+ end_symbol = torch.tensor([self.vocab["</s>"]] * batch_size, dtype=torch.long, device=self.device)
218
+
219
+ mask = torch.ones((batch_size), dtype=torch.uint8, device=self.device)
220
+ length_c = 1
221
+ while mask.sum().item() != 0 and length_c < 100:
222
+
223
+ # (batch_size, 1, ni) --> (batch_size, 1, ni+nz)
224
+ word_embed = self.embed(decoder_input)
225
+ word_embed = torch.cat((word_embed, z.unsqueeze(1)), dim=-1)
226
+
227
+ output, decoder_hidden = self.lstm(word_embed, decoder_hidden)
228
+
229
+ # (batch_size, 1, vocab_size) --> (batch_size, vocab_size)
230
+ decoder_output = self.pred_linear(output)
231
+ output_logits = decoder_output.squeeze(1)
232
+
233
+ # (batch_size)
234
+ if greedy:
235
+ max_index = torch.argmax(output_logits, dim=1)
236
+ else:
237
+ probs = F.softmax(output_logits, dim=1)
238
+ max_index = torch.multinomial(probs, num_samples=1).squeeze(1)
239
+
240
+ decoder_input = max_index.unsqueeze(1)
241
+ length_c += 1
242
+
243
+ for i in range(batch_size):
244
+ word = self.vocab.id2word(max_index[i].item())
245
+ if mask[i].item():
246
+ decoded_batch[i].append(self.vocab.id2word(max_index[i].item()))
247
+
248
+ mask = torch.mul((max_index != end_symbol), mask)
249
+
250
+ return decoded_batch
251
+
252
+ class VarLSTMDecoder(LSTMDecoder):
253
+ """LSTM decoder with constant-length data"""
254
+ def __init__(self, args, vocab, model_init, emb_init):
255
+ super(VarLSTMDecoder, self).__init__(args, vocab, model_init, emb_init)
256
+
257
+ self.embed = nn.Embedding(len(vocab), args.ni, padding_idx=vocab['<pad>'])
258
+ vocab_mask = torch.ones(len(vocab))
259
+ vocab_mask[vocab['<pad>']] = 0
260
+ self.loss = nn.CrossEntropyLoss(weight=vocab_mask, reduce=False)
261
+
262
+ self.reset_parameters(model_init, emb_init)
263
+
264
+ def decode(self, input, z):
265
+ """
266
+ Args:
267
+ input: tuple which contains x and sents_len
268
+ x: (batch_size, seq_len)
269
+ sents_len: long tensor of sentence lengths
270
+ z: (batch_size, n_sample, nz)
271
+ """
272
+
273
+ input, sents_len = input
274
+
275
+ # not predicting start symbol
276
+ sents_len = sents_len - 1
277
+
278
+ batch_size, n_sample, _ = z.size()
279
+ seq_len = input.size(1)
280
+
281
+ # (batch_size, seq_len, ni)
282
+ word_embed = self.embed(input)
283
+ word_embed = self.dropout_in(word_embed)
284
+
285
+ if n_sample == 1:
286
+ z_ = z.expand(batch_size, seq_len, self.nz)
287
+
288
+ else:
289
+ word_embed = word_embed.unsqueeze(1).expand(batch_size, n_sample, seq_len, self.ni) \
290
+ .contiguous()
291
+
292
+ # (batch_size * n_sample, seq_len, ni)
293
+ word_embed = word_embed.view(batch_size * n_sample, seq_len, self.ni)
294
+
295
+ z_ = z.unsqueeze(2).expand(batch_size, n_sample, seq_len, self.nz).contiguous()
296
+ z_ = z_.view(batch_size * n_sample, seq_len, self.nz)
297
+
298
+ # (batch_size * n_sample, seq_len, ni + nz)
299
+ word_embed = torch.cat((word_embed, z_), -1)
300
+
301
+ sents_len = sents_len.unsqueeze(1).expand(batch_size, n_sample).contiguous().view(-1)
302
+ packed_embed = pack_padded_sequence(word_embed, sents_len.tolist(), batch_first=True)
303
+
304
+ z = z.view(batch_size * n_sample, self.nz)
305
+ # h_init = self.trans_linear(z).unsqueeze(0)
306
+ # c_init = h_init.new_zeros(h_init.size())
307
+ c_init = self.trans_linear(z).unsqueeze(0)
308
+ h_init = torch.tanh(c_init)
309
+ output, _ = self.lstm(packed_embed, (h_init, c_init))
310
+ output, _ = pad_packed_sequence(output, batch_first=True)
311
+
312
+ output = self.dropout_out(output)
313
+
314
+ # (batch_size * n_sample, seq_len, vocab_size)
315
+ output_logits = self.pred_linear(output)
316
+
317
+ return output_logits
318
+
319
+ def reconstruct_error(self, x, z):
320
+ """Cross Entropy in the language case
321
+ Args:
322
+ x: tuple which contains x_ and sents_len
323
+ x_: (batch_size, seq_len)
324
+ sents_len: long tensor of sentence lengths
325
+ z: (batch_size, n_sample, nz)
326
+ Returns:
327
+ loss: (batch_size, n_sample). Loss
328
+ across different sentence and z
329
+ """
330
+
331
+ x, sents_len = x
332
+
333
+ #remove end symbol
334
+ src = x[:, :-1]
335
+
336
+ # remove start symbol
337
+ tgt = x[:, 1:]
338
+
339
+ batch_size, seq_len = src.size()
340
+ n_sample = z.size(1)
341
+
342
+ # (batch_size * n_sample, seq_len, vocab_size)
343
+ output_logits = self.decode((src, sents_len), z)
344
+
345
+ if n_sample == 1:
346
+ tgt = tgt.contiguous().view(-1)
347
+ else:
348
+ # (batch_size * n_sample * seq_len)
349
+ tgt = tgt.unsqueeze(1).expand(batch_size, n_sample, seq_len) \
350
+ .contiguous().view(-1)
351
+
352
+ # (batch_size * n_sample * seq_len)
353
+ loss = self.loss(output_logits.view(-1, output_logits.size(2)),
354
+ tgt)
355
+
356
+
357
+ # (batch_size, n_sample)
358
+ return loss.view(batch_size, n_sample, -1).sum(-1)
Optimus/code/examples/big_ae/modules/decoders/decoder.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class DecoderBase(nn.Module):
6
+ """docstring for Decoder"""
7
+ def __init__(self):
8
+ super(DecoderBase, self).__init__()
9
+
10
+
11
+ def freeze(self):
12
+ for param in self.parameters():
13
+ param.requires_grad = False
14
+
15
+ def decode(self, x, z):
16
+ """
17
+ Args:
18
+ x: (batch_size, seq_len)
19
+ z: (batch_size, n_sample, nz)
20
+ Returns: Tensor1
21
+ Tensor1: the output logits with size (batch_size * n_sample, seq_len, vocab_size)
22
+ """
23
+
24
+ raise NotImplementedError
25
+
26
+ def reconstruct_error(self, x, z):
27
+ """reconstruction loss
28
+ Args:
29
+ x: (batch_size, *)
30
+ z: (batch_size, n_sample, nz)
31
+ Returns:
32
+ loss: (batch_size, n_sample). Loss
33
+ across different sentence and z
34
+ """
35
+
36
+ raise NotImplementedError
37
+
38
+ def beam_search_decode(self, z, K):
39
+ """beam search decoding
40
+ Args:
41
+ z: (batch_size, nz)
42
+ K: the beam size
43
+ Returns: List1
44
+ List1: the decoded word sentence list
45
+ """
46
+
47
+ raise NotImplementedError
48
+
49
+ def sample_decode(self, z):
50
+ """sampling from z
51
+ Args:
52
+ z: (batch_size, nz)
53
+ Returns: List1
54
+ List1: the decoded word sentence list
55
+ """
56
+
57
+ raise NotImplementedError
58
+
59
+ def greedy_decode(self, z):
60
+ """greedy decoding from z
61
+ Args:
62
+ z: (batch_size, nz)
63
+ Returns: List1
64
+ List1: the decoded word sentence list
65
+ """
66
+
67
+ raise NotImplementedError
68
+
69
+ def log_probability(self, x, z):
70
+ """
71
+ Args:
72
+ x: (batch_size, *)
73
+ z: (batch_size, n_sample, nz)
74
+ Returns:
75
+ log_p: (batch_size, n_sample).
76
+ log_p(x|z) across different x and z
77
+ """
78
+
79
+ raise NotImplementedError
Optimus/code/examples/big_ae/modules/encoders/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .enc_lstm import *
Optimus/code/examples/big_ae/modules/encoders/enc_lstm.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from itertools import chain
2
+ import math
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence
7
+ from .gaussian_encoder import GaussianEncoderBase
8
+ from ..utils import log_sum_exp
9
+
10
+ class GaussianLSTMEncoder(GaussianEncoderBase):
11
+ """Gaussian LSTM Encoder with constant-length input"""
12
+ def __init__(self, args, vocab_size, model_init, emb_init):
13
+ super(GaussianLSTMEncoder, self).__init__()
14
+ self.ni = args.ni
15
+ self.nh = args.enc_nh
16
+ self.nz = args.nz
17
+ self.args = args
18
+
19
+ self.embed = nn.Embedding(vocab_size, args.ni)
20
+
21
+ self.lstm = nn.LSTM(input_size=args.ni,
22
+ hidden_size=args.enc_nh,
23
+ num_layers=1,
24
+ batch_first=True,
25
+ dropout=0)
26
+
27
+ self.linear = nn.Linear(args.enc_nh, 2 * args.nz, bias=False)
28
+
29
+ self.reset_parameters(model_init, emb_init)
30
+
31
+ def reset_parameters(self, model_init, emb_init):
32
+ # for name, param in self.lstm.named_parameters():
33
+ # # self.initializer(param)
34
+ # if 'bias' in name:
35
+ # nn.init.constant_(param, 0.0)
36
+ # # model_init(param)
37
+ # elif 'weight' in name:
38
+ # model_init(param)
39
+
40
+ # model_init(self.linear.weight)
41
+ # emb_init(self.embed.weight)
42
+ for param in self.parameters():
43
+ model_init(param)
44
+ emb_init(self.embed.weight)
45
+
46
+
47
+ def forward(self, input):
48
+ """
49
+ Args:
50
+ x: (batch_size, seq_len)
51
+ Returns: Tensor1, Tensor2
52
+ Tensor1: the mean tensor, shape (batch, nz)
53
+ Tensor2: the logvar tensor, shape (batch, nz)
54
+ """
55
+
56
+ # (batch_size, seq_len-1, args.ni)
57
+ word_embed = self.embed(input)
58
+
59
+ _, (last_state, last_cell) = self.lstm(word_embed)
60
+
61
+ mean, logvar = self.linear(last_state).chunk(2, -1)
62
+
63
+ # fix variance as a pre-defined value
64
+ if self.args.fix_var > 0:
65
+ logvar = mean.new_tensor([[[math.log(self.args.fix_var)]]]).expand_as(mean)
66
+
67
+ return mean.squeeze(0), logvar.squeeze(0)
68
+
69
+ # def eval_inference_mode(self, x):
70
+ # """compute the mode points in the inference distribution
71
+ # (in Gaussian case)
72
+ # Returns: Tensor
73
+ # Tensor: the posterior mode points with shape (*, nz)
74
+ # """
75
+
76
+ # # (batch_size, nz)
77
+ # mu, logvar = self.forward(x)
78
+
79
+
80
+ class VarLSTMEncoder(GaussianLSTMEncoder):
81
+ """Gaussian LSTM Encoder with variable-length input"""
82
+ def __init__(self, args, vocab_size, model_init, emb_init):
83
+ super(VarLSTMEncoder, self).__init__(args, vocab_size, model_init, emb_init)
84
+
85
+
86
+ def forward(self, input):
87
+ """
88
+ Args:
89
+ input: tuple which contains x and sents_len
90
+ x: (batch_size, seq_len)
91
+ sents_len: long tensor of sentence lengths
92
+ Returns: Tensor1, Tensor2
93
+ Tensor1: the mean tensor, shape (batch, nz)
94
+ Tensor2: the logvar tensor, shape (batch, nz)
95
+ """
96
+
97
+ input, sents_len = input
98
+ # (batch_size, seq_len, args.ni)
99
+ word_embed = self.embed(input)
100
+
101
+ packed_embed = pack_padded_sequence(word_embed, sents_len.tolist(), batch_first=True)
102
+
103
+ _, (last_state, last_cell) = self.lstm(packed_embed)
104
+
105
+ mean, logvar = self.linear(last_state).chunk(2, -1)
106
+
107
+ return mean.squeeze(0), logvar.squeeze(0)
108
+
109
+ def encode(self, input, nsamples):
110
+ """perform the encoding and compute the KL term
111
+ Args:
112
+ input: tuple which contains x and sents_len
113
+ Returns: Tensor1, Tensor2
114
+ Tensor1: the tensor latent z with shape [batch, nsamples, nz]
115
+ Tensor2: the tenor of KL for each x with shape [batch]
116
+ """
117
+
118
+ # (batch_size, nz)
119
+ mu, logvar = self.forward(input)
120
+
121
+ # (batch, nsamples, nz)
122
+ z = self.reparameterize(mu, logvar, nsamples)
123
+
124
+ KL = 0.5 * (mu.pow(2) + logvar.exp() - logvar - 1).sum(dim=1)
125
+
126
+ return z, KL
Optimus/code/examples/big_ae/modules/encoders/encoder.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ from ..utils import log_sum_exp
6
+
7
+ class EncoderBase(nn.Module):
8
+ """docstring for EncoderBase"""
9
+ def __init__(self):
10
+ super(EncoderBase, self).__init__()
11
+
12
+ def forward(self, x):
13
+ """
14
+ Args:
15
+ x: (batch_size, *)
16
+ Returns: the tensors required to parameterize a distribution.
17
+ E.g. for Gaussian encoder it returns the mean and variance tensors
18
+ """
19
+
20
+ raise NotImplementedError
21
+
22
+ def sample(self, input, nsamples):
23
+ """sampling from the encoder
24
+ Returns: Tensor1
25
+ Tensor1: the tensor latent z with shape [batch, nsamples, nz]
26
+ """
27
+
28
+ raise NotImplementedError
29
+
30
+ def encode(self, input, nsamples):
31
+ """perform the encoding and compute the KL term
32
+ Returns: Tensor1, Tensor2
33
+ Tensor1: the tensor latent z with shape [batch, nsamples, nz]
34
+ Tensor2: the tenor of KL for each x with shape [batch]
35
+ """
36
+
37
+ raise NotImplementedError
38
+
39
+
40
+ def eval_inference_dist(self, x, z, param=None):
41
+ """this function computes log q(z | x)
42
+ Args:
43
+ z: tensor
44
+ different z points that will be evaluated, with
45
+ shape [batch, nsamples, nz]
46
+ Returns: Tensor1
47
+ Tensor1: log q(z|x) with shape [batch, nsamples]
48
+ """
49
+
50
+ raise NotImplementedError
51
+
52
+ def calc_mi(self, x):
53
+ """Approximate the mutual information between x and z
54
+ I(x, z) = E_xE_{q(z|x)}log(q(z|x)) - E_xE_{q(z|x)}log(q(z))
55
+ Returns: Float
56
+ """
57
+
58
+ raise NotImplementedError
Optimus/code/examples/big_ae/modules/encoders/gaussian_encoder.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ from .encoder import EncoderBase
6
+ from ..utils import log_sum_exp
7
+
8
+ class GaussianEncoderBase(EncoderBase):
9
+ """docstring for EncoderBase"""
10
+ def __init__(self):
11
+ super(GaussianEncoderBase, self).__init__()
12
+
13
+ def freeze(self):
14
+ for param in self.parameters():
15
+ param.requires_grad = False
16
+
17
+ def forward(self, x):
18
+ """
19
+ Args:
20
+ x: (batch_size, *)
21
+ Returns: Tensor1, Tensor2
22
+ Tensor1: the mean tensor, shape (batch, nz)
23
+ Tensor2: the logvar tensor, shape (batch, nz)
24
+ """
25
+
26
+ raise NotImplementedError
27
+
28
+ def encode_stats(self, x):
29
+
30
+ return self.forward(x)
31
+
32
+ def sample(self, input, nsamples):
33
+ """sampling from the encoder
34
+ Returns: Tensor1
35
+ Tensor1: the tensor latent z with shape [batch, nsamples, nz]
36
+ """
37
+
38
+ # (batch_size, nz)
39
+ mu, logvar = self.forward(input)
40
+
41
+ # (batch, nsamples, nz)
42
+ z = self.reparameterize(mu, logvar, nsamples)
43
+
44
+ return z, (mu, logvar)
45
+
46
+ def encode(self, input, nsamples):
47
+ """perform the encoding and compute the KL term
48
+ Returns: Tensor1, Tensor2
49
+ Tensor1: the tensor latent z with shape [batch, nsamples, nz]
50
+ Tensor2: the tenor of KL for each x with shape [batch]
51
+ """
52
+
53
+ # (batch_size, nz)
54
+ mu, logvar = self.forward(input)
55
+
56
+ # (batch, nsamples, nz)
57
+ z = self.reparameterize(mu, logvar, nsamples)
58
+
59
+ KL = 0.5 * (mu.pow(2) + logvar.exp() - logvar - 1).sum(dim=1)
60
+
61
+ return z, KL
62
+
63
+ def reparameterize(self, mu, logvar, nsamples=1):
64
+ """sample from posterior Gaussian family
65
+ Args:
66
+ mu: Tensor
67
+ Mean of gaussian distribution with shape (batch, nz)
68
+ logvar: Tensor
69
+ logvar of gaussian distibution with shape (batch, nz)
70
+ Returns: Tensor
71
+ Sampled z with shape (batch, nsamples, nz)
72
+ """
73
+ batch_size, nz = mu.size()
74
+ std = logvar.mul(0.5).exp()
75
+
76
+ mu_expd = mu.unsqueeze(1).expand(batch_size, nsamples, nz)
77
+ std_expd = std.unsqueeze(1).expand(batch_size, nsamples, nz)
78
+
79
+ eps = torch.zeros_like(std_expd).normal_()
80
+
81
+ return mu_expd + torch.mul(eps, std_expd)
82
+
83
+ def eval_inference_dist(self, x, z, param=None):
84
+ """this function computes log q(z | x)
85
+ Args:
86
+ z: tensor
87
+ different z points that will be evaluated, with
88
+ shape [batch, nsamples, nz]
89
+ Returns: Tensor1
90
+ Tensor1: log q(z|x) with shape [batch, nsamples]
91
+ """
92
+
93
+ nz = z.size(2)
94
+
95
+ if not param:
96
+ mu, logvar = self.forward(x)
97
+ else:
98
+ mu, logvar = param
99
+
100
+ # (batch_size, 1, nz)
101
+ mu, logvar = mu.unsqueeze(1), logvar.unsqueeze(1)
102
+ var = logvar.exp()
103
+
104
+ # (batch_size, nsamples, nz)
105
+ dev = z - mu
106
+
107
+ # (batch_size, nsamples)
108
+ log_density = -0.5 * ((dev ** 2) / var).sum(dim=-1) - \
109
+ 0.5 * (nz * math.log(2 * math.pi) + logvar.sum(-1))
110
+
111
+ return log_density
112
+
113
+
114
+
115
+ def calc_mi(self, x):
116
+ """Approximate the mutual information between x and z
117
+ I(x, z) = E_xE_{q(z|x)}log(q(z|x)) - E_xE_{q(z|x)}log(q(z))
118
+ Returns: Float
119
+ """
120
+
121
+ # [x_batch, nz]
122
+ mu, logvar = self.forward(x)
123
+
124
+ x_batch, nz = mu.size()
125
+
126
+ # E_{q(z|x)}log(q(z|x)) = -0.5*nz*log(2*\pi) - 0.5*(1+logvar).sum(-1)
127
+ neg_entropy = (-0.5 * nz * math.log(2 * math.pi)- 0.5 * (1 + logvar).sum(-1)).mean()
128
+
129
+ # [z_batch, 1, nz]
130
+ z_samples = self.reparameterize(mu, logvar, 1)
131
+
132
+ # [1, x_batch, nz]
133
+ mu, logvar = mu.unsqueeze(0), logvar.unsqueeze(0)
134
+ var = logvar.exp()
135
+
136
+ # (z_batch, x_batch, nz)
137
+ dev = z_samples - mu
138
+
139
+ # (z_batch, x_batch)
140
+ log_density = -0.5 * ((dev ** 2) / var).sum(dim=-1) - \
141
+ 0.5 * (nz * math.log(2 * math.pi) + logvar.sum(-1))
142
+
143
+ # log q(z): aggregate posterior
144
+ # [z_batch]
145
+ log_qz = log_sum_exp(log_density, dim=1) - math.log(x_batch)
146
+
147
+ return (neg_entropy - log_qz.mean(-1)).item()
Optimus/code/examples/big_ae/modules/spacefusion.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .vae import VAE
2
+ import numpy as np
3
+ import torch, copy, pdb
4
+ import torch.nn.functional as F
5
+
6
+ from torch import nn
7
+
8
+ import pdb
9
+
10
+
11
+ def set_trainable(module, value):
12
+ for param in module.parameters():
13
+ param.requires_grad = value
14
+
15
+ class SpaceFusion(VAE):
16
+ def __init__(self, encoder, decoder, tokenizer_encoder, tokenizer_decoder, args):
17
+ super(SpaceFusion, self).__init__(encoder, decoder, tokenizer_encoder, tokenizer_decoder, args)
18
+ children = [v for v in encoder.encoder.layer.children()] # list of 12 BertLayer
19
+
20
+ self.num_s2s_bert_layer = args.num_s2s_bert_layer
21
+ self.S2S_layers = nn.ModuleList([copy.deepcopy(c) for c in children[-args.num_s2s_bert_layer:] ]) # the last layer of encoder
22
+ self.S2S_pooler = copy.deepcopy(encoder.pooler)
23
+ self.ix_turn_sep = tokenizer_encoder.convert_tokens_to_ids('[SEP]')
24
+ if args.freeze_bert:
25
+ print('@'*20 + f' freezing BERT {args.num_frozen_bert_layer} layers')
26
+ for child in children[:args.num_frozen_bert_layer]:
27
+ set_trainable(child, False)
28
+
29
+
30
+
31
+ def ids2speaker(self, ids):
32
+ # 0 for speaker A, 1 for speaker B
33
+ N, T = ids.shape
34
+ speaker = np.zeros((N, T))
35
+ sep = ids == self.ix_turn_sep
36
+ for i in range(N):
37
+ is_B = False # start with speaker A
38
+ for t in range(T):
39
+ speaker[i,t] = int(is_B)
40
+ if sep[i,t].item():
41
+ is_B = not is_B
42
+
43
+ # make sure the final speaker is speaker B (so response is always speaker A)
44
+ if not is_B:
45
+ speaker = 1 - speaker
46
+
47
+ return torch.LongTensor(speaker).to(ids.device)
48
+
49
+ def forward(self, inputs_src, inputs_tgt, labels_tgt, return_vec=False): # [batch, time]
50
+ # toggle config to get desired encoder output
51
+ self.encoder.encoder.output_attentions = False
52
+ self.encoder.encoder.output_hidden_states = True
53
+
54
+
55
+ # AE encoder
56
+ mask = (inputs_tgt > 0).float().to(inputs_src.device)
57
+ outputs = self.encoder(inputs_tgt, attention_mask=mask)
58
+ z_AE, _ = self.connect(outputs[1])
59
+ z_AE = z_AE.squeeze(1)
60
+
61
+ # S2S encoder
62
+ mask = (inputs_src > 0).float()
63
+ speaker = self.ids2speaker(inputs_src)
64
+ outputs = self.encoder(inputs_src, attention_mask=mask, token_type_ids=speaker)
65
+ _, _, all_layer_attn = outputs # last_layer_attn, pooled, all_layer_attn = outputs
66
+ seq_z_prev = all_layer_attn[-self.num_s2s_bert_layer-1] # seq of z at layer 11 ()
67
+
68
+ for s2s in self.S2S_layers:
69
+ layer_outputs = s2s(seq_z_prev, attention_mask=mask.unsqueeze(1).unsqueeze(1))
70
+ seq_z_prev = layer_outputs[0]
71
+
72
+ z_S2S = self.encoder.pooler(layer_outputs[0])
73
+ z_S2S, _ = self.connect(z_S2S)
74
+ z_S2S = z_S2S.squeeze(1)
75
+
76
+ if return_vec:
77
+ return z_AE, z_S2S
78
+
79
+ # interpolation/smoothness
80
+ u = torch.FloatTensor(np.random.random((z_AE.shape[0], 1))).to(inputs_tgt.device)
81
+ z_interp = u * z_AE + (1 - u) * z_S2S
82
+ std = 0.1
83
+ noise = torch.FloatTensor(np.random.normal(size=z_interp.shape) * std).to(z_interp.device)
84
+ z_interp = z_interp + noise
85
+
86
+ loss_rec = 0
87
+ z_idx = 0
88
+ for z in [z_AE, z_S2S, z_interp]:
89
+ #pdb.set_trace()
90
+ past = z # past = self.decoder.linear(z)
91
+ outputs = self.decoder(input_ids=labels_tgt, past=past, labels=labels_tgt, label_ignore=self.pad_token_id)
92
+ if z_idx == 1:
93
+ loss_rec = loss_rec + 1.0 * outputs[0]
94
+ else:
95
+ loss_rec = loss_rec + outputs[0]
96
+ z_idx += 1
97
+ loss_rec = loss_rec/3
98
+
99
+ # fusion/regularization
100
+ L_pull = self.dist_pair(z_AE, z_S2S)
101
+ L_push = torch.stack([self.dist_batch(z) for z in [z_AE, z_S2S]]).min()
102
+ loss_reg = (L_pull - L_push * 2) / np.sqrt(z.shape[-1])
103
+
104
+ loss = loss_rec + self.args.beta * loss_reg
105
+ return loss_rec, loss_reg, loss
106
+
107
+ def sent2latent(self, inputs_src):
108
+ # toggle config to get desired encoder output
109
+ self.encoder.encoder.output_attentions = False
110
+ self.encoder.encoder.output_hidden_states = True
111
+
112
+ # S2S encoder
113
+ mask = (inputs_src > 0).float()
114
+ speaker = self.ids2speaker(inputs_src)
115
+ outputs = self.encoder(inputs_src, attention_mask=mask, token_type_ids=speaker)
116
+
117
+ _, _, all_layer_attn = outputs # last_layer_attn, pooled, all_layer_attn = outputs
118
+ # seq_z_prev = all_layer_attn[-2] # seq of z at layer 11 ()
119
+ # layer_outputs = self.S2S_layer(seq_z_prev, attention_mask=mask.unsqueeze(1).unsqueeze(1))
120
+
121
+ seq_z_prev = all_layer_attn[-self.num_s2s_bert_layer-1] # seq of z at layer 11 ()
122
+ for s2s in self.S2S_layers:
123
+ layer_outputs = s2s(seq_z_prev, attention_mask=mask.unsqueeze(1).unsqueeze(1))
124
+ seq_z_prev = layer_outputs[0]
125
+
126
+ z_S2S = self.encoder.pooler(layer_outputs[0])
127
+ z_S2S, _ = self.connect(z_S2S)
128
+ z_S2S = z_S2S.squeeze(1)
129
+
130
+ return z_S2S
131
+
132
+
133
+ def dist_pair(self, a, b):
134
+ return F.pairwise_distance(a, b).mean()
135
+
136
+
137
+ def dist_batch(self, vec):
138
+ n = vec.shape[0]
139
+ dmin = []
140
+ for i in range(n):
141
+ dd = F.pairwise_distance(vec[i:i+1,:].repeat(n,1), vec)
142
+ dmin.append(dd.min())
143
+ return torch.stack(dmin).mean()
Optimus/code/examples/big_ae/modules/utils.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ def safe_log(z):
4
+ return torch.log(z + 1e-7)
5
+
6
+ def log_sum_exp(value, dim=None, keepdim=False):
7
+ """Numerically stable implementation of the operation
8
+ value.exp().sum(dim, keepdim).log()
9
+ """
10
+ if dim is not None:
11
+ m, _ = torch.max(value, dim=dim, keepdim=True)
12
+ value0 = value - m
13
+ if keepdim is False:
14
+ m = m.squeeze(dim)
15
+ return m + torch.log(torch.sum(torch.exp(value0), dim=dim, keepdim=keepdim))
16
+ else:
17
+ m = torch.max(value)
18
+ sum_exp = torch.sum(torch.exp(value - m))
19
+ return m + torch.log(sum_exp)
20
+
21
+
22
+ def generate_grid(zmin, zmax, dz, device, ndim=2):
23
+ """generate a 1- or 2-dimensional grid
24
+ Returns: Tensor, int
25
+ Tensor: The grid tensor with shape (k^2, 2),
26
+ where k=(zmax - zmin)/dz
27
+ int: k
28
+ """
29
+
30
+ if ndim == 2:
31
+ x = torch.arange(zmin, zmax, dz)
32
+ k = x.size(0)
33
+
34
+ x1 = x.unsqueeze(1).repeat(1, k).view(-1)
35
+ x2 = x.repeat(k)
36
+
37
+ return torch.cat((x1.unsqueeze(-1), x2.unsqueeze(-1)), dim=-1).to(device), k
38
+
39
+ elif ndim == 1:
40
+ return torch.arange(zmin, zmax, dz).unsqueeze(1).to(device)
Optimus/code/examples/big_ae/modules/vae.py ADDED
@@ -0,0 +1,638 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ from .utils import log_sum_exp
6
+
7
+ import pdb
8
+
9
+ import logging
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ class VAE(nn.Module):
14
+ """VAE with normal prior"""
15
+ def __init__(self, encoder, decoder, tokenizer_encoder, tokenizer_decoder, args): #
16
+ super(VAE, self).__init__()
17
+ self.encoder = encoder
18
+ self.decoder = decoder
19
+
20
+ self.args = args
21
+ self.nz = args.latent_size
22
+
23
+ self.eos_token_id = tokenizer_decoder.convert_tokens_to_ids([tokenizer_decoder.eos_token])[0]
24
+ self.pad_token_id = tokenizer_decoder.convert_tokens_to_ids([tokenizer_decoder.pad_token])[0]
25
+
26
+
27
+ # connector: from Bert hidden units to the latent space
28
+ # self.linear = nn.Linear(args.nz, 2 * args.nz, bias=False)
29
+
30
+ # Standard Normal prior
31
+ loc = torch.zeros(self.nz, device=args.device)
32
+ scale = torch.ones(self.nz, device=args.device)
33
+ self.prior = torch.distributions.normal.Normal(loc, scale)
34
+
35
+ def connect(self, bert_fea, nsamples=1):
36
+ """
37
+ Returns: Tensor1, Tensor2
38
+ Tensor1: the tensor latent z with shape [batch, nsamples, nz]
39
+ Tensor2: the tenor of KL for each x with shape [batch]
40
+ """
41
+
42
+ # (batch_size, nz)
43
+
44
+ mean, logvar = self.encoder.linear(bert_fea).chunk(2, -1)
45
+ # pdb.set_trace()
46
+ # mean, logvar = mean.squeeze(0), logvar.squeeze(0)
47
+
48
+ # (batch, nsamples, nz)
49
+ z = self.reparameterize(mean, logvar, nsamples)
50
+ KL = 0.5 * (mean.pow(2) + logvar.exp() - logvar - 1).sum(dim=1)
51
+
52
+ return z, KL
53
+
54
+ def connect_deterministic(self, bert_fea, nsamples=1):
55
+ """
56
+ Returns: Tensor1, Tensor2
57
+ Tensor1: the tensor latent z with shape [batch, nsamples, nz]
58
+ Tensor2: the tenor of KL for each x with shape [batch]
59
+ """
60
+
61
+ # (batch_size, nz)
62
+
63
+ mean, logvar = self.encoder.linear(bert_fea).chunk(2, -1)
64
+ # pdb.set_trace()
65
+ # mean, logvar = mean.squeeze(0), logvar.squeeze(0)
66
+
67
+ logvar.fill_(.0)
68
+ # (batch, nsamples, nz)
69
+ z = self.reparameterize(mean, logvar, nsamples)
70
+ KL = 0.5 * (mean.pow(2) + logvar.exp() - logvar - 1).sum(dim=1)
71
+
72
+ return z, KL
73
+
74
+
75
+
76
+ def reparameterize(self, mu, logvar, nsamples=1):
77
+ """sample from posterior Gaussian family
78
+ Args:
79
+ mu: Tensor
80
+ Mean of gaussian distribution with shape (batch, nz)
81
+ logvar: Tensor
82
+ logvar of gaussian distibution with shape (batch, nz)
83
+ Returns: Tensor
84
+ Sampled z with shape (batch, nsamples, nz)
85
+ """
86
+ batch_size, nz = mu.size()
87
+ std = logvar.mul(0.5).exp()
88
+
89
+ mu_expd = mu.unsqueeze(1).expand(batch_size, nsamples, nz)
90
+ std_expd = std.unsqueeze(1).expand(batch_size, nsamples, nz)
91
+
92
+ eps = torch.zeros_like(std_expd).normal_()
93
+
94
+ return mu_expd + torch.mul(eps, std_expd)
95
+
96
+ def forward(self, inputs, labels):
97
+
98
+ # pdb.set_trace()
99
+
100
+ attention_mask=(inputs > 0).float()
101
+ # logger.info(inputs)
102
+ # logger.info(attention_mask)
103
+ # logger.info(labels)
104
+ reconstrution_mask=(labels != 50257).float() # 50257 is the padding token for GPT2
105
+ sent_length = torch.sum(reconstrution_mask, dim=1)
106
+
107
+
108
+ outputs = self.encoder(inputs, attention_mask)
109
+ pooled_hidden_fea = outputs[1] # model outputs are always tuple in pytorch-transformers (see doc)
110
+
111
+ if self.args.fb_mode==0:
112
+ # Connect hidden feature to the latent space
113
+ latent_z, loss_kl = self.connect(pooled_hidden_fea)
114
+ latent_z = latent_z.squeeze(1)
115
+
116
+
117
+ # Decoding
118
+ outputs = self.decoder(input_ids=labels, past=latent_z, labels=labels, label_ignore=self.pad_token_id)
119
+ loss_rec = outputs[0] # model outputs are always tuple in pytorch-transformers (see doc)
120
+
121
+ elif self.args.fb_mode==1:
122
+ # Connect hidden feature to the latent space
123
+ mu, logvar = self.encoder.linear(pooled_hidden_fea).chunk(2, -1)
124
+ latent_z = self.reparameterize(mu, logvar, nsamples=1)
125
+ latent_z = latent_z.squeeze(1)
126
+ loss_kl = 0.5 * (mu.pow(2) + logvar.exp() - logvar - 1)
127
+ kl_mask = (loss_kl > self.args.dim_target_kl).float()
128
+ loss_kl = (kl_mask * loss_kl).sum(dim=1)
129
+
130
+ # pdb.set_trace()
131
+ # past = self.decoder.linear(latent_z)
132
+ # Decoding
133
+ outputs = self.decoder(input_ids=labels, past=latent_z, labels=labels, label_ignore=self.pad_token_id)
134
+ loss_rec = outputs[0] # model outputs are always tuple in pytorch-transformers (see doc)
135
+
136
+ elif self.args.fb_mode==2:
137
+ # Connect hidden feature to the latent space
138
+ latent_z, loss_kl = self.connect_deterministic(pooled_hidden_fea)
139
+ latent_z = latent_z.squeeze(1)
140
+
141
+ # past = self.decoder.linear(latent_z)
142
+ # Decoding
143
+ outputs = self.decoder(input_ids=labels, past=latent_z, labels=labels, label_ignore=self.pad_token_id)
144
+ loss_rec = outputs[0] # model outputs are always tuple in pytorch-transformers (see doc)
145
+
146
+
147
+ # pdb.set_trace()
148
+ if self.args.length_weighted_loss:
149
+ loss = loss_rec / sent_length + self.args.beta * loss_kl
150
+ else:
151
+ loss = loss_rec + self.args.beta * loss_kl
152
+
153
+
154
+ return loss_rec, loss_kl, loss
155
+
156
+
157
+
158
+ def encoder_sample(self, bert_fea, nsamples):
159
+ """sampling from the encoder
160
+ Returns: Tensor1
161
+ Tensor1: the tensor latent z with shape [batch, nsamples, nz]
162
+ """
163
+
164
+ # (batch_size, nz)
165
+
166
+ mu, logvar = self.encoder.linear(bert_fea).chunk(2, -1)
167
+ mu, logvar = mu.squeeze(0), logvar.squeeze(0)
168
+
169
+ # (batch, nsamples, nz)
170
+ z = self.reparameterize(mu, logvar, nsamples)
171
+
172
+ return z, (mu, logvar)
173
+
174
+
175
+ def encode_stats(self, x):
176
+ """
177
+ Returns: Tensor1, Tensor2
178
+ Tensor1: the mean of latent z with shape [batch, nz]
179
+ Tensor2: the logvar of latent z with shape [batch, nz]
180
+ """
181
+
182
+ return self.encoder.encode_stats(x)
183
+
184
+ def decode(self, z, strategy, K=10):
185
+ """generate samples from z given strategy
186
+ Args:
187
+ z: [batch, nsamples, nz]
188
+ strategy: "beam" or "greedy" or "sample"
189
+ K: the beam width parameter
190
+ Returns: List1
191
+ List1: a list of decoded word sequence
192
+ """
193
+
194
+ if strategy == "beam":
195
+ return self.decoder.beam_search_decode(z, K)
196
+ elif strategy == "greedy":
197
+ return self.decoder.greedy_decode(z)
198
+ elif strategy == "sample":
199
+ return self.decoder.sample_decode(z)
200
+ else:
201
+ raise ValueError("the decoding strategy is not supported")
202
+
203
+
204
+ def reconstruct(self, x, decoding_strategy="greedy", K=5):
205
+ """reconstruct from input x
206
+ Args:
207
+ x: (batch, *)
208
+ decoding_strategy: "beam" or "greedy" or "sample"
209
+ K: the beam width parameter
210
+ Returns: List1
211
+ List1: a list of decoded word sequence
212
+ """
213
+ z = self.sample_from_inference(x).squeeze(1)
214
+
215
+ return self.decode(z, decoding_strategy, K)
216
+
217
+ def log_probability(self, x, z):
218
+ """Cross Entropy in the language case
219
+ Args:
220
+ x: (batch_size, seq_len)
221
+ z: (batch_size, n_sample, nz)
222
+ Returns:
223
+ log_p: (batch_size, n_sample).
224
+ log_p(x|z) across different x and z
225
+ """
226
+ outputs = self.decoder(input_ids=x, past=z, labels=x, label_ignore=self.pad_token_id)
227
+ loss_rec = outputs[0]
228
+ return -loss_rec
229
+
230
+
231
+
232
+ def loss_iw(self, x0, x1, nsamples=50, ns=1):
233
+ """
234
+ Args:
235
+ x: if the data is constant-length, x is the data tensor with
236
+ shape (batch, *). Otherwise x is a tuple that contains
237
+ the data tensor and length list
238
+ Returns: Tensor1, Tensor2, Tensor3
239
+ Tensor1: total loss [batch]
240
+ Tensor2: reconstruction loss shape [batch]
241
+ Tensor3: KL loss shape [batch]
242
+ """
243
+
244
+ # encoding into bert features
245
+ bert_fea = self.encoder(x0)[1]
246
+
247
+ # (batch_size, nz)
248
+
249
+ mu, logvar = self.encoder.linear(bert_fea).chunk(2, -1)
250
+
251
+
252
+ ##################
253
+ # compute KL
254
+ ##################
255
+ # pdb.set_trace()
256
+ KL = 0.5 * (mu.pow(2) + logvar.exp() - logvar - 1).sum(dim=1)
257
+
258
+ # mu, logvar = mu.squeeze(0), logvar.squeeze(0)
259
+ ll_tmp, rc_tmp = [], []
260
+ for _ in range(int(nsamples / ns)):
261
+
262
+ # (batch, nsamples, nz)
263
+ z = self.reparameterize(mu, logvar, ns)
264
+ # past = self.decoder.linear(z)
265
+ past = z
266
+
267
+ # [batch, nsamples]
268
+ log_prior = self.eval_prior_dist(z)
269
+ log_gen = self.eval_cond_ll(x1, past)
270
+ log_infer = self.eval_inference_dist(z, (mu, logvar))
271
+
272
+ # pdb.set_trace()
273
+ log_gen = log_gen.unsqueeze(0).contiguous().view(z.shape[0],-1)
274
+
275
+
276
+ # pdb.set_trace()
277
+ rc_tmp.append(log_gen)
278
+ ll_tmp.append(log_gen + log_prior - log_infer)
279
+
280
+
281
+
282
+ log_prob_iw = log_sum_exp(torch.cat(ll_tmp, dim=-1), dim=-1) - math.log(nsamples)
283
+ log_gen_iw = torch.mean(torch.cat(rc_tmp, dim=-1), dim=-1)
284
+
285
+ return log_prob_iw, log_gen_iw , KL
286
+
287
+
288
+ def nll_iw(self, x0, x1, nsamples, ns=1):
289
+ """compute the importance weighting estimate of the log-likelihood
290
+ Args:
291
+ x0, x1: two different tokenization results of x, where x is the data tensor with shape (batch, *).
292
+ nsamples: Int
293
+ the number of samples required to estimate marginal data likelihood
294
+ Returns: Tensor1
295
+ Tensor1: the estimate of log p(x), shape [batch]
296
+ """
297
+
298
+ # compute iw every ns samples to address the memory issue
299
+ # nsamples = 500, ns = 100
300
+ # nsamples = 500, ns = 10
301
+
302
+ # TODO: note that x is forwarded twice in self.encoder.sample(x, ns) and self.eval_inference_dist(x, z, param)
303
+ #. this problem is to be solved in order to speed up
304
+
305
+ tmp = []
306
+ for _ in range(int(nsamples / ns)):
307
+ # [batch, ns, nz]
308
+
309
+ # Chunyuan:
310
+ # encoding into bert features
311
+ pooled_hidden_fea = self.encoder(x0)[1]
312
+
313
+ # param is the parameters required to evaluate q(z|x)
314
+ z, param = self.encoder_sample(pooled_hidden_fea, ns)
315
+
316
+ # [batch, ns]
317
+ log_comp_ll = self.eval_complete_ll(x1, z)
318
+ log_infer_ll = self.eval_inference_dist(z, param)
319
+
320
+ tmp.append(log_comp_ll - log_infer_ll)
321
+
322
+ ll_iw = log_sum_exp(torch.cat(tmp, dim=-1), dim=-1) - math.log(nsamples)
323
+
324
+ return ll_iw
325
+
326
+ def KL(self, x):
327
+ _, KL = self.encode(x, 1)
328
+
329
+ return KL
330
+
331
+ def eval_prior_dist(self, zrange):
332
+ """perform grid search to calculate the true posterior
333
+ Args:
334
+ zrange: tensor
335
+ different z points that will be evaluated, with
336
+ shape (k^2, nz), where k=(zmax - zmin)/space
337
+ """
338
+
339
+ # (k^2)
340
+ return self.prior.log_prob(zrange).sum(dim=-1)
341
+
342
+ def eval_complete_ll(self, x, z):
343
+ """compute log p(z,x)
344
+ Args:
345
+ x: Tensor
346
+ input with shape [batch, seq_len]
347
+ z: Tensor
348
+ evaluation points with shape [batch, nsamples, nz]
349
+ Returns: Tensor1
350
+ Tensor1: log p(z,x) Tensor with shape [batch, nsamples]
351
+ """
352
+
353
+ # [batch, nsamples]
354
+ log_prior = self.eval_prior_dist(z)
355
+ log_gen = self.eval_cond_ll(x, z)
356
+
357
+ return log_prior + log_gen
358
+
359
+
360
+
361
+ def eval_cond_ll(self, x, z):
362
+ """compute log p(x|z)
363
+ """
364
+ x_shape = list(x.size())
365
+ z_shape = list(z.size())
366
+ if len(z_shape) == 3:
367
+ x = x.unsqueeze(1).repeat(1, z_shape[1], 1).contiguous().view(x_shape[0]*z_shape[1], x_shape[-1])
368
+ z = z.contiguous().view(x_shape[0]*z_shape[1], z_shape[-1])
369
+
370
+ return self.log_probability(x, z)
371
+
372
+
373
+
374
+ def eval_log_model_posterior(self, x, grid_z):
375
+ """perform grid search to calculate the true posterior
376
+ this function computes p(z|x)
377
+ Args:
378
+ grid_z: tensor
379
+ different z points that will be evaluated, with
380
+ shape (k^2, nz), where k=(zmax - zmin)/pace
381
+ Returns: Tensor
382
+ Tensor: the log posterior distribution log p(z|x) with
383
+ shape [batch_size, K^2]
384
+ """
385
+ try:
386
+ batch_size = x.size(0)
387
+ except:
388
+ batch_size = x[0].size(0)
389
+
390
+ # (batch_size, k^2, nz)
391
+ grid_z = grid_z.unsqueeze(0).expand(batch_size, *grid_z.size()).contiguous()
392
+
393
+ # (batch_size, k^2)
394
+ log_comp = self.eval_complete_ll(x, grid_z)
395
+
396
+ # normalize to posterior
397
+ log_posterior = log_comp - log_sum_exp(log_comp, dim=1, keepdim=True)
398
+
399
+ return log_posterior
400
+
401
+ def sample_from_inference(self, x, nsamples=1):
402
+ """perform sampling from inference net
403
+ Returns: Tensor
404
+ Tensor: samples from infernece nets with
405
+ shape (batch_size, nsamples, nz)
406
+ """
407
+ z, _ = self.encoder.sample(x, nsamples)
408
+
409
+ return z
410
+
411
+
412
+ def sample_from_posterior(self, x, nsamples):
413
+ """perform MH sampling from model posterior
414
+ Returns: Tensor
415
+ Tensor: samples from model posterior with
416
+ shape (batch_size, nsamples, nz)
417
+ """
418
+
419
+ # use the samples from inference net as initial points
420
+ # for MCMC sampling. [batch_size, nsamples, nz]
421
+ cur = self.encoder.sample_from_inference(x, 1)
422
+ cur_ll = self.eval_complete_ll(x, cur)
423
+ total_iter = self.args.mh_burn_in + nsamples * self.args.mh_thin
424
+ samples = []
425
+ for iter_ in range(total_iter):
426
+ next = torch.normal(mean=cur,
427
+ std=cur.new_full(size=cur.size(), fill_value=self.args.mh_std))
428
+ # [batch_size, 1]
429
+ next_ll = self.eval_complete_ll(x, next)
430
+ ratio = next_ll - cur_ll
431
+
432
+ accept_prob = torch.min(ratio.exp(), ratio.new_ones(ratio.size()))
433
+
434
+ uniform_t = accept_prob.new_empty(accept_prob.size()).uniform_()
435
+
436
+ # [batch_size, 1]
437
+ mask = (uniform_t < accept_prob).float()
438
+ mask_ = mask.unsqueeze(2)
439
+
440
+ cur = mask_ * next + (1 - mask_) * cur
441
+ cur_ll = mask * next_ll + (1 - mask) * cur_ll
442
+
443
+ if iter_ >= self.args.mh_burn_in and (iter_ - self.args.mh_burn_in) % self.args.mh_thin == 0:
444
+ samples.append(cur.unsqueeze(1))
445
+
446
+ return torch.cat(samples, dim=1)
447
+
448
+
449
+ def calc_model_posterior_mean(self, x, grid_z):
450
+ """compute the mean value of model posterior, i.e. E_{z ~ p(z|x)}[z]
451
+ Args:
452
+ grid_z: different z points that will be evaluated, with
453
+ shape (k^2, nz), where k=(zmax - zmin)/pace
454
+ x: [batch, *]
455
+ Returns: Tensor1
456
+ Tensor1: the mean value tensor with shape [batch, nz]
457
+ """
458
+
459
+ # [batch, K^2]
460
+ log_posterior = self.eval_log_model_posterior(x, grid_z)
461
+ posterior = log_posterior.exp()
462
+
463
+ # [batch, nz]
464
+ return torch.mul(posterior.unsqueeze(2), grid_z.unsqueeze(0)).sum(1)
465
+
466
+ def calc_infer_mean(self, x):
467
+ """
468
+ Returns: Tensor1
469
+ Tensor1: the mean of inference distribution, with shape [batch, nz]
470
+ """
471
+
472
+ mean, logvar = self.encoder.forward(x)
473
+
474
+ return mean
475
+
476
+
477
+
478
+
479
+ def eval_inference_dist(self, z, param):
480
+ """this function computes log q(z | x)
481
+ Args:
482
+ z: tensor
483
+ different z points that will be evaluated, with
484
+ shape [batch, nsamples, nz]
485
+ Returns: Tensor1
486
+ Tensor1: log q(z|x) with shape [batch, nsamples]
487
+ """
488
+
489
+ nz = z.size(2)
490
+ mu, logvar = param
491
+
492
+ # (batch_size, 1, nz)
493
+ mu, logvar = mu.unsqueeze(1), logvar.unsqueeze(1)
494
+ var = logvar.exp()
495
+
496
+ # (batch_size, nsamples, nz)
497
+ dev = z - mu
498
+
499
+ # (batch_size, nsamples)
500
+ log_density = -0.5 * ((dev ** 2) / var).sum(dim=-1) - \
501
+ 0.5 * (nz * math.log(2 * math.pi) + logvar.sum(-1))
502
+
503
+ return log_density
504
+
505
+
506
+
507
+ def calc_mi(self, test_data_batch, args):
508
+ # calc_mi_v3
509
+ import math
510
+ from modules.utils import log_sum_exp
511
+
512
+ mi = 0
513
+ num_examples = 0
514
+
515
+ mu_batch_list, logvar_batch_list = [], []
516
+ neg_entropy = 0.
517
+ for batch_data in test_data_batch:
518
+
519
+ x0, _, _ = batch_data
520
+ x0 = x0.to(args.device)
521
+
522
+ # encoding into bert features
523
+ bert_fea = self.encoder(x0)[1]
524
+
525
+ (batch_size, nz)
526
+ mu, logvar = self.encoder.linear(bert_fea).chunk(2, -1)
527
+
528
+ x_batch, nz = mu.size()
529
+
530
+ #print(x_batch, end=' ')
531
+
532
+ num_examples += x_batch
533
+
534
+ # E_{q(z|x)}log(q(z|x)) = -0.5*nz*log(2*\pi) - 0.5*(1+logvar).sum(-1)
535
+
536
+ neg_entropy += (-0.5 * nz * math.log(2 * math.pi)- 0.5 * (1 + logvar).sum(-1)).sum().item()
537
+ mu_batch_list += [mu.cpu()]
538
+ logvar_batch_list += [logvar.cpu()]
539
+
540
+ pdb.set_trace()
541
+
542
+ neg_entropy = neg_entropy / num_examples
543
+ ##print()
544
+
545
+ num_examples = 0
546
+ log_qz = 0.
547
+ for i in range(len(mu_batch_list)):
548
+ ###############
549
+ # get z_samples
550
+ ###############
551
+ mu, logvar = mu_batch_list[i].cuda(), logvar_batch_list[i].cuda()
552
+
553
+ # [z_batch, 1, nz]
554
+
555
+ z_samples = self.reparameterize(mu, logvar, 1)
556
+
557
+ z_samples = z_samples.view(-1, 1, nz)
558
+ num_examples += z_samples.size(0)
559
+
560
+ ###############
561
+ # compute density
562
+ ###############
563
+ # [1, x_batch, nz]
564
+ #mu, logvar = mu_batch_list[i].cuda(), logvar_batch_list[i].cuda()
565
+ #indices = list(np.random.choice(np.arange(len(mu_batch_list)), 10)) + [i]
566
+ indices = np.arange(len(mu_batch_list))
567
+ mu = torch.cat([mu_batch_list[_] for _ in indices], dim=0).cuda()
568
+ logvar = torch.cat([logvar_batch_list[_] for _ in indices], dim=0).cuda()
569
+ x_batch, nz = mu.size()
570
+
571
+ mu, logvar = mu.unsqueeze(0), logvar.unsqueeze(0)
572
+ var = logvar.exp()
573
+
574
+ # (z_batch, x_batch, nz)
575
+ dev = z_samples - mu
576
+
577
+ # (z_batch, x_batch)
578
+ log_density = -0.5 * ((dev ** 2) / var).sum(dim=-1) - \
579
+ 0.5 * (nz * math.log(2 * math.pi) + logvar.sum(-1))
580
+
581
+ # log q(z): aggregate posterior
582
+ # [z_batch]
583
+ log_qz += (log_sum_exp(log_density, dim=1) - math.log(x_batch)).sum(-1)
584
+
585
+ log_qz /= num_examples
586
+ mi = neg_entropy - log_qz
587
+
588
+ return mi
589
+
590
+
591
+
592
+ def calc_au(self, eval_dataloader, args, delta=0.01):
593
+ """compute the number of active units
594
+ """
595
+ cnt = 0
596
+ for batch_data in eval_dataloader:
597
+
598
+ x0, _, _ = batch_data
599
+ x0 = x0.to(args.device)
600
+
601
+ # encoding into bert features
602
+ bert_fea = self.encoder(x0)[1]
603
+
604
+ # (batch_size, nz)
605
+ mean, logvar = self.encoder.linear(bert_fea).chunk(2, -1)
606
+
607
+ if cnt == 0:
608
+ means_sum = mean.sum(dim=0, keepdim=True)
609
+ else:
610
+ means_sum = means_sum + mean.sum(dim=0, keepdim=True)
611
+ cnt += mean.size(0)
612
+
613
+ # (1, nz)
614
+ mean_mean = means_sum / cnt
615
+
616
+ cnt = 0
617
+ for batch_data in eval_dataloader:
618
+
619
+ x0, _, _ = batch_data
620
+ x0 = x0.to(args.device)
621
+
622
+ # encoding into bert features
623
+ bert_fea = self.encoder(x0)[1]
624
+
625
+ # (batch_size, nz)
626
+ mean, _ = self.encoder.linear(bert_fea).chunk(2, -1)
627
+
628
+ if cnt == 0:
629
+ var_sum = ((mean - mean_mean) ** 2).sum(dim=0)
630
+ else:
631
+ var_sum = var_sum + ((mean - mean_mean) ** 2).sum(dim=0)
632
+ cnt += mean.size(0)
633
+
634
+ # (nz)
635
+ au_var = var_sum / (cnt - 1)
636
+
637
+ return (au_var >= delta).sum().item(), au_var
638
+
Optimus/code/examples/big_ae/run_data_filtering.py ADDED
@@ -0,0 +1,507 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Fine-tuning the library models for language modeling on a text file (GPT, GPT-2, BERT, RoBERTa).
18
+ GPT and GPT-2 are fine-tuned using a causal language modeling (CLM) loss while BERT and RoBERTa are fine-tuned
19
+ using a masked language modeling (MLM) loss.
20
+ """
21
+
22
+ from __future__ import absolute_import, division, print_function
23
+
24
+
25
+ import pdb
26
+ import argparse
27
+ import glob
28
+ import logging
29
+
30
+ import os
31
+ import pickle
32
+ import json
33
+ import random
34
+ from pathlib import Path
35
+
36
+ import numpy as np
37
+ import torch
38
+ from torch.utils.data import DataLoader, Dataset, SequentialSampler, RandomSampler, TensorDataset
39
+ from torch.utils.data.distributed import DistributedSampler
40
+ from tensorboardX import SummaryWriter
41
+ from tqdm import tqdm, trange
42
+ from collections import defaultdict
43
+
44
+ # from azure.cosmosdb.table.tableservice import TableService
45
+ # from azure.cosmosdb.table.models import Entity
46
+ from datetime import datetime
47
+
48
+
49
+ from pytorch_transformers import (WEIGHTS_NAME, AdamW, WarmupLinearSchedule,
50
+ BertConfig, BertForLatentConnector, BertTokenizer,
51
+ GPT2Config, GPT2ForLatentConnector, GPT2Tokenizer,
52
+ OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer,
53
+ RobertaConfig, RobertaForMaskedLM, RobertaTokenizer)
54
+
55
+ from utils import (calc_iwnll, calc_mi, calc_au, BucketingDataLoader, MultipleFiles_DataLoader, BucketingMultipleFiles_DataLoader, frange_cycle_linear, frange_cycle_zero_linear)
56
+
57
+ from modules import VAE
58
+
59
+
60
+ # logging.getLogger("azure").setLevel(logging.WARNING)
61
+ # logging.getLogger("TableService").setLevel(logging.WARNING)
62
+
63
+ logger = logging.getLogger(__name__)
64
+
65
+
66
+ MODEL_CLASSES = {
67
+ 'gpt2': (GPT2Config, GPT2ForLatentConnector, GPT2Tokenizer),
68
+ 'openai-gpt': (OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer),
69
+ 'bert': (BertConfig, BertForLatentConnector, BertTokenizer),
70
+ 'roberta': (RobertaConfig, RobertaForMaskedLM, RobertaTokenizer)
71
+ }
72
+
73
+
74
+ storage_name="textae"
75
+ key=r"6yBCXlblof8DVFJ4BD3eNFTrGQCej6cKfCf5z308cKnevyHaG+yl/m+ITVErB9yt0kvN3ToqxLIh0knJEfFmPA=="
76
+ # ts = TableService(account_name=storage_name, account_key=key)
77
+
78
+
79
+
80
+ def build_dataload_and_cache_examples(args, tokenizer, evaluate=False):
81
+ if isinstance(tokenizer, list):
82
+ args.batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
83
+ file_path=args.input_file_path
84
+ dataloader = MultipleFiles_DataLoader(file_path, args.batch_size, args.max_seq_length, tokenizer, args, bucket=100, shuffle=True, use_tensor=False)
85
+ else:
86
+ pass
87
+ return dataloader
88
+
89
+
90
+ def set_seed(args):
91
+ random.seed(args.seed)
92
+ np.random.seed(args.seed)
93
+ torch.manual_seed(args.seed)
94
+ if args.n_gpu > 0:
95
+ torch.cuda.manual_seed_all(args.seed)
96
+
97
+
98
+ def mask_tokens(inputs, tokenizer, args):
99
+ """ Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. """
100
+ labels = inputs.clone()
101
+ # We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
102
+
103
+ masked_indices = torch.bernoulli(torch.full(labels.shape, args.mlm_probability)).to(torch.uint8)
104
+ labels[masked_indices==1] = -1 # We only compute loss on masked tokens
105
+
106
+ # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
107
+ indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).to(torch.uint8) & masked_indices
108
+ inputs[indices_replaced] = tokenizer.convert_tokens_to_ids(tokenizer.mask_token)
109
+
110
+ # 10% of the time, we replace masked input tokens with random word
111
+ indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).to(torch.uint8) & masked_indices & ~indices_replaced
112
+ indices_random = indices_random
113
+ random_words = torch.randint(len(tokenizer), labels.shape, dtype=torch.long)
114
+ inputs[indices_random] = random_words[indices_random]
115
+
116
+ # The rest of the time (10% of the time) we keep the masked input tokens unchanged
117
+ return inputs, labels
118
+
119
+
120
+ def train(args, train_dataloader, model_vae, encoder_tokenizer, decoder_tokenizer, table_name):
121
+ """ Train the model """
122
+ if args.local_rank in [-1, 0]:
123
+ tb_writer = SummaryWriter()
124
+
125
+ args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
126
+ # train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
127
+ # train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
128
+
129
+ if args.max_steps > 0:
130
+ t_total = args.max_steps
131
+ args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
132
+ else:
133
+ t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
134
+
135
+ # Prepare optimizer and schedule (linear warmup and decay)
136
+
137
+
138
+ # model_encoder, model_decoder, model_connector = model_vae.encoder, model_vae.decoder, model_vae.linear
139
+ no_decay = ['bias', 'LayerNorm.weight']
140
+ optimizer_grouped_parameters = [
141
+ {'params': [p for n, p in model_vae.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
142
+ {'params': [p for n, p in model_vae.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
143
+ ]
144
+
145
+ optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
146
+ scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total)
147
+
148
+
149
+ if args.fp16:
150
+ try:
151
+ from apex import amp
152
+ except ImportError:
153
+ raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
154
+ model_vae, optimizer = amp.initialize(model_vae, optimizer, opt_level=args.fp16_opt_level)
155
+
156
+ # multi-gpu training (should be after apex fp16 initialization)
157
+ if args.n_gpu > 1:
158
+ model_vae = torch.nn.DataParallel(model_vae, device_ids=range(args.n_gpu)).to(args.device)
159
+
160
+ # Distributed training (should be after apex fp16 initialization)
161
+ if args.local_rank != -1:
162
+ model_vae = torch.nn.parallel.DistributedDataParallel(model_vae, device_ids=[args.local_rank],
163
+ output_device=args.local_rank,
164
+ find_unused_parameters=True)
165
+
166
+
167
+
168
+ files = Path(args.input_file_path)
169
+ num_files = len(list(files.glob('*seq64*.json')))
170
+
171
+ # create output file folder
172
+ if not os.path.exists(args.output_file_path) and args.local_rank in [-1, 0]:
173
+ os.makedirs(args.output_file_path)
174
+
175
+
176
+ # Train!
177
+ logger.info("***** Running training *****")
178
+ logger.info(" Num files = %d", num_files)
179
+ logger.info(" Num examples of first file = %d", train_dataloader.num_examples)
180
+ logger.info(" Num Epochs = %d", args.num_train_epochs)
181
+ logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
182
+ logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d",
183
+ args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
184
+ logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
185
+ logger.info(" Total optimization steps = %d", t_total)
186
+
187
+
188
+ num_collected, num_dropped = 0, 0
189
+
190
+ model_vae.zero_grad()
191
+ num_train_epochs_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0])
192
+
193
+ n_iter = int(args.num_train_epochs) * len(train_dataloader)
194
+
195
+ tmp_list = []
196
+ dict_token_length = defaultdict(int)
197
+
198
+
199
+ if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
200
+ os.makedirs(args.output_dir)
201
+
202
+ dict_file = os.path.join(args.output_dir, args.dataset.lower()+f'.length_freq.json' )
203
+
204
+ set_seed(args) # Added here for reproducibility (even between python 2 and 3)
205
+ for epoch in num_train_epochs_iterator:
206
+
207
+ for idx_file in range(num_files):
208
+
209
+ examples = []
210
+ cached_features_file = os.path.join(args.output_file_path, args.dataset.lower()+f'.segmented.nltk.split.seq64.{train_dataloader.file_idx}.json' )
211
+ logger.info(f"Epoch {epoch}, File idx {train_dataloader.file_idx}")
212
+ epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
213
+
214
+ # if idx_file > 11:
215
+ # break
216
+
217
+ for step, batch in enumerate(epoch_iterator):
218
+
219
+ inst, token_lengths = batch
220
+ dict_token_length[ token_lengths[0,0].item() ] += 1
221
+
222
+ if ( token_lengths> 256 ).sum().item()>0:
223
+ over_length_tensor = ( token_lengths> 256 ).sum(-1)
224
+ inst_ = [inst[i] for i in range(len(inst)) if over_length_tensor[i]==0 ]
225
+ examples += inst_
226
+ num_collected += len(inst_)
227
+ num_dropped += len(inst) - len(inst_)
228
+ logger.info(f"{num_dropped} files filtered.")
229
+ else:
230
+ examples += inst
231
+ num_collected += len(inst)
232
+
233
+ # Good practice: save your data multiple times on Philly
234
+
235
+ if args.use_philly:
236
+ save_solid = False
237
+ while not save_solid:
238
+ try:
239
+ with open(cached_features_file, 'w') as fp:
240
+ json.dump(examples, fp)
241
+ save_solid = True
242
+ except:
243
+ pass
244
+ else:
245
+ with open(cached_features_file, 'w') as fp:
246
+ json.dump(examples, fp)
247
+ logger.info(f"Saving features in the cached file at {cached_features_file}")
248
+
249
+ train_dataloader.reset()
250
+
251
+ if args.local_rank in [-1, 0]:
252
+ tb_writer.close()
253
+
254
+ logger.info(dict_token_length)
255
+ # Good practice: save your dict multiple times on Philly
256
+ if args.use_philly:
257
+ save_solid = False
258
+ while not save_solid:
259
+ try:
260
+ with open(dict_file, 'w') as fp:
261
+ json.dump(dict_token_length, fp)
262
+ save_solid = True
263
+ except:
264
+ pass
265
+ else:
266
+ with open(dict_file, 'w') as fp:
267
+ json.dump(dict_token_length, fp)
268
+
269
+ return num_collected, num_dropped
270
+
271
+
272
+ def main():
273
+ parser = argparse.ArgumentParser()
274
+
275
+ ## Required parameters
276
+ parser.add_argument("--input_file_path", default=None, type=str, required=True,
277
+ help="The output directory where the input files will be written.")
278
+ parser.add_argument("--output_file_path", default=None, type=str, required=True,
279
+ help="The output directory where the output files will be written.")
280
+ parser.add_argument("--output_dir", default=None, type=str, required=True,
281
+ help="The output directory where the logs and results will be saved.")
282
+ parser.add_argument("--dataset", default=None, type=str, help="The dataset.")
283
+
284
+
285
+
286
+ ## Other parameters
287
+ parser.add_argument("--ExpName", default="", type=str,
288
+ help="The experiment name used in Azure Table.")
289
+
290
+ ## Encoder options
291
+ parser.add_argument("--encoder_model_type", default="bert", type=str,
292
+ help="The encoder model architecture to be fine-tuned.")
293
+ parser.add_argument("--encoder_model_name_or_path", default="bert-base-cased", type=str,
294
+ help="The encoder model checkpoint for weights initialization.")
295
+ parser.add_argument("--encoder_config_name", default="", type=str,
296
+ help="Optional pretrained config name or path if not the same as model_name_or_path")
297
+ parser.add_argument("--encoder_tokenizer_name", default="", type=str,
298
+ help="Optional pretrained tokenizer name or path if not the same as model_name_or_path")
299
+
300
+ ## Decoder options
301
+ parser.add_argument("--decoder_model_type", default="gpt2", type=str,
302
+ help="The decoder model architecture to be fine-tuned.")
303
+ parser.add_argument("--decoder_model_name_or_path", default="bert-base-cased", type=str,
304
+ help="The decoder model checkpoint for weights initialization.")
305
+ parser.add_argument("--decoder_config_name", default="", type=str,
306
+ help="Optional pretrained config name or path if not the same as model_name_or_path")
307
+ parser.add_argument("--decoder_tokenizer_name", default="", type=str,
308
+ help="Optional pretrained tokenizer name or path if not the same as model_name_or_path")
309
+
310
+ ## Variational auto-encoder
311
+ parser.add_argument("--latent_size", default=32, type=int, help="Latent space dimension.")
312
+ parser.add_argument("--use_deterministic_connect", action='store_true',
313
+ help="Use deterministic inference to generate latent codes, i.e., standard auto-encoders.")
314
+
315
+ ## Objective functions
316
+ parser.add_argument("--mlm", action='store_true',
317
+ help="Train with masked-language modeling loss instead of language modeling.")
318
+ parser.add_argument("--mlm_probability", type=float, default=0.15,
319
+ help="Ratio of tokens to mask for masked language modeling loss")
320
+ parser.add_argument("--beta", type=float, default=1.0,
321
+ help="The weighting hyper-parameter of the KL term in VAE")
322
+
323
+
324
+ parser.add_argument("--cache_dir", default="", type=str,
325
+ help="Optional directory to store the pre-trained models downloaded from s3 (instread of the default one)")
326
+ parser.add_argument("--max_seq_length", default=512, type=int,
327
+ help="Optional input sequence length before tokenization. The sequence will be dropped if it is longer the max_seq_length")
328
+ parser.add_argument("--block_size", default=-1, type=int,
329
+ help="Optional input sequence length after tokenization."
330
+ "The training dataset will be truncated in block of this size for training."
331
+ "Default to the model max input length for single sentence inputs (take into account special tokens).")
332
+ parser.add_argument("--do_train", action='store_true',
333
+ help="Whether to run training.")
334
+ parser.add_argument("--do_eval", action='store_true',
335
+ help="Whether to run eval on the dev set.")
336
+ parser.add_argument("--evaluate_during_training", action='store_true',
337
+ help="Run evaluation during training at each logging step.")
338
+ parser.add_argument("--do_lower_case", action='store_true',
339
+ help="Set this flag if you are using an uncased model.")
340
+
341
+
342
+ # Training Schedules
343
+ parser.add_argument("--ratio_increase", default=0.25, type=float,
344
+ help="Learning schedule, the percentage for the annealing stage.")
345
+ parser.add_argument("--ratio_zero", default=0.25, type=float,
346
+ help="Learning schedule, the percentage for the pure auto-encoding stage.")
347
+ parser.add_argument("--fb_mode", default=0, type=int,
348
+ help="free bit training mode.")
349
+ parser.add_argument("--dim_target_kl", default=3.0, type=float,
350
+ help="dim_target_kl free bit training mode.")
351
+ parser.add_argument("--per_gpu_train_batch_size", default=4, type=int,
352
+ help="Batch size per GPU/CPU for training.")
353
+ parser.add_argument("--per_gpu_eval_batch_size", default=1, type=int,
354
+ help="Batch size per GPU/CPU for evaluation.")
355
+ parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
356
+ help="Number of updates steps to accumulate before performing a backward/update pass.")
357
+ parser.add_argument("--learning_rate", default=5e-5, type=float,
358
+ help="The initial learning rate for Adam.")
359
+ parser.add_argument("--weight_decay", default=0.0, type=float,
360
+ help="Weight deay if we apply some.")
361
+ parser.add_argument("--adam_epsilon", default=1e-8, type=float,
362
+ help="Epsilon for Adam optimizer.")
363
+ parser.add_argument("--max_grad_norm", default=1.0, type=float,
364
+ help="Max gradient norm.")
365
+ parser.add_argument("--num_train_epochs", default=1.0, type=float,
366
+ help="Total number of training epochs to perform.")
367
+ parser.add_argument("--max_steps", default=-1, type=int,
368
+ help="If > 0: set total number of training steps to perform. Override num_train_epochs.")
369
+ parser.add_argument("--warmup_steps", default=0, type=int,
370
+ help="Linear warmup over warmup_steps.")
371
+ parser.add_argument("--use_philly", action='store_true',
372
+ help="Use Philly for computing.")
373
+
374
+ ## IO: Logging and Saving
375
+ parser.add_argument('--logging_steps', type=int, default=50,
376
+ help="Log every X updates steps.")
377
+ parser.add_argument('--save_steps', type=int, default=50,
378
+ help="Save checkpoint every X updates steps.")
379
+ parser.add_argument("--eval_all_checkpoints", action='store_true',
380
+ help="Evaluate all checkpoints starting with the same prefix as model_name_or_path ending and ending with step number")
381
+ parser.add_argument("--no_cuda", action='store_true',
382
+ help="Avoid using CUDA when available")
383
+ parser.add_argument('--overwrite_output_dir', action='store_true',
384
+ help="Overwrite the content of the output directory")
385
+ parser.add_argument('--overwrite_cache', action='store_true',
386
+ help="Overwrite the cached training and evaluation sets")
387
+ parser.add_argument('--seed', type=int, default=42,
388
+ help="random seed for initialization")
389
+ parser.add_argument('--gloabl_step_eval', type=int, default=661,
390
+ help="Evaluate the results at the given global step")
391
+
392
+ # Precision & Distributed Training
393
+ parser.add_argument('--fp16', action='store_true',
394
+ help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit")
395
+ parser.add_argument('--fp16_opt_level', type=str, default='O1',
396
+ help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
397
+ "See details at https://nvidia.github.io/apex/amp.html")
398
+ parser.add_argument("--local_rank", type=int, default=-1,
399
+ help="For distributed training: local_rank")
400
+ parser.add_argument('--server_ip', type=str, default='', help="For distant debugging.")
401
+ parser.add_argument('--server_port', type=str, default='', help="For distant debugging.")
402
+ args = parser.parse_args()
403
+
404
+ if args.decoder_model_type in ["bert", "roberta"] and not args.mlm:
405
+ raise ValueError("BERT and RoBERTa do not have LM heads but masked LM heads. They must be run using the --mlm "
406
+ "flag (masked language modeling).")
407
+
408
+ if os.path.exists(args.output_file_path) and os.listdir(args.output_file_path) and args.do_train and not args.overwrite_output_dir:
409
+ raise ValueError("Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_file_path))
410
+
411
+ # Setup distant debugging if needed
412
+ if args.server_ip and args.server_port:
413
+ # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
414
+ import ptvsd
415
+ print("Waiting for debugger attach")
416
+ ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
417
+ ptvsd.wait_for_attach()
418
+
419
+ # Setup CUDA, GPU & distributed training
420
+ logger.info(f'Local rank is {args.local_rank}')
421
+ if args.local_rank == -1 or args.no_cuda:
422
+ device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
423
+ args.n_gpu = torch.cuda.device_count()
424
+ else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
425
+ torch.cuda.set_device(args.local_rank)
426
+ device = torch.device("cuda", args.local_rank)
427
+ torch.distributed.init_process_group(backend='nccl')
428
+ args.n_gpu = 1
429
+ args.device = device
430
+
431
+ # Setup logging
432
+ logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
433
+ datefmt = '%m/%d/%Y %H:%M:%S',
434
+ level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
435
+ logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
436
+ args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16)
437
+
438
+ args.ExpName = 'Vae_' + args.dataset + '_Nz_' + str(args.latent_size) + '_Beta_' + str(args.beta) + '_Dkl_' + str(args.dim_target_kl) + '_Ra_' + str(args.ratio_increase) + '_R0_' + str(args.ratio_zero)
439
+ table_name = 'Vae' + args.dataset + 'Nz' + str(args.latent_size)
440
+ try:
441
+ ts.create_table(table_name)
442
+ except:
443
+ pass
444
+
445
+
446
+ # Set seed
447
+ set_seed(args)
448
+
449
+ # Load pretrained model and tokenizer
450
+ if args.local_rank not in [-1, 0]:
451
+ torch.distributed.barrier() # Barrier to make sure only the first process in distributed training download model & vocab
452
+
453
+ ## Encoder
454
+ encoder_config_class, encoder_model_class, encoder_tokenizer_class = MODEL_CLASSES[args.encoder_model_type]
455
+ encoder_config = encoder_config_class.from_pretrained(args.encoder_config_name if args.encoder_config_name else args.encoder_model_name_or_path)
456
+ tokenizer_encoder = encoder_tokenizer_class.from_pretrained(args.encoder_tokenizer_name if args.encoder_tokenizer_name else args.encoder_model_name_or_path, do_lower_case=args.do_lower_case)
457
+ if args.block_size <= 0:
458
+ args.block_size = tokenizer_encoder.max_len_single_sentence # Our input block size will be the max possible for the model
459
+ args.block_size = min(args.block_size, tokenizer_encoder.max_len_single_sentence)
460
+ model_encoder = encoder_model_class.from_pretrained(args.encoder_model_name_or_path, from_tf=bool('.ckpt' in args.encoder_model_name_or_path), config=encoder_config, latent_size=args.latent_size)
461
+ # model_encoder.to(args.device)
462
+
463
+ ## Decoder
464
+ decoder_config_class, decoder_model_class, decoder_tokenizer_class = MODEL_CLASSES[args.decoder_model_type]
465
+ decoder_config = decoder_config_class.from_pretrained(args.decoder_config_name if args.decoder_config_name else args.decoder_model_name_or_path)
466
+ tokenizer_decoder = decoder_tokenizer_class.from_pretrained(args.decoder_tokenizer_name if args.decoder_tokenizer_name else args.decoder_model_name_or_path, do_lower_case=args.do_lower_case)
467
+ if args.block_size <= 0:
468
+ args.block_size = tokenizer_decoder.max_len_single_sentence # Our input block size will be the max possible for the model
469
+ args.block_size = min(args.block_size, tokenizer_decoder.max_len_single_sentence)
470
+ model_decoder = decoder_model_class.from_pretrained(args.decoder_model_name_or_path, from_tf=bool('.ckpt' in args.decoder_model_name_or_path), config=decoder_config, latent_size=args.latent_size)
471
+
472
+ # Chunyuan: Add Padding token to GPT2
473
+ special_tokens_dict = {'pad_token': '<PAD>', 'bos_token': '<BOS>', 'eos_token': '<EOS>'}
474
+ num_added_toks = tokenizer_decoder.add_special_tokens(special_tokens_dict)
475
+ print('We have added', num_added_toks, 'tokens to GPT2')
476
+ model_decoder.resize_token_embeddings(len(tokenizer_decoder)) # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer.
477
+ assert tokenizer_decoder.pad_token == '<PAD>'
478
+
479
+ # model_decoder.to(args.device)
480
+
481
+ model_vae = VAE(model_encoder, model_decoder, tokenizer_encoder, tokenizer_decoder, args).to(args.device) #
482
+
483
+ # on_gpu = next(model_vae.parameters()).is_cuda
484
+
485
+
486
+
487
+ if args.local_rank == 0:
488
+ torch.distributed.barrier() # End of barrier to make sure only the first process in distributed training download model & vocab
489
+
490
+ logger.info("Training/evaluation parameters %s", args)
491
+
492
+ global_step= 0
493
+ # Training
494
+ if args.do_train:
495
+ if args.local_rank not in [-1, 0]:
496
+ torch.distributed.barrier() # Barrier to make sure only the first process in distributed training process the dataset, and the others will use the cache
497
+
498
+ train_dataloader = build_dataload_and_cache_examples(args, [tokenizer_encoder, tokenizer_decoder], evaluate=False)
499
+
500
+ if args.local_rank == 0:
501
+ torch.distributed.barrier()
502
+
503
+ num_collected, num_dropped = train(args, train_dataloader, model_vae, tokenizer_encoder, tokenizer_decoder, table_name)
504
+ logger.info(" num_collected = %s, num_dropped = %s", num_collected, num_dropped)
505
+
506
+ if __name__ == "__main__":
507
+ main()
Optimus/code/examples/big_ae/run_dialog_dataloader.py ADDED
@@ -0,0 +1,483 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Fine-tuning the library models for language modeling on a text file (GPT, GPT-2, BERT, RoBERTa).
18
+ GPT and GPT-2 are fine-tuned using a causal language modeling (CLM) loss while BERT and RoBERTa are fine-tuned
19
+ using a masked language modeling (MLM) loss.
20
+ """
21
+
22
+ from __future__ import absolute_import, division, print_function
23
+
24
+
25
+ import pdb
26
+ import argparse
27
+ import glob
28
+ import logging
29
+
30
+ import os
31
+ import pickle
32
+ import random
33
+
34
+ import numpy as np
35
+ import torch
36
+ from torch.utils.data import DataLoader, Dataset, SequentialSampler, RandomSampler, TensorDataset
37
+ from torch.utils.data.distributed import DistributedSampler
38
+ from tensorboardX import SummaryWriter
39
+ from tqdm import tqdm, trange
40
+ from collections import defaultdict
41
+
42
+ # from azure.cosmosdb.table.tableservice import TableService
43
+ # from azure.cosmosdb.table.models import Entity
44
+ from datetime import datetime
45
+
46
+
47
+
48
+ from pytorch_transformers import (WEIGHTS_NAME, AdamW, WarmupLinearSchedule,
49
+ BertConfig, BertForLatentConnector, BertTokenizer,
50
+ GPT2Config, GPT2ForLatentConnector, GPT2Tokenizer,
51
+ OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer,
52
+ RobertaConfig, RobertaForMaskedLM, RobertaTokenizer)
53
+
54
+ from utils import (calc_iwnll, calc_mi, calc_au, Dialog_BucketingDataLoader, TextDataset_Split, TextDataset_2Tokenizers, frange_cycle_linear, frange_cycle_zero_linear)
55
+
56
+
57
+ from modules import VAE
58
+
59
+
60
+ # logging.getLogger("azure").setLevel(logging.WARNING)
61
+ # logging.getLogger("TableService").setLevel(logging.WARNING)
62
+
63
+ logger = logging.getLogger(__name__)
64
+
65
+
66
+ MODEL_CLASSES = {
67
+ 'gpt2': (GPT2Config, GPT2ForLatentConnector, GPT2Tokenizer),
68
+ 'openai-gpt': (OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer),
69
+ 'bert': (BertConfig, BertForLatentConnector, BertTokenizer),
70
+ 'roberta': (RobertaConfig, RobertaForMaskedLM, RobertaTokenizer)
71
+ }
72
+
73
+
74
+ storage_name="textae"
75
+ key=r"6yBCXlblof8DVFJ4BD3eNFTrGQCej6cKfCf5z308cKnevyHaG+yl/m+ITVErB9yt0kvN3ToqxLIh0knJEfFmPA=="
76
+ # ts = TableService(account_name=storage_name, account_key=key)
77
+
78
+
79
+ def build_dataload_and_cache_examples(args, tokenizer, evaluate=False):
80
+ if isinstance(tokenizer, list):
81
+ if not evaluate:
82
+ args.batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
83
+ file_path=args.train_data_file
84
+ else:
85
+ args.batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
86
+ file_path=args.eval_data_file
87
+ dataloader = Dialog_BucketingDataLoader(file_path, args.batch_size, args.max_seq_length, tokenizer, args, bucket=100, shuffle=True)
88
+ else:
89
+ pass
90
+ return dataloader
91
+
92
+
93
+
94
+
95
+ def set_seed(args):
96
+ random.seed(args.seed)
97
+ np.random.seed(args.seed)
98
+ torch.manual_seed(args.seed)
99
+ if args.n_gpu > 0:
100
+ torch.cuda.manual_seed_all(args.seed)
101
+
102
+
103
+
104
+ def train(args, train_dataloader, model_vae, encoder_tokenizer, decoder_tokenizer, table_name):
105
+ """ Train the model """
106
+ if args.local_rank in [-1, 0]:
107
+ tb_writer = SummaryWriter()
108
+
109
+ args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
110
+ # train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
111
+ # train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
112
+
113
+ if args.max_steps > 0:
114
+ t_total = args.max_steps
115
+ args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
116
+ else:
117
+ t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
118
+
119
+ # Prepare optimizer and schedule (linear warmup and decay)
120
+
121
+
122
+ # model_encoder, model_decoder, model_connector = model_vae.encoder, model_vae.decoder, model_vae.linear
123
+ no_decay = ['bias', 'LayerNorm.weight']
124
+ optimizer_grouped_parameters = [
125
+ {'params': [p for n, p in model_vae.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
126
+ {'params': [p for n, p in model_vae.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
127
+ ]
128
+
129
+ optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
130
+ scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total)
131
+
132
+
133
+ if args.fp16:
134
+ try:
135
+ from apex import amp
136
+ except ImportError:
137
+ raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
138
+ model_vae, optimizer = amp.initialize(model_vae, optimizer, opt_level=args.fp16_opt_level)
139
+
140
+ # multi-gpu training (should be after apex fp16 initialization)
141
+ if args.n_gpu > 1:
142
+ model_vae = torch.nn.DataParallel(model_vae, device_ids=range(args.n_gpu)).to(args.device)
143
+
144
+ # Distributed training (should be after apex fp16 initialization)
145
+ if args.local_rank != -1:
146
+ model_vae = torch.nn.parallel.DistributedDataParallel(model_vae, device_ids=[args.local_rank],
147
+ output_device=args.local_rank,
148
+ find_unused_parameters=True)
149
+
150
+
151
+ # Train!
152
+ logger.info("***** Running training *****")
153
+ logger.info(" Num examples = %d", train_dataloader.num_examples)
154
+ logger.info(" Num Epochs = %d", args.num_train_epochs)
155
+ logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
156
+ logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d",
157
+ args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
158
+ logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
159
+ logger.info(" Total optimization steps = %d", t_total)
160
+
161
+ global_step = 0
162
+ tr_loss, logging_loss = 0.0, 0.0
163
+
164
+
165
+ model_vae.zero_grad()
166
+
167
+ # model_vae = model_vae.module if hasattr(model_vae, 'module') else model_vae # Take care of distributed/parallel training
168
+
169
+ train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0])
170
+
171
+ n_iter = int(args.num_train_epochs) * len(train_dataloader)
172
+ beta_t_list = frange_cycle_zero_linear(n_iter, start=0.0, stop=args.beta, n_cycle=1, ratio_increase=args.ratio_increase, ratio_zero=args.ratio_zero)
173
+
174
+ tmp_list = []
175
+ set_seed(args) # Added here for reproducibility (even between python 2 and 3)
176
+ for epoch in train_iterator:
177
+ epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
178
+ for step, batch in enumerate(epoch_iterator):
179
+
180
+ input_ids_bert_ctx, input_ids_bert, input_ids_gpt, token_lengths = batch
181
+
182
+ logger.info(f'Conxtext in Bert, Length {token_lengths[0]} ; Tokens: {input_ids_bert_ctx}')
183
+ logger.info(f'Response in Bert, Length {token_lengths[1]} ; Tokens: {input_ids_bert}')
184
+ logger.info(f'Response in GPT2, Length {token_lengths[2]} ; Tokens: {input_ids_gpt}')
185
+ # TODO: write donw training scripts for dialog response generation
186
+
187
+
188
+ if (step + 1) % args.gradient_accumulation_steps == 0:
189
+
190
+ global_step += 1
191
+
192
+
193
+ if args.max_steps > 0 and global_step > args.max_steps:
194
+ epoch_iterator.close()
195
+ break
196
+
197
+
198
+ if args.max_steps > 0 and global_step > args.max_steps:
199
+ train_iterator.close()
200
+ break
201
+
202
+ if args.local_rank in [-1, 0]:
203
+ tb_writer.close()
204
+
205
+ return global_step
206
+
207
+
208
+
209
+
210
+
211
+
212
+ def main():
213
+ parser = argparse.ArgumentParser()
214
+
215
+ ## Required parameters
216
+ parser.add_argument("--train_data_file", default=None, type=str, required=True,
217
+ help="The input training data file (a text file).")
218
+ parser.add_argument("--output_dir", default=None, type=str, required=True,
219
+ help="The output directory where the model predictions and checkpoints will be written.")
220
+ parser.add_argument("--dataset", default=None, type=str, help="The dataset.")
221
+
222
+ ## Other parameters
223
+ parser.add_argument("--eval_data_file", default=None, type=str,
224
+ help="An optional input evaluation data file to evaluate the perplexity on (a text file).")
225
+ parser.add_argument("--ExpName", default="", type=str,
226
+ help="The experiment name used in Azure Table.")
227
+
228
+ ## Encoder options
229
+ parser.add_argument("--encoder_model_type", default="bert", type=str,
230
+ help="The encoder model architecture to be fine-tuned.")
231
+ parser.add_argument("--encoder_model_name_or_path", default="bert-base-cased", type=str,
232
+ help="The encoder model checkpoint for weights initialization.")
233
+ parser.add_argument("--encoder_config_name", default="", type=str,
234
+ help="Optional pretrained config name or path if not the same as model_name_or_path")
235
+ parser.add_argument("--encoder_tokenizer_name", default="", type=str,
236
+ help="Optional pretrained tokenizer name or path if not the same as model_name_or_path")
237
+
238
+ ## Decoder options
239
+ parser.add_argument("--decoder_model_type", default="gpt2", type=str,
240
+ help="The decoder model architecture to be fine-tuned.")
241
+ parser.add_argument("--decoder_model_name_or_path", default="bert-base-cased", type=str,
242
+ help="The decoder model checkpoint for weights initialization.")
243
+ parser.add_argument("--decoder_config_name", default="", type=str,
244
+ help="Optional pretrained config name or path if not the same as model_name_or_path")
245
+ parser.add_argument("--decoder_tokenizer_name", default="", type=str,
246
+ help="Optional pretrained tokenizer name or path if not the same as model_name_or_path")
247
+
248
+ ## Variational auto-encoder
249
+ parser.add_argument("--latent_size", default=32, type=int, help="Latent space dimension.")
250
+ parser.add_argument("--use_deterministic_connect", action='store_true',
251
+ help="Use deterministic inference to generate latent codes, i.e., standard auto-encoders.")
252
+ parser.add_argument("--use_pretrained_model", action='store_true',
253
+ help="Use pre-trained auto-encoder models as the initialization")
254
+
255
+ ## Objective functions
256
+ parser.add_argument("--mlm", action='store_true',
257
+ help="Train with masked-language modeling loss instead of language modeling.")
258
+ parser.add_argument("--mlm_probability", type=float, default=0.15,
259
+ help="Ratio of tokens to mask for masked language modeling loss")
260
+ parser.add_argument("--beta", type=float, default=1.0,
261
+ help="The weighting hyper-parameter of the KL term in VAE")
262
+
263
+
264
+ parser.add_argument("--cache_dir", default="", type=str,
265
+ help="Optional directory to store the pre-trained models downloaded from s3 (instread of the default one)")
266
+ parser.add_argument("--max_seq_length", default=512, type=int,
267
+ help="Optional input sequence length before tokenization. The sequence will be dropped if it is longer the max_seq_length")
268
+ parser.add_argument("--block_size", default=-1, type=int,
269
+ help="Optional input sequence length after tokenization."
270
+ "The training dataset will be truncated in block of this size for training."
271
+ "Default to the model max input length for single sentence inputs (take into account special tokens).")
272
+ parser.add_argument("--do_train", action='store_true',
273
+ help="Whether to run training.")
274
+ parser.add_argument("--do_eval", action='store_true',
275
+ help="Whether to run eval on the dev set.")
276
+ parser.add_argument("--evaluate_during_training", action='store_true',
277
+ help="Run evaluation during training at each logging step.")
278
+ parser.add_argument("--do_lower_case", action='store_true',
279
+ help="Set this flag if you are using an uncased model.")
280
+
281
+
282
+ # Training Schedules
283
+ parser.add_argument("--ratio_increase", default=0.25, type=float,
284
+ help="Learning schedule, the percentage for the annealing stage.")
285
+ parser.add_argument("--ratio_zero", default=0.25, type=float,
286
+ help="Learning schedule, the percentage for the pure auto-encoding stage.")
287
+ parser.add_argument("--fb_mode", default=0, type=int,
288
+ help="free bit training mode.")
289
+ parser.add_argument("--dim_target_kl", default=3.0, type=float,
290
+ help="dim_target_kl free bit training mode.")
291
+ parser.add_argument("--per_gpu_train_batch_size", default=4, type=int,
292
+ help="Batch size per GPU/CPU for training.")
293
+ parser.add_argument("--per_gpu_eval_batch_size", default=1, type=int,
294
+ help="Batch size per GPU/CPU for evaluation.")
295
+ parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
296
+ help="Number of updates steps to accumulate before performing a backward/update pass.")
297
+ parser.add_argument("--learning_rate", default=5e-5, type=float,
298
+ help="The initial learning rate for Adam.")
299
+ parser.add_argument("--weight_decay", default=0.0, type=float,
300
+ help="Weight deay if we apply some.")
301
+ parser.add_argument("--adam_epsilon", default=1e-8, type=float,
302
+ help="Epsilon for Adam optimizer.")
303
+ parser.add_argument("--max_grad_norm", default=1.0, type=float,
304
+ help="Max gradient norm.")
305
+ parser.add_argument("--num_train_epochs", default=1.0, type=float,
306
+ help="Total number of training epochs to perform.")
307
+ parser.add_argument("--max_steps", default=-1, type=int,
308
+ help="If > 0: set total number of training steps to perform. Override num_train_epochs.")
309
+ parser.add_argument("--warmup_steps", default=0, type=int,
310
+ help="Linear warmup over warmup_steps.")
311
+ parser.add_argument("--use_philly", action='store_true',
312
+ help="Use Philly for computing.")
313
+
314
+ ## IO: Logging and Saving
315
+ parser.add_argument('--logging_steps', type=int, default=50,
316
+ help="Log every X updates steps.")
317
+ parser.add_argument('--save_steps', type=int, default=50,
318
+ help="Save checkpoint every X updates steps.")
319
+ parser.add_argument("--eval_all_checkpoints", action='store_true',
320
+ help="Evaluate all checkpoints starting with the same prefix as model_name_or_path ending and ending with step number")
321
+ parser.add_argument("--no_cuda", action='store_true',
322
+ help="Avoid using CUDA when available")
323
+ parser.add_argument('--overwrite_output_dir', action='store_true',
324
+ help="Overwrite the content of the output directory")
325
+ parser.add_argument('--overwrite_cache', action='store_true',
326
+ help="Overwrite the cached training and evaluation sets")
327
+ parser.add_argument('--seed', type=int, default=42,
328
+ help="random seed for initialization")
329
+ parser.add_argument('--gloabl_step_eval', type=int, default=661,
330
+ help="Evaluate the results at the given global step")
331
+
332
+ # Precision & Distributed Training
333
+ parser.add_argument('--fp16', action='store_true',
334
+ help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit")
335
+ parser.add_argument('--fp16_opt_level', type=str, default='O1',
336
+ help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
337
+ "See details at https://nvidia.github.io/apex/amp.html")
338
+ parser.add_argument("--local_rank", type=int, default=-1,
339
+ help="For distributed training: local_rank")
340
+ parser.add_argument('--server_ip', type=str, default='', help="For distant debugging.")
341
+ parser.add_argument('--server_port', type=str, default='', help="For distant debugging.")
342
+ args = parser.parse_args()
343
+
344
+ if args.decoder_model_type in ["bert", "roberta"] and not args.mlm:
345
+ raise ValueError("BERT and RoBERTa do not have LM heads but masked LM heads. They must be run using the --mlm "
346
+ "flag (masked language modeling).")
347
+ if args.eval_data_file is None and args.do_eval:
348
+ raise ValueError("Cannot do evaluation without an evaluation data file. Either supply a file to --eval_data_file "
349
+ "or remove the --do_eval argument.")
350
+
351
+ if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir:
352
+ raise ValueError("Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir))
353
+
354
+ # Setup distant debugging if needed
355
+ if args.server_ip and args.server_port:
356
+ # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
357
+ import ptvsd
358
+ print("Waiting for debugger attach")
359
+ ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
360
+ ptvsd.wait_for_attach()
361
+
362
+ # Setup CUDA, GPU & distributed training
363
+ if args.local_rank == -1 or args.no_cuda:
364
+ device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
365
+ args.n_gpu = torch.cuda.device_count()
366
+ else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
367
+ torch.cuda.set_device(args.local_rank)
368
+ device = torch.device("cuda", args.local_rank)
369
+ torch.distributed.init_process_group(backend='nccl')
370
+ args.n_gpu = 1
371
+ args.device = device
372
+
373
+ # Setup logging
374
+ logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
375
+ datefmt = '%m/%d/%Y %H:%M:%S',
376
+ level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
377
+ logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
378
+ args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16)
379
+
380
+ args.ExpName = 'Vae_' + args.dataset + '_Nz_' + str(args.latent_size) + '_Beta_' + str(args.beta) + '_Dkl_' + str(args.dim_target_kl) + '_Ra_' + str(args.ratio_increase) + '_R0_' + str(args.ratio_zero)
381
+ table_name = 'Vae' + args.dataset + 'Nz' + str(args.latent_size)
382
+ try:
383
+ ts.create_table(table_name)
384
+ except:
385
+ pass
386
+
387
+
388
+ # Set seed
389
+ set_seed(args)
390
+
391
+ # Load pretrained model and tokenizer
392
+ if args.local_rank not in [-1, 0]:
393
+ torch.distributed.barrier() # Barrier to make sure only the first process in distributed training download model & vocab
394
+
395
+ if args.use_pretrained_model:
396
+
397
+ args.encoder_model_type = args.encoder_model_type.lower()
398
+ args.decoder_model_type = args.decoder_model_type.lower()
399
+
400
+ global_step = args.gloabl_step_eval
401
+
402
+ output_encoder_dir = os.path.join(args.checkpoint_dir, 'checkpoint-encoder-{}'.format(global_step))
403
+ output_decoder_dir = os.path.join(args.checkpoint_dir, 'checkpoint-decoder-{}'.format(global_step))
404
+ checkpoints = [ [output_encoder_dir, output_decoder_dir] ]
405
+ logger.info("Evaluate the following checkpoints: %s", checkpoints)
406
+
407
+ # Load a trained Encoder model and vocabulary
408
+ encoder_config_class, encoder_model_class, encoder_tokenizer_class = MODEL_CLASSES[args.encoder_model_type]
409
+ model_encoder = encoder_model_class.from_pretrained(output_encoder_dir, latent_size=args.latent_size)
410
+ tokenizer_encoder = encoder_tokenizer_class.from_pretrained(args.encoder_tokenizer_name if args.encoder_tokenizer_name else args.encoder_model_name_or_path, do_lower_case=args.do_lower_case)
411
+
412
+ model_encoder.to(args.device)
413
+ if args.block_size <= 0:
414
+ args.block_size = tokenizer_encoder.max_len_single_sentence # Our input block size will be the max possible for the model
415
+ args.block_size = min(args.block_size, tokenizer_encoder.max_len_single_sentence)
416
+
417
+ # Load a trained Decoder model and vocabulary
418
+ decoder_config_class, decoder_model_class, decoder_tokenizer_class = MODEL_CLASSES[args.decoder_model_type]
419
+ model_decoder = decoder_model_class.from_pretrained(output_decoder_dir, latent_size=args.latent_size)
420
+ tokenizer_decoder = decoder_tokenizer_class.from_pretrained(args.decoder_tokenizer_name if args.decoder_tokenizer_name else args.decoder_model_name_or_path, do_lower_case=args.do_lower_case)
421
+ model_decoder.to(args.device)
422
+ if args.block_size <= 0:
423
+ args.block_size = tokenizer_decoder.max_len_single_sentence # Our input block size will be the max possible for the model
424
+ args.block_size = min(args.block_size, tokenizer_decoder.max_len_single_sentence)
425
+
426
+ else:
427
+ ## Encoder
428
+ encoder_config_class, encoder_model_class, encoder_tokenizer_class = MODEL_CLASSES[args.encoder_model_type]
429
+ encoder_config = encoder_config_class.from_pretrained(args.encoder_config_name if args.encoder_config_name else args.encoder_model_name_or_path)
430
+ tokenizer_encoder = encoder_tokenizer_class.from_pretrained(args.encoder_tokenizer_name if args.encoder_tokenizer_name else args.encoder_model_name_or_path, do_lower_case=args.do_lower_case)
431
+ if args.block_size <= 0:
432
+ args.block_size = tokenizer_encoder.max_len_single_sentence # Our input block size will be the max possible for the model
433
+ args.block_size = min(args.block_size, tokenizer_encoder.max_len_single_sentence)
434
+ model_encoder = encoder_model_class.from_pretrained(args.encoder_model_name_or_path, from_tf=bool('.ckpt' in args.encoder_model_name_or_path), config=encoder_config, latent_size=args.latent_size)
435
+ # model_encoder.to(args.device)
436
+
437
+ ## Decoder
438
+ decoder_config_class, decoder_model_class, decoder_tokenizer_class = MODEL_CLASSES[args.decoder_model_type]
439
+ decoder_config = decoder_config_class.from_pretrained(args.decoder_config_name if args.decoder_config_name else args.decoder_model_name_or_path)
440
+ tokenizer_decoder = decoder_tokenizer_class.from_pretrained(args.decoder_tokenizer_name if args.decoder_tokenizer_name else args.decoder_model_name_or_path, do_lower_case=args.do_lower_case)
441
+ if args.block_size <= 0:
442
+ args.block_size = tokenizer_decoder.max_len_single_sentence # Our input block size will be the max possible for the model
443
+ args.block_size = min(args.block_size, tokenizer_decoder.max_len_single_sentence)
444
+ model_decoder = decoder_model_class.from_pretrained(args.decoder_model_name_or_path, from_tf=bool('.ckpt' in args.decoder_model_name_or_path), config=decoder_config, latent_size=args.latent_size)
445
+
446
+ pdb.set_trace()
447
+
448
+ # Chunyuan: Add Padding token to GPT2
449
+ special_tokens_dict = {'pad_token': '<PAD>', 'bos_token': '<BOS>', 'eos_token': '<EOS>'}
450
+ num_added_toks = tokenizer_decoder.add_special_tokens(special_tokens_dict)
451
+ print('We have added', num_added_toks, 'tokens to GPT2')
452
+ model_decoder.resize_token_embeddings(len(tokenizer_decoder)) # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer.
453
+ assert tokenizer_decoder.pad_token == '<PAD>'
454
+
455
+ # model_decoder.to(args.device)
456
+
457
+ model_vae = VAE(model_encoder, model_decoder, tokenizer_encoder, tokenizer_decoder, args).to(args.device) #
458
+
459
+ # on_gpu = next(model_vae.parameters()).is_cuda
460
+
461
+
462
+
463
+ if args.local_rank == 0:
464
+ torch.distributed.barrier() # End of barrier to make sure only the first process in distributed training download model & vocab
465
+
466
+ logger.info("Training/evaluation parameters %s", args)
467
+
468
+ global_step= 0
469
+ # Training
470
+ if args.do_train:
471
+ if args.local_rank not in [-1, 0]:
472
+ torch.distributed.barrier() # Barrier to make sure only the first process in distributed training process the dataset, and the others will use the cache
473
+
474
+ train_dataloader = build_dataload_and_cache_examples(args, [tokenizer_encoder, tokenizer_decoder], evaluate=False)
475
+
476
+ if args.local_rank == 0:
477
+ torch.distributed.barrier()
478
+
479
+ global_step = train(args, train_dataloader, model_vae, tokenizer_encoder, tokenizer_decoder, table_name)
480
+ logger.info(" global_step = %s", global_step)
481
+
482
+ if __name__ == "__main__":
483
+ main()
Optimus/code/examples/big_ae/run_encoding_generation.py ADDED
@@ -0,0 +1,487 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # coding=utf-8
3
+ # Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team.
4
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ """ Conditional text generation with the auto-regressive models of the library (GPT/GPT-2/Transformer-XL/XLNet)
18
+ """
19
+ from __future__ import absolute_import, division, print_function, unicode_literals
20
+
21
+ import argparse
22
+ import glob
23
+ import logging
24
+ import os
25
+ import pickle
26
+ import random
27
+
28
+
29
+ import torch
30
+ import torch.nn.functional as F
31
+ import numpy as np
32
+
33
+ from torch.utils.data import DataLoader, Dataset, SequentialSampler, RandomSampler, TensorDataset
34
+ from torch.utils.data.distributed import DistributedSampler
35
+ from tqdm import tqdm, trange
36
+
37
+
38
+ from pytorch_transformers import GPT2Config, OpenAIGPTConfig, XLNetConfig, TransfoXLConfig, BertConfig
39
+ from pytorch_transformers import GPT2LMHeadModel, GPT2Tokenizer, GPT2ForLatentConnector
40
+ from pytorch_transformers import OpenAIGPTLMHeadModel, OpenAIGPTTokenizer
41
+ from pytorch_transformers import XLNetLMHeadModel, XLNetTokenizer
42
+ from pytorch_transformers import TransfoXLLMHeadModel, TransfoXLTokenizer
43
+ from pytorch_transformers import BertForLatentConnector, BertTokenizer
44
+
45
+ from collections import defaultdict
46
+ from modules import VAE
47
+ from utils import (TextDataset_Split, TextDataset_2Tokenizers, BucketingDataLoader)
48
+
49
+
50
+ import pdb
51
+
52
+
53
+ logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
54
+ datefmt = '%m/%d/%Y %H:%M:%S',
55
+ level = logging.INFO)
56
+ logger = logging.getLogger(__name__)
57
+
58
+ MAX_LENGTH = int(10000) # Hardcoded max length to avoid infinite loop
59
+
60
+ ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (GPT2Config, OpenAIGPTConfig, XLNetConfig, TransfoXLConfig)), ())
61
+
62
+ MODEL_CLASSES = {
63
+ 'gpt2': (GPT2Config, GPT2ForLatentConnector, GPT2Tokenizer),
64
+ 'bert': (BertConfig, BertForLatentConnector, BertTokenizer)
65
+ }
66
+
67
+ # Padding text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia
68
+ # in https://github.com/rusiaaman/XLNet-gen#methodology
69
+ # and https://medium.com/@amanrusia/xlnet-speaks-comparison-to-gpt-2-ea1a4e9ba39e
70
+ PADDING_TEXT = """ In 1991, the remains of Russian Tsar Nicholas II and his family
71
+ (except for Alexei and Maria) are discovered.
72
+ The voice of Nicholas's young son, Tsarevich Alexei Nikolaevich, narrates the
73
+ remainder of the story. 1883 Western Siberia,
74
+ a young Grigori Rasputin is asked by his father and a group of men to perform magic.
75
+ Rasputin has a vision and denounces one of the men as a horse thief. Although his
76
+ father initially slaps him for making such an accusation, Rasputin watches as the
77
+ man is chased outside and beaten. Twenty years later, Rasputin sees a vision of
78
+ the Virgin Mary, prompting him to become a priest. Rasputin quickly becomes famous,
79
+ with people, even a bishop, begging for his blessing. <eod> </s> <eos>"""
80
+
81
+
82
+ def set_seed(args):
83
+ np.random.seed(args.seed)
84
+ torch.manual_seed(args.seed)
85
+ if args.n_gpu > 0:
86
+ torch.cuda.manual_seed_all(args.seed)
87
+
88
+
89
+ def load_and_cache_examples(args, tokenizer, evaluate=False):
90
+ if isinstance(tokenizer, list):
91
+ dataset = TextDataset_2Tokenizers(tokenizer, args, file_path=args.eval_data_file if evaluate else args.train_data_file, block_size=args.block_size)
92
+ else:
93
+ dataset = TextDataset_Split(tokenizer, args, file_path=args.eval_data_file if evaluate else args.train_data_file, block_size=args.block_size)
94
+ return dataset
95
+
96
+ def build_dataload_and_cache_examples(args, tokenizer, evaluate=False):
97
+ if isinstance(tokenizer, list):
98
+ if not evaluate:
99
+ args.batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
100
+ file_path=args.train_data_file
101
+ else:
102
+ args.batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
103
+ file_path=args.eval_data_file
104
+ dataloader = BucketingDataLoader(file_path, args.batch_size, args.max_seq_length, tokenizer, args, bucket=100, shuffle=False)
105
+ else:
106
+ pass
107
+ return dataloader
108
+
109
+
110
+ def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
111
+ """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
112
+ Args:
113
+ logits: logits distribution shape (vocabulary size)
114
+ top_k > 0: keep only top k tokens with highest probability (top-k filtering).
115
+ top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
116
+ Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
117
+ From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
118
+ """
119
+ assert logits.dim() == 1 # batch size 1 for now - could be updated for more but the code would be less clear
120
+ top_k = min(top_k, logits.size(-1)) # Safety check
121
+ if top_k > 0:
122
+ # Remove all tokens with a probability less than the last token of the top-k
123
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
124
+ logits[indices_to_remove] = filter_value
125
+
126
+ if top_p > 0.0:
127
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
128
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
129
+
130
+ # Remove tokens with cumulative probability above the threshold
131
+ sorted_indices_to_remove = cumulative_probs > top_p
132
+ # Shift the indices to the right to keep also the first token above the threshold
133
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
134
+ sorted_indices_to_remove[..., 0] = 0
135
+
136
+ indices_to_remove = sorted_indices[sorted_indices_to_remove]
137
+ logits[indices_to_remove] = filter_value
138
+ return logits
139
+
140
+
141
+ def sample_sequence(model, length, context, num_samples=1, temperature=1, top_k=0, top_p=0.0, is_xlnet=False, device='cpu'):
142
+ context = torch.tensor(context, dtype=torch.long, device=device)
143
+ context = context.unsqueeze(0).repeat(num_samples, 1)
144
+ generated = context
145
+ with torch.no_grad():
146
+ for _ in trange(length):
147
+
148
+ inputs = {'input_ids': generated}
149
+ if is_xlnet:
150
+ # XLNet is a direct (predict same token, not next token) and bi-directional model by default
151
+ # => need one additional dummy token in the input (will be masked), attention mask and target mapping (see model docstring)
152
+ input_ids = torch.cat((generated, torch.zeros((1, 1), dtype=torch.long, device=device)), dim=1)
153
+ perm_mask = torch.zeros((1, input_ids.shape[1], input_ids.shape[1]), dtype=torch.float, device=device)
154
+ perm_mask[:, :, -1] = 1.0 # Previous tokens don't see last token
155
+ target_mapping = torch.zeros((1, 1, input_ids.shape[1]), dtype=torch.float, device=device)
156
+ target_mapping[0, 0, -1] = 1.0 # predict last token
157
+ inputs = {'input_ids': input_ids, 'perm_mask': perm_mask, 'target_mapping': target_mapping}
158
+
159
+ outputs = model(**inputs) # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet (cached hidden-states)
160
+ next_token_logits = outputs[0][0, -1, :] / temperature
161
+ filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
162
+ next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
163
+ generated = torch.cat((generated, next_token.unsqueeze(0)), dim=1)
164
+ return generated
165
+
166
+ def sample_sequence_conditional(model, length, context, past=None, num_samples=1, temperature=1, top_k=0, top_p=0.0, device='cpu', decoder_tokenizer=None):
167
+
168
+ context = torch.tensor(context, dtype=torch.long, device=device)
169
+ context = context.unsqueeze(0).repeat(num_samples, 1)
170
+ generated = context
171
+ with torch.no_grad():
172
+ while True:
173
+ # for _ in trange(length):
174
+ inputs = {'input_ids': generated, 'past': past}
175
+ outputs = model(**inputs) # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet (cached hidden-states)
176
+ next_token_logits = outputs[0][0, -1, :] / temperature
177
+ filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
178
+ next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
179
+ generated = torch.cat((generated, next_token.unsqueeze(0)), dim=1)
180
+
181
+ # pdb.set_trace()
182
+ if next_token.unsqueeze(0)[0,0].item() == decoder_tokenizer.encode('<EOS>')[0]:
183
+ break
184
+
185
+ return generated
186
+
187
+
188
+
189
+ # a wrapper function to choose between different play modes
190
+ def evaluate_latent_space(args, model_vae, encoder_tokenizer, decoder_tokenizer, prefix=""):
191
+
192
+ eval_dataloader = build_dataload_and_cache_examples(args, [encoder_tokenizer, decoder_tokenizer], evaluate=False)
193
+
194
+ # Eval!
195
+ logger.info("***** Running recontruction evaluation {} *****".format(prefix))
196
+ logger.info(" Num examples = %d", len(eval_dataloader))
197
+ logger.info(" Batch size = %d", args.per_gpu_eval_batch_size)
198
+
199
+ model_vae.eval()
200
+
201
+ model_vae = model_vae.module if hasattr(model_vae, 'module') else model_vae # Take care of distributed/parallel training
202
+
203
+ if args.play_mode == 'reconstrction':
204
+ result = calc_rec(model_vae, eval_dataloader, encoder_tokenizer, decoder_tokenizer, args, ns=100)
205
+ result_file_name = "eval_recontruction_results.txt"
206
+ elif args.play_mode == 'interpolation':
207
+ result = calc_interpolate(model_vae, eval_dataloader, encoder_tokenizer, decoder_tokenizer, args, ns=100)
208
+ result_file_name = "eval_interpolation_results.txt"
209
+ else:
210
+ logger.info("Please specify the corrent play mode [reconstrction, interpolation]")
211
+
212
+
213
+ eval_output_dir = args.output_dir
214
+ output_eval_file = os.path.join(eval_output_dir, result_file_name)
215
+
216
+ with open(output_eval_file, "w") as writer:
217
+ logger.info("***** Eval {} results *****".format(args.play_mode))
218
+ for key in sorted(result.keys()):
219
+ logger.info(" %s \n %s", key, str(result[key]))
220
+ writer.write("%s \n %s\n" % (key, str(result[key])))
221
+
222
+ return result
223
+
224
+
225
+ def calc_rec(model_vae, eval_dataloader, encoder_tokenizer, decoder_tokenizer, args, ns=1):
226
+
227
+ count = 0
228
+ result = defaultdict(str)
229
+ for batch in tqdm(eval_dataloader, desc="Evaluating recontruction"):
230
+ # pdb.set_trace()
231
+ x0, x1, x_lengths = batch
232
+
233
+ max_len_values, _ = x_lengths.max(0)
234
+ x0 = x0[:,:max_len_values[0]]
235
+ x1 = x1[:,:max_len_values[1]]
236
+
237
+ x0 = x0.to(args.device)
238
+ x1 = x1.to(args.device)
239
+ x_lengths = x_lengths.to(args.device)
240
+
241
+ context_tokens = decoder_tokenizer.encode('<BOS>')
242
+
243
+ with torch.no_grad():
244
+
245
+ text_x0 = encoder_tokenizer.decode(x0[0,:x_lengths[0,0]].tolist(), clean_up_tokenization_spaces=True)[0]
246
+ # result["INPUT TEXT " + str(count)].append(text_x0)
247
+
248
+ pooled_hidden_fea = model_vae.encoder(x0, attention_mask=(x0 > 0).float())[1]
249
+
250
+ # Connect hidden feature to the latent space
251
+ # latent_z, loss_kl = model_vae.connect(pooled_hidden_fea)
252
+ mean, logvar = model_vae.encoder.linear(pooled_hidden_fea).chunk(2, -1)
253
+ latent_z = mean.squeeze(1)
254
+
255
+ past = latent_z
256
+ out = sample_sequence_conditional(
257
+ model=model_vae.decoder,
258
+ context=context_tokens,
259
+ past=past,
260
+ length=x_lengths[0,1], # Chunyuan: Fix length; or use <EOS> to complete a sentence
261
+ temperature=args.temperature,
262
+ top_k=args.top_k,
263
+ top_p=args.top_p,
264
+ device=args.device,
265
+ decoder_tokenizer = decoder_tokenizer
266
+ )
267
+ text_x1 = decoder_tokenizer.decode(out[0,:].tolist(), clean_up_tokenization_spaces=True)
268
+ text_x1 = text_x1.split()[1:-1]
269
+ text_x1 = ' '.join(text_x1) + '\n'
270
+ result[text_x0] = text_x1
271
+
272
+ count += 1
273
+ if count>args.total_sents:
274
+ break
275
+
276
+
277
+ return result
278
+
279
+
280
+
281
+
282
+ def calc_interpolate(model_vae, eval_dataloader, encoder_tokenizer, decoder_tokenizer, args, ns=1):
283
+
284
+ count = 0
285
+ latent_codes = []
286
+ sample_interval = 0
287
+ for batch in tqdm(eval_dataloader, desc="Evaluating interpolation"):
288
+ # pdb.set_trace()
289
+ x0, x1, x_lengths = batch
290
+
291
+ max_len_values, _ = x_lengths.max(0)
292
+ x0 = x0[:,:max_len_values[0]]
293
+ x0 = x0.to(args.device)
294
+ x_lengths = x_lengths.to(args.device)
295
+
296
+
297
+ with torch.no_grad():
298
+ if sample_interval == 0 or sample_interval == args.total_sents:
299
+ text_x0 = encoder_tokenizer.decode(x0[0,:x_lengths[0,0]].tolist(), clean_up_tokenization_spaces=True)[0]
300
+ pooled_hidden_fea = model_vae.encoder(x0, attention_mask=(x0 > 0).float())[1]
301
+
302
+ # Connect hidden feature to the latent space
303
+ mean, logvar = model_vae.encoder.linear(pooled_hidden_fea).chunk(2, -1)
304
+ latent_z = mean.squeeze(1)
305
+
306
+ latent_codes.append(latent_z)
307
+
308
+ if sample_interval == 5:
309
+ latent_codes.append(latent_z)
310
+ sample_interval = 0
311
+ continue
312
+ else:
313
+ sample_interval += 1
314
+ continue
315
+
316
+ count += 1
317
+ if count>args.total_sents:
318
+ break
319
+
320
+ context_tokens = decoder_tokenizer.encode('<BOS>')
321
+ result = defaultdict(str)
322
+ latent_codes_interpolation = []
323
+ num_steps = args.num_interpolation_steps
324
+ for step in range(num_steps+1):
325
+ latent_z = latent_codes[0] + (latent_codes[1] - latent_codes[0]) * step * 1.0/num_steps
326
+
327
+ past = latent_z
328
+ out = sample_sequence_conditional(
329
+ model=model_vae.decoder,
330
+ context=context_tokens,
331
+ past=past,
332
+ length=x_lengths[0,1], # Chunyuan: Fix length; or use <EOS> to complete a sentence
333
+ temperature=args.temperature,
334
+ top_k=args.top_k,
335
+ top_p=args.top_p,
336
+ device=args.device,
337
+ decoder_tokenizer = decoder_tokenizer
338
+ )
339
+ text_x1 = decoder_tokenizer.decode(out[0,:].tolist(), clean_up_tokenization_spaces=True)
340
+ text_x1 = text_x1.split()[1:-1]
341
+ text_x1 = ' '.join(text_x1)
342
+ result[step] = text_x1
343
+
344
+ return result
345
+
346
+
347
+
348
+
349
+ def main():
350
+ parser = argparse.ArgumentParser()
351
+
352
+ parser.add_argument("--train_data_file", default=None, type=str, required=True,
353
+ help="The input training data file (a text file).")
354
+ parser.add_argument("--eval_data_file", default=None, type=str,
355
+ help="An input evaluation data file to evaluate the perplexity on (a text file).")
356
+ parser.add_argument("--checkpoint_dir", default=None, type=str, required=True,
357
+ help="The directory where checkpoints are saved.")
358
+ parser.add_argument("--output_dir", default=None, type=str, required=True,
359
+ help="The output directory where the model predictions and checkpoints will be written.")
360
+ parser.add_argument("--dataset", default='Snli', type=str, help="The dataset.")
361
+
362
+ ## Variational auto-encoder
363
+ parser.add_argument("--latent_size", default=32, type=int, help="Latent space dimension.")
364
+ parser.add_argument("--total_sents", default=10, type=int, help="Total sentences to test recontruction.")
365
+ parser.add_argument("--num_interpolation_steps", default=10, type=int, help="Total sentences to test recontruction.")
366
+ parser.add_argument("--play_mode", default="interpolation", type=str,
367
+ help="interpolation or reconstruction.")
368
+
369
+
370
+ ## Encoder options
371
+ parser.add_argument("--encoder_model_type", default="bert", type=str,
372
+ help="The encoder model architecture to be fine-tuned.")
373
+ parser.add_argument("--encoder_model_name_or_path", default="bert-base-cased", type=str,
374
+ help="The encoder model checkpoint for weights initialization.")
375
+ parser.add_argument("--encoder_config_name", default="", type=str,
376
+ help="Optional pretrained config name or path if not the same as model_name_or_path")
377
+ parser.add_argument("--encoder_tokenizer_name", default="", type=str,
378
+ help="Optional pretrained tokenizer name or path if not the same as model_name_or_path")
379
+
380
+ ## Decoder options
381
+ parser.add_argument("--decoder_model_type", default="gpt2", type=str,
382
+ help="The decoder model architecture to be fine-tuned.")
383
+ parser.add_argument("--decoder_model_name_or_path", default="bert-base-cased", type=str,
384
+ help="The decoder model checkpoint for weights initialization.")
385
+ parser.add_argument("--decoder_config_name", default="", type=str,
386
+ help="Optional pretrained config name or path if not the same as model_name_or_path")
387
+ parser.add_argument("--decoder_tokenizer_name", default="", type=str,
388
+ help="Optional pretrained tokenizer name or path if not the same as model_name_or_path")
389
+
390
+
391
+ parser.add_argument("--per_gpu_train_batch_size", default=1, type=int,
392
+ help="Batch size per GPU/CPU for training.")
393
+ parser.add_argument("--per_gpu_eval_batch_size", default=1, type=int,
394
+ help="Batch size per GPU/CPU for evaluation.")
395
+ parser.add_argument('--gloabl_step_eval', type=int, default=661,
396
+ help="Evaluate the results at the given global step")
397
+
398
+ parser.add_argument("--max_seq_length", default=512, type=int,
399
+ help="Optional input sequence length before tokenization. The sequence will be dropped if it is longer the max_seq_length")
400
+
401
+
402
+ ## Variational auto-encoder
403
+ parser.add_argument("--nz", default=32, type=int,
404
+ help="Latent space dimension.")
405
+
406
+ parser.add_argument("--prompt", type=str, default="")
407
+ parser.add_argument("--padding_text", type=str, default="")
408
+ parser.add_argument("--length", type=int, default=20)
409
+ parser.add_argument("--temperature", type=float, default=1.0)
410
+ parser.add_argument("--top_k", type=int, default=0)
411
+ parser.add_argument("--top_p", type=float, default=0.9)
412
+ parser.add_argument("--no_cuda", action='store_true',
413
+ help="Avoid using CUDA when available")
414
+ parser.add_argument('--seed', type=int, default=42,
415
+ help="random seed for initialization")
416
+
417
+ parser.add_argument("--block_size", default=-1, type=int,
418
+ help="Optional input sequence length after tokenization."
419
+ "The training dataset will be truncated in block of this size for training."
420
+ "Default to the model max input length for single sentence inputs (take into account special tokens).")
421
+ parser.add_argument("--do_lower_case", action='store_true',
422
+ help="Set this flag if you are using an uncased model.")
423
+
424
+ parser.add_argument("--use_philly", action='store_true',
425
+ help="Use Philly for computing.")
426
+
427
+ args = parser.parse_args()
428
+
429
+ args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
430
+ args.n_gpu = torch.cuda.device_count()
431
+
432
+ set_seed(args)
433
+
434
+
435
+ args.encoder_model_type = args.encoder_model_type.lower()
436
+ args.decoder_model_type = args.decoder_model_type.lower()
437
+
438
+
439
+ global_step = args.gloabl_step_eval
440
+
441
+ output_encoder_dir = os.path.join(args.checkpoint_dir, 'checkpoint-encoder-{}'.format(global_step))
442
+ output_decoder_dir = os.path.join(args.checkpoint_dir, 'checkpoint-decoder-{}'.format(global_step))
443
+ checkpoints = [ [output_encoder_dir, output_decoder_dir] ]
444
+ logger.info("Evaluate the following checkpoints: %s", checkpoints)
445
+
446
+ # Load a trained Encoder model and vocabulary that you have fine-tuned
447
+ encoder_config_class, encoder_model_class, encoder_tokenizer_class = MODEL_CLASSES[args.encoder_model_type]
448
+ model_encoder = encoder_model_class.from_pretrained(output_encoder_dir, latent_size=args.latent_size)
449
+ tokenizer_encoder = encoder_tokenizer_class.from_pretrained(args.encoder_tokenizer_name if args.encoder_tokenizer_name else args.encoder_model_name_or_path, do_lower_case=args.do_lower_case)
450
+
451
+ model_encoder.to(args.device)
452
+ if args.block_size <= 0:
453
+ args.block_size = tokenizer_encoder.max_len_single_sentence # Our input block size will be the max possible for the model
454
+ args.block_size = min(args.block_size, tokenizer_encoder.max_len_single_sentence)
455
+
456
+ # Load a trained Decoder model and vocabulary that you have fine-tuned
457
+ decoder_config_class, decoder_model_class, decoder_tokenizer_class = MODEL_CLASSES[args.decoder_model_type]
458
+ model_decoder = decoder_model_class.from_pretrained(output_decoder_dir, latent_size=args.latent_size)
459
+ tokenizer_decoder = decoder_tokenizer_class.from_pretrained(args.decoder_tokenizer_name if args.decoder_tokenizer_name else args.decoder_model_name_or_path, do_lower_case=args.do_lower_case)
460
+ model_decoder.to(args.device)
461
+ if args.block_size <= 0:
462
+ args.block_size = tokenizer_decoder.max_len_single_sentence # Our input block size will be the max possible for the model
463
+ args.block_size = min(args.block_size, tokenizer_decoder.max_len_single_sentence)
464
+
465
+ # Load full model
466
+ output_full_dir = os.path.join(args.checkpoint_dir, 'checkpoint-full-{}'.format(global_step))
467
+ checkpoint = torch.load(os.path.join(output_full_dir, 'training.bin'))
468
+
469
+ # Chunyuan: Add Padding token to GPT2
470
+ special_tokens_dict = {'pad_token': '<PAD>', 'bos_token': '<BOS>', 'eos_token': '<EOS>'}
471
+ num_added_toks = tokenizer_decoder.add_special_tokens(special_tokens_dict)
472
+ print('We have added', num_added_toks, 'tokens to GPT2')
473
+ model_decoder.resize_token_embeddings(len(tokenizer_decoder)) # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer.
474
+ assert tokenizer_decoder.pad_token == '<PAD>'
475
+
476
+
477
+ # Evaluation
478
+ model_vae = VAE(model_encoder, model_decoder, tokenizer_encoder, tokenizer_decoder, args)
479
+ model_vae.load_state_dict(checkpoint['model_state_dict'])
480
+ logger.info("Pre-trained Optimus is successfully loaded")
481
+ model_vae.to(args.device)
482
+
483
+ result = evaluate_latent_space(args, model_vae, tokenizer_encoder, tokenizer_decoder, prefix=global_step)
484
+
485
+
486
+ if __name__ == '__main__':
487
+ main()
Optimus/code/examples/big_ae/run_generation_from_prior.py ADDED
@@ -0,0 +1,414 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # coding=utf-8
3
+ # Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team.
4
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ """ Conditional text generation with the auto-regressive models of the library (GPT/GPT-2/Transformer-XL/XLNet)
18
+ """
19
+ from __future__ import absolute_import, division, print_function, unicode_literals
20
+
21
+ import argparse
22
+ import glob
23
+ import logging
24
+ import os
25
+ import pickle
26
+ import random
27
+
28
+
29
+ cwd = os.getcwd()
30
+ print(f"Current working dir is {cwd}")
31
+
32
+ import sys
33
+ sys.path.append('./')
34
+ pt_path = os.path.join( cwd, 'pytorch_transformers')
35
+ sys.path.append(pt_path)
36
+ print(f"Pytorch Transformer {pt_path}")
37
+
38
+ import torch
39
+ import torch.nn.functional as F
40
+ import numpy as np
41
+
42
+ from torch.utils.data import DataLoader, Dataset, SequentialSampler, RandomSampler, TensorDataset
43
+ from torch.utils.data.distributed import DistributedSampler
44
+ from tqdm import tqdm, trange
45
+
46
+
47
+ from pytorch_transformers import GPT2Config, OpenAIGPTConfig, XLNetConfig, TransfoXLConfig, BertConfig
48
+ from pytorch_transformers import GPT2LMHeadModel, GPT2Tokenizer, GPT2ForLatentConnector
49
+ from pytorch_transformers import OpenAIGPTLMHeadModel, OpenAIGPTTokenizer
50
+ from pytorch_transformers import XLNetLMHeadModel, XLNetTokenizer
51
+ from pytorch_transformers import TransfoXLLMHeadModel, TransfoXLTokenizer
52
+ from pytorch_transformers import BertForLatentConnector, BertTokenizer
53
+
54
+ import pytorch_transformers
55
+
56
+ from collections import defaultdict
57
+ from modules import VAE
58
+ from utils import (TextDataset_Split, TextDataset_2Tokenizers, BucketingDataLoader)
59
+ from metrics import Bleu, SelfBleu
60
+
61
+
62
+
63
+ import pdb
64
+
65
+
66
+ logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
67
+ datefmt = '%m/%d/%Y %H:%M:%S',
68
+ level = logging.INFO)
69
+ logger = logging.getLogger(__name__)
70
+
71
+ MAX_LENGTH = int(10000) # Hardcoded max length to avoid infinite loop
72
+
73
+ ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (GPT2Config, OpenAIGPTConfig, XLNetConfig, TransfoXLConfig)), ())
74
+
75
+ MODEL_CLASSES = {
76
+ 'gpt2': (GPT2Config, GPT2ForLatentConnector, GPT2Tokenizer),
77
+ 'bert': (BertConfig, BertForLatentConnector, BertTokenizer)
78
+ }
79
+
80
+ # Padding text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia
81
+ # in https://github.com/rusiaaman/XLNet-gen#methodology
82
+ # and https://medium.com/@amanrusia/xlnet-speaks-comparison-to-gpt-2-ea1a4e9ba39e
83
+ PADDING_TEXT = """ In 1991, the remains of Russian Tsar Nicholas II and his family
84
+ (except for Alexei and Maria) are discovered.
85
+ The voice of Nicholas's young son, Tsarevich Alexei Nikolaevich, narrates the
86
+ remainder of the story. 1883 Western Siberia,
87
+ a young Grigori Rasputin is asked by his father and a group of men to perform magic.
88
+ Rasputin has a vision and denounces one of the men as a horse thief. Although his
89
+ father initially slaps him for making such an accusation, Rasputin watches as the
90
+ man is chased outside and beaten. Twenty years later, Rasputin sees a vision of
91
+ the Virgin Mary, prompting him to become a priest. Rasputin quickly becomes famous,
92
+ with people, even a bishop, begging for his blessing. <eod> </s> <eos>"""
93
+
94
+
95
+ def set_seed(args):
96
+ np.random.seed(args.seed)
97
+ torch.manual_seed(args.seed)
98
+ if args.n_gpu > 0:
99
+ torch.cuda.manual_seed_all(args.seed)
100
+
101
+
102
+ def load_and_cache_examples(args, tokenizer, evaluate=False):
103
+ if isinstance(tokenizer, list):
104
+ dataset = TextDataset_2Tokenizers(tokenizer, args, file_path=args.eval_data_file if evaluate else args.train_data_file, block_size=args.block_size)
105
+ else:
106
+ dataset = TextDataset_Split(tokenizer, args, file_path=args.eval_data_file if evaluate else args.train_data_file, block_size=args.block_size)
107
+ return dataset
108
+
109
+ def build_dataload_and_cache_examples(args, tokenizer, evaluate=False):
110
+ if isinstance(tokenizer, list):
111
+ if not evaluate:
112
+ args.batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
113
+ file_path=args.train_data_file
114
+ else:
115
+ args.batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
116
+ file_path=args.eval_data_file
117
+ dataloader = BucketingDataLoader(file_path, args.batch_size, args.max_seq_length, tokenizer, args, bucket=100, shuffle=False)
118
+ else:
119
+ pass
120
+ return dataloader
121
+
122
+
123
+ def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
124
+ """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
125
+ Args:
126
+ logits: logits distribution shape (vocabulary size)
127
+ top_k > 0: keep only top k tokens with highest probability (top-k filtering).
128
+ top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
129
+ Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
130
+ From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
131
+ """
132
+ assert logits.dim() == 1 # batch size 1 for now - could be updated for more but the code would be less clear
133
+
134
+ # top-k
135
+ top_k = min(top_k, logits.size(-1)) # Safety check
136
+ if top_k > 0:
137
+ # Remove all tokens with a probability less than the last token of the top-k
138
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
139
+ logits[indices_to_remove] = filter_value
140
+
141
+ # top-p
142
+ if top_p > 0.0:
143
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
144
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
145
+
146
+ # Remove tokens with cumulative probability above the threshold
147
+ sorted_indices_to_remove = cumulative_probs > top_p
148
+ # Shift the indices to the right to keep also the first token above the threshold
149
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
150
+ sorted_indices_to_remove[..., 0] = 0
151
+
152
+ indices_to_remove = sorted_indices[sorted_indices_to_remove]
153
+ logits[indices_to_remove] = filter_value
154
+ return logits
155
+
156
+
157
+ def sample_sequence(model, length, context, num_samples=1, temperature=1, top_k=0, top_p=0.0, is_xlnet=False, device='cpu'):
158
+ context = torch.tensor(context, dtype=torch.long, device=device)
159
+ context = context.unsqueeze(0).repeat(num_samples, 1)
160
+ generated = context
161
+ with torch.no_grad():
162
+ for _ in trange(length):
163
+
164
+ inputs = {'input_ids': generated}
165
+ if is_xlnet:
166
+ # XLNet is a direct (predict same token, not next token) and bi-directional model by default
167
+ # => need one additional dummy token in the input (will be masked), attention mask and target mapping (see model docstring)
168
+ input_ids = torch.cat((generated, torch.zeros((1, 1), dtype=torch.long, device=device)), dim=1)
169
+ perm_mask = torch.zeros((1, input_ids.shape[1], input_ids.shape[1]), dtype=torch.float, device=device)
170
+ perm_mask[:, :, -1] = 1.0 # Previous tokens don't see last token
171
+ target_mapping = torch.zeros((1, 1, input_ids.shape[1]), dtype=torch.float, device=device)
172
+ target_mapping[0, 0, -1] = 1.0 # predict last token
173
+ inputs = {'input_ids': input_ids, 'perm_mask': perm_mask, 'target_mapping': target_mapping}
174
+
175
+ outputs = model(**inputs) # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet (cached hidden-states)
176
+ next_token_logits = outputs[0][0, -1, :] / temperature
177
+ filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
178
+ next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
179
+ generated = torch.cat((generated, next_token.unsqueeze(0)), dim=1)
180
+ return generated
181
+
182
+ def sample_sequence_conditional(model, length, context, past=None, num_samples=1, temperature=1, top_k=0, top_p=0.0, device='cpu', decoder_tokenizer=None, max_seq_length=-1):
183
+
184
+ context = torch.tensor(context, dtype=torch.long, device=device)
185
+ context = context.unsqueeze(0).repeat(num_samples, 1)
186
+ generated = context
187
+ gen_seq_length = 0
188
+ with torch.no_grad():
189
+ while True:
190
+ inputs = {'input_ids': generated, 'past': past}
191
+ outputs = model(**inputs) # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet (cached hidden-states)
192
+ next_token_logits = outputs[0][0, -1, :] / temperature
193
+ filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
194
+ next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
195
+ generated = torch.cat((generated, next_token.unsqueeze(0)), dim=1)
196
+ gen_seq_length += 1
197
+ # pdb.set_trace()
198
+ if next_token.unsqueeze(0)[0,0].item() == decoder_tokenizer.encode('<EOS>')[0]:
199
+ break
200
+ if max_seq_length>0 and gen_seq_length>max_seq_length:
201
+ break
202
+
203
+ return generated
204
+
205
+
206
+ def evaluate_generation_fromp_prior(model_vae, decoder_tokenizer, args, ns=1):
207
+
208
+ loc = torch.zeros([args.nz]).to(args.device)
209
+ scale = torch.ones([args.nz]).to(args.device)
210
+ prior = torch.distributions.normal.Normal(loc, scale)
211
+
212
+ context_tokens = decoder_tokenizer.encode('<BOS>')
213
+
214
+ count = 0
215
+ result = defaultdict(str)
216
+ for i in tqdm(range(args.num_sents)):
217
+
218
+ with torch.no_grad():
219
+ latent_z = prior.sample()
220
+ # pdb.set_trace()
221
+ past = model_vae.decoder.linear(latent_z.unsqueeze(0))
222
+
223
+ # pdb.set_trace()
224
+ out = sample_sequence_conditional(
225
+ model=model_vae.decoder,
226
+ context=context_tokens,
227
+ past=past,
228
+ length=args.max_seq_length, # Chunyuan: Fix length; or use <EOS> to complete a sentence
229
+ temperature=args.temperature,
230
+ top_k=args.top_k,
231
+ top_p=args.top_p,
232
+ device=args.device,
233
+ decoder_tokenizer = decoder_tokenizer,
234
+ max_seq_length = args.max_seq_length
235
+ )
236
+ text_x1 = decoder_tokenizer.decode(out[0,:].tolist(), clean_up_tokenization_spaces=True)
237
+ text_x1 = text_x1.split()[1:-1]
238
+ text_x1 = ' '.join(text_x1) + '\n'
239
+ result[i] = text_x1
240
+
241
+ if args.use_philly:
242
+ print("PROGRESS: {}%".format( round(100 * i /args.num_sents , 4)))
243
+
244
+ with open(args.output_generation_file, "w") as writer:
245
+ logger.info("***** SHOW generated sentences from prior *****")
246
+ for key in sorted(result.keys()):
247
+ # logger.info(" %s \n %s", key, str(result[key]))
248
+ # writer.write("%s \n %s\n" % (key, str(result[key])))
249
+ writer.write("%s" % str(result[key]))
250
+
251
+ return result
252
+
253
+
254
+ # bleu = evaluate_bleu(results, args)
255
+
256
+
257
+
258
+
259
+
260
+
261
+ def main():
262
+ parser = argparse.ArgumentParser()
263
+
264
+ parser.add_argument("--train_data_file", default=None, type=str, required=True,
265
+ help="The input training data file (a text file).")
266
+ parser.add_argument("--eval_data_file", default=None, type=str,
267
+ help="An input evaluation data file to evaluate the perplexity on (a text file).")
268
+ parser.add_argument("--checkpoint_dir", default=None, type=str, required=True,
269
+ help="The directory where checkpoints are saved.")
270
+ parser.add_argument("--output_dir", default=None, type=str, required=True,
271
+ help="The output directory where the model predictions and checkpoints will be written.")
272
+ parser.add_argument("--dataset", default='Snli', type=str, help="The dataset.")
273
+
274
+ ## Variational auto-encoder
275
+ parser.add_argument("--latent_size", default=32, type=int, help="Latent space dimension.")
276
+ parser.add_argument("--total_sents", default=10, type=int, help="Total sentences to test recontruction.")
277
+ parser.add_argument("--num_sents", default=10, type=int, help="Total sentences to generate.")
278
+
279
+
280
+ ## Encoder options
281
+ parser.add_argument("--encoder_model_type", default="bert", type=str,
282
+ help="The encoder model architecture to be fine-tuned.")
283
+ parser.add_argument("--encoder_model_name_or_path", default="bert-base-cased", type=str,
284
+ help="The encoder model checkpoint for weights initialization.")
285
+ parser.add_argument("--encoder_config_name", default="", type=str,
286
+ help="Optional pretrained config name or path if not the same as model_name_or_path")
287
+ parser.add_argument("--encoder_tokenizer_name", default="", type=str,
288
+ help="Optional pretrained tokenizer name or path if not the same as model_name_or_path")
289
+
290
+ ## Decoder options
291
+ parser.add_argument("--decoder_model_type", default="gpt2", type=str,
292
+ help="The decoder model architecture to be fine-tuned.")
293
+ parser.add_argument("--decoder_model_name_or_path", default="bert-base-cased", type=str,
294
+ help="The decoder model checkpoint for weights initialization.")
295
+ parser.add_argument("--decoder_config_name", default="", type=str,
296
+ help="Optional pretrained config name or path if not the same as model_name_or_path")
297
+ parser.add_argument("--decoder_tokenizer_name", default="", type=str,
298
+ help="Optional pretrained tokenizer name or path if not the same as model_name_or_path")
299
+
300
+
301
+ parser.add_argument("--per_gpu_train_batch_size", default=1, type=int,
302
+ help="Batch size per GPU/CPU for training.")
303
+ parser.add_argument("--per_gpu_eval_batch_size", default=1, type=int,
304
+ help="Batch size per GPU/CPU for evaluation.")
305
+ parser.add_argument('--gloabl_step_eval', type=int, default=661,
306
+ help="Evaluate the results at the given global step")
307
+
308
+ parser.add_argument("--max_seq_length", default=512, type=int,
309
+ help="Optional input sequence length before tokenization. The sequence will be dropped if it is longer the max_seq_length")
310
+
311
+
312
+ ## Variational auto-encoder
313
+ parser.add_argument("--nz", default=32, type=int,
314
+ help="Latent space dimension.")
315
+
316
+ parser.add_argument("--prompt", type=str, default="")
317
+ parser.add_argument("--padding_text", type=str, default="")
318
+ parser.add_argument("--length", type=int, default=20)
319
+ parser.add_argument("--temperature", type=float, default=1.0)
320
+ parser.add_argument("--top_k", type=int, default=0)
321
+ parser.add_argument("--top_p", type=float, default=0.9)
322
+ parser.add_argument("--no_cuda", action='store_true',
323
+ help="Avoid using CUDA when available")
324
+ parser.add_argument('--seed', type=int, default=42,
325
+ help="random seed for initialization")
326
+
327
+ parser.add_argument("--block_size", default=-1, type=int,
328
+ help="Optional input sequence length after tokenization."
329
+ "The training dataset will be truncated in block of this size for training."
330
+ "Default to the model max input length for single sentence inputs (take into account special tokens).")
331
+ parser.add_argument("--do_lower_case", action='store_true',
332
+ help="Set this flag if you are using an uncased model.")
333
+
334
+ parser.add_argument("--use_philly", action='store_true',
335
+ help="Use Philly for computing.")
336
+
337
+ args = parser.parse_args()
338
+
339
+ args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
340
+ args.n_gpu = torch.cuda.device_count()
341
+
342
+ set_seed(args)
343
+
344
+
345
+ args.encoder_model_type = args.encoder_model_type.lower()
346
+ args.decoder_model_type = args.decoder_model_type.lower()
347
+
348
+
349
+ global_step = args.gloabl_step_eval
350
+
351
+ output_encoder_dir = os.path.join(args.checkpoint_dir, 'checkpoint-encoder-{}'.format(global_step))
352
+ output_decoder_dir = os.path.join(args.checkpoint_dir, 'checkpoint-decoder-{}'.format(global_step))
353
+ checkpoints = [ [output_encoder_dir, output_decoder_dir] ]
354
+ logger.info("Evaluate the following checkpoints: %s", checkpoints)
355
+
356
+ # Load a trained Encoder model and vocabulary that you have fine-tuned
357
+ encoder_config_class, encoder_model_class, encoder_tokenizer_class = MODEL_CLASSES[args.encoder_model_type]
358
+ model_encoder = encoder_model_class.from_pretrained(output_encoder_dir, latent_size=args.latent_size)
359
+ tokenizer_encoder = encoder_tokenizer_class.from_pretrained(args.encoder_tokenizer_name if args.encoder_tokenizer_name else args.encoder_model_name_or_path, do_lower_case=args.do_lower_case)
360
+
361
+ model_encoder.to(args.device)
362
+ if args.block_size <= 0:
363
+ args.block_size = tokenizer_encoder.max_len_single_sentence # Our input block size will be the max possible for the model
364
+ args.block_size = min(args.block_size, tokenizer_encoder.max_len_single_sentence)
365
+
366
+ # Load a trained Decoder model and vocabulary that you have fine-tuned
367
+ decoder_config_class, decoder_model_class, decoder_tokenizer_class = MODEL_CLASSES[args.decoder_model_type]
368
+ model_decoder = decoder_model_class.from_pretrained(output_decoder_dir, latent_size=args.latent_size)
369
+ tokenizer_decoder = decoder_tokenizer_class.from_pretrained(args.decoder_tokenizer_name if args.decoder_tokenizer_name else args.decoder_model_name_or_path, do_lower_case=args.do_lower_case)
370
+ model_decoder.to(args.device)
371
+ if args.block_size <= 0:
372
+ args.block_size = tokenizer_decoder.max_len_single_sentence # Our input block size will be the max possible for the model
373
+ args.block_size = min(args.block_size, tokenizer_decoder.max_len_single_sentence)
374
+
375
+ # pdb.set_trace()
376
+ # Chunyuan: Add Padding token to GPT2
377
+ special_tokens_dict = {'pad_token': '<PAD>', 'bos_token': '<BOS>', 'eos_token': '<EOS>'}
378
+ num_added_toks = tokenizer_decoder.add_special_tokens(special_tokens_dict)
379
+ print('We have added', num_added_toks, 'tokens to GPT2')
380
+ model_decoder.resize_token_embeddings(len(tokenizer_decoder)) # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer.
381
+ assert tokenizer_decoder.pad_token == '<PAD>'
382
+
383
+
384
+ # Evaluation
385
+ model_vae = VAE(model_encoder, model_decoder, tokenizer_encoder, tokenizer_decoder, args).to(args.device)
386
+
387
+ if not os.path.exists(args.output_dir): os.makedirs(args.output_dir)
388
+ args.output_generation_file = os.path.join(args.output_dir, f"generation_from_vae_prior_t{args.temperature}_p{args.top_p}.txt")
389
+ # args.output_generation_file = args.train_data_file
390
+ result = evaluate_generation_fromp_prior(model_vae, tokenizer_decoder, args)
391
+
392
+
393
+ bleu5 = Bleu(test_text= args.output_generation_file,
394
+ real_text=args.eval_data_file,
395
+ num_real_sentences=args.num_sents,
396
+ num_fake_sentences=args.num_sents,
397
+ gram=5).get_score()
398
+ logger.info(f'The bleu score is {bleu5}')
399
+
400
+ sbleu5 = SelfBleu(test_text= args.output_generation_file,
401
+ num_sentences=args.num_sents,
402
+ gram=5).get_score()
403
+ logger.info(f'The self-bleu score is {sbleu5}')
404
+
405
+ args.eval_results_file = os.path.join(args.output_dir, f"eval_results_t{args.temperature}_p{args.top_p}.txt")
406
+ eval_results = {'bleu5':bleu5 , 'sbleu5':sbleu5}
407
+ with open(args.eval_results_file, "w") as writer:
408
+ logger.info("***** SHOW the quantative evalution results *****")
409
+ for key in sorted(eval_results.keys()):
410
+ writer.write("%s %s" % (key, str(eval_results[key])) )
411
+
412
+
413
+ if __name__ == '__main__':
414
+ main()
Optimus/code/examples/big_ae/run_gpt2_generation.py ADDED
@@ -0,0 +1,390 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # coding=utf-8
3
+ # Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team.
4
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ """ Conditional text generation with the auto-regressive models of the library (GPT/GPT-2/Transformer-XL/XLNet)
18
+ """
19
+ from __future__ import absolute_import, division, print_function, unicode_literals
20
+
21
+ import argparse
22
+ import glob
23
+ import logging
24
+ import os
25
+ import pickle
26
+ import random
27
+
28
+
29
+ cwd = os.getcwd()
30
+ print(f"Current working dir is {cwd}")
31
+
32
+ import sys
33
+ sys.path.append('./')
34
+ pt_path = os.path.join( cwd, 'pytorch_transformers')
35
+ sys.path.append(pt_path)
36
+ print(f"Pytorch Transformer {pt_path}")
37
+
38
+ import torch
39
+ import torch.nn.functional as F
40
+ import numpy as np
41
+
42
+ from torch.utils.data import DataLoader, Dataset, SequentialSampler, RandomSampler, TensorDataset
43
+ from torch.utils.data.distributed import DistributedSampler
44
+ from tqdm import tqdm, trange
45
+
46
+
47
+ from pytorch_transformers import GPT2Config, OpenAIGPTConfig, XLNetConfig, TransfoXLConfig, BertConfig
48
+ from pytorch_transformers import GPT2LMHeadModel, GPT2Tokenizer, GPT2ForLatentConnector
49
+ from pytorch_transformers import OpenAIGPTLMHeadModel, OpenAIGPTTokenizer
50
+ from pytorch_transformers import XLNetLMHeadModel, XLNetTokenizer
51
+ from pytorch_transformers import TransfoXLLMHeadModel, TransfoXLTokenizer
52
+ from pytorch_transformers import BertForLatentConnector, BertTokenizer
53
+
54
+ import pytorch_transformers
55
+
56
+ from collections import defaultdict
57
+ from modules import VAE
58
+ from utils import (TextDataset_Split, TextDataset_2Tokenizers, BucketingDataLoader)
59
+ from metrics import Bleu, SelfBleu
60
+
61
+
62
+
63
+ import pdb
64
+
65
+
66
+ logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
67
+ datefmt = '%m/%d/%Y %H:%M:%S',
68
+ level = logging.INFO)
69
+ logger = logging.getLogger(__name__)
70
+
71
+ MAX_LENGTH = int(10000) # Hardcoded max length to avoid infinite loop
72
+
73
+ ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (GPT2Config, OpenAIGPTConfig, XLNetConfig, TransfoXLConfig)), ())
74
+
75
+ MODEL_CLASSES = {
76
+ 'gpt2': (GPT2Config, GPT2LMHeadModel, GPT2Tokenizer),
77
+ 'bert': (BertConfig, BertForLatentConnector, BertTokenizer)
78
+ }
79
+
80
+ # Padding text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia
81
+ # in https://github.com/rusiaaman/XLNet-gen#methodology
82
+ # and https://medium.com/@amanrusia/xlnet-speaks-comparison-to-gpt-2-ea1a4e9ba39e
83
+ PADDING_TEXT = """ In 1991, the remains of Russian Tsar Nicholas II and his family
84
+ (except for Alexei and Maria) are discovered.
85
+ The voice of Nicholas's young son, Tsarevich Alexei Nikolaevich, narrates the
86
+ remainder of the story. 1883 Western Siberia,
87
+ a young Grigori Rasputin is asked by his father and a group of men to perform magic.
88
+ Rasputin has a vision and denounces one of the men as a horse thief. Although his
89
+ father initially slaps him for making such an accusation, Rasputin watches as the
90
+ man is chased outside and beaten. Twenty years later, Rasputin sees a vision of
91
+ the Virgin Mary, prompting him to become a priest. Rasputin quickly becomes famous,
92
+ with people, even a bishop, begging for his blessing. <eod> </s> <eos>"""
93
+
94
+
95
+ def set_seed(args):
96
+ np.random.seed(args.seed)
97
+ torch.manual_seed(args.seed)
98
+ if args.n_gpu > 0:
99
+ torch.cuda.manual_seed_all(args.seed)
100
+
101
+
102
+ def load_and_cache_examples(args, tokenizer, evaluate=False):
103
+ if isinstance(tokenizer, list):
104
+ dataset = TextDataset_2Tokenizers(tokenizer, args, file_path=args.eval_data_file if evaluate else args.train_data_file, block_size=args.block_size)
105
+ else:
106
+ dataset = TextDataset_Split(tokenizer, args, file_path=args.eval_data_file if evaluate else args.train_data_file, block_size=args.block_size)
107
+ return dataset
108
+
109
+ def build_dataload_and_cache_examples(args, tokenizer, evaluate=False):
110
+ if isinstance(tokenizer, list):
111
+ if not evaluate:
112
+ args.batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
113
+ file_path=args.train_data_file
114
+ else:
115
+ args.batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
116
+ file_path=args.eval_data_file
117
+ dataloader = BucketingDataLoader(file_path, args.batch_size, args.max_seq_length, tokenizer, args, bucket=100, shuffle=False)
118
+ else:
119
+ pass
120
+ return dataloader
121
+
122
+
123
+ def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
124
+ """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
125
+ Args:
126
+ logits: logits distribution shape (vocabulary size)
127
+ top_k > 0: keep only top k tokens with highest probability (top-k filtering).
128
+ top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
129
+ Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
130
+ From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
131
+ """
132
+ assert logits.dim() == 1 # batch size 1 for now - could be updated for more but the code would be less clear
133
+
134
+ # top-k
135
+ top_k = min(top_k, logits.size(-1)) # Safety check
136
+ if top_k > 0:
137
+ # Remove all tokens with a probability less than the last token of the top-k
138
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
139
+ logits[indices_to_remove] = filter_value
140
+
141
+ # top-p
142
+ if top_p > 0.0:
143
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
144
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
145
+
146
+ # Remove tokens with cumulative probability above the threshold
147
+ sorted_indices_to_remove = cumulative_probs > top_p
148
+ # Shift the indices to the right to keep also the first token above the threshold
149
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
150
+ sorted_indices_to_remove[..., 0] = 0
151
+
152
+ indices_to_remove = sorted_indices[sorted_indices_to_remove]
153
+ logits[indices_to_remove] = filter_value
154
+ return logits
155
+
156
+
157
+ def sample_sequence(model, length, context, num_samples=1, temperature=1, top_k=0, top_p=0.0, is_xlnet=False, device='cpu', decoder_tokenizer=None, max_seq_length=-1):
158
+ context = torch.tensor(context, dtype=torch.long, device=device)
159
+ context = context.unsqueeze(0).repeat(num_samples, 1)
160
+ generated = context
161
+ gen_seq_length = 0
162
+ with torch.no_grad():
163
+ while True:
164
+
165
+ inputs = {'input_ids': generated}
166
+ outputs = model(**inputs) # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet (cached hidden-states)
167
+ next_token_logits = outputs[0][0, -1, :] / temperature
168
+ filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
169
+ next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
170
+ generated = torch.cat((generated, next_token.unsqueeze(0)), dim=1)
171
+ gen_seq_length += 1
172
+ if next_token.unsqueeze(0)[0,0].item() == decoder_tokenizer.encode('<EOS>')[0]:
173
+ break
174
+ if max_seq_length>0 and gen_seq_length>max_seq_length:
175
+ break
176
+
177
+
178
+ return generated
179
+
180
+ def sample_sequence_conditional(model, length, context, past=None, num_samples=1, temperature=1, top_k=0, top_p=0.0, device='cpu', decoder_tokenizer=None, max_seq_length=-1):
181
+
182
+ context = torch.tensor(context, dtype=torch.long, device=device)
183
+ context = context.unsqueeze(0).repeat(num_samples, 1)
184
+ generated = context
185
+ gen_seq_length = 0
186
+ with torch.no_grad():
187
+ while True:
188
+ inputs = {'input_ids': generated, 'past': past}
189
+ outputs = model(**inputs) # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet (cached hidden-states)
190
+ next_token_logits = outputs[0][0, -1, :] / temperature
191
+ filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
192
+ next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
193
+ generated = torch.cat((generated, next_token.unsqueeze(0)), dim=1)
194
+ gen_seq_length += 1
195
+ # pdb.set_trace()
196
+ if next_token.unsqueeze(0)[0,0].item() == decoder_tokenizer.encode('<EOS>')[0]:
197
+ break
198
+ if max_seq_length>0 and gen_seq_length>max_seq_length:
199
+ break
200
+
201
+ return generated
202
+
203
+
204
+ def evaluate_generation_from_gpt2(model, decoder_tokenizer, args, ns=1):
205
+
206
+ loc = torch.zeros([args.nz]).to(args.device)
207
+ scale = torch.ones([args.nz]).to(args.device)
208
+ prior = torch.distributions.normal.Normal(loc, scale)
209
+
210
+ context_tokens = decoder_tokenizer.encode('<BOS>')
211
+
212
+ count = 0
213
+ result = defaultdict(str)
214
+ for i in tqdm(range(args.num_sents)):
215
+
216
+ with torch.no_grad():
217
+
218
+ out = sample_sequence(
219
+ model=model,
220
+ context=context_tokens,
221
+ length=args.max_seq_length, # Chunyuan: Fix length; or use <EOS> to complete a sentence
222
+ temperature=args.temperature,
223
+ top_k=args.top_k,
224
+ top_p=args.top_p,
225
+ device=args.device,
226
+ decoder_tokenizer = decoder_tokenizer,
227
+ max_seq_length = args.max_seq_length
228
+ )
229
+ text_x1 = decoder_tokenizer.decode(out[0,:].tolist(), clean_up_tokenization_spaces=True)
230
+ text_x1 = text_x1.split()[1:-1]
231
+ text_x1 = ' '.join(text_x1) + '\n'
232
+ result[i] = text_x1
233
+
234
+ if args.use_philly:
235
+ print("PROGRESS: {}%".format( round(100 * i /args.num_sents , 4)))
236
+
237
+ with open(args.output_generation_file, "w") as writer:
238
+ logger.info("***** SHOW generated sentences from prior *****")
239
+ for key in sorted(result.keys()):
240
+ # logger.info(" %s \n %s", key, str(result[key]))
241
+ # writer.write("%s \n %s\n" % (key, str(result[key])))
242
+ writer.write("%s" % str(result[key]))
243
+
244
+ return result
245
+
246
+
247
+ # bleu = evaluate_bleu(results, args)
248
+
249
+
250
+
251
+
252
+
253
+
254
+ def main():
255
+ parser = argparse.ArgumentParser()
256
+
257
+ parser.add_argument("--train_data_file", default=None, type=str, required=True,
258
+ help="The input training data file (a text file).")
259
+ parser.add_argument("--eval_data_file", default=None, type=str,
260
+ help="An input evaluation data file to evaluate the perplexity on (a text file).")
261
+ parser.add_argument("--checkpoint_dir", default=None, type=str, required=True,
262
+ help="The directory where checkpoints are saved.")
263
+ parser.add_argument("--output_dir", default=None, type=str, required=True,
264
+ help="The output directory where the model predictions and checkpoints will be written.")
265
+ parser.add_argument("--dataset", default='Snli', type=str, help="The dataset.")
266
+
267
+ ## Variational auto-encoder
268
+ parser.add_argument("--latent_size", default=32, type=int, help="Latent space dimension.")
269
+ parser.add_argument("--total_sents", default=10, type=int, help="Total sentences to test recontruction.")
270
+ parser.add_argument("--num_sents", default=10, type=int, help="Total sentences to generate.")
271
+
272
+
273
+ ## Encoder options
274
+ parser.add_argument("--encoder_model_type", default="bert", type=str,
275
+ help="The encoder model architecture to be fine-tuned.")
276
+ parser.add_argument("--encoder_model_name_or_path", default="bert-base-cased", type=str,
277
+ help="The encoder model checkpoint for weights initialization.")
278
+ parser.add_argument("--encoder_config_name", default="", type=str,
279
+ help="Optional pretrained config name or path if not the same as model_name_or_path")
280
+ parser.add_argument("--encoder_tokenizer_name", default="", type=str,
281
+ help="Optional pretrained tokenizer name or path if not the same as model_name_or_path")
282
+
283
+ ## Decoder options
284
+ parser.add_argument("--decoder_model_type", default="gpt2", type=str,
285
+ help="The decoder model architecture to be fine-tuned.")
286
+ parser.add_argument("--decoder_model_name_or_path", default="bert-base-cased", type=str,
287
+ help="The decoder model checkpoint for weights initialization.")
288
+ parser.add_argument("--decoder_config_name", default="", type=str,
289
+ help="Optional pretrained config name or path if not the same as model_name_or_path")
290
+ parser.add_argument("--decoder_tokenizer_name", default="", type=str,
291
+ help="Optional pretrained tokenizer name or path if not the same as model_name_or_path")
292
+
293
+
294
+ parser.add_argument("--per_gpu_train_batch_size", default=1, type=int,
295
+ help="Batch size per GPU/CPU for training.")
296
+ parser.add_argument("--per_gpu_eval_batch_size", default=1, type=int,
297
+ help="Batch size per GPU/CPU for evaluation.")
298
+ parser.add_argument('--gloabl_step_eval', type=int, default=661,
299
+ help="Evaluate the results at the given global step")
300
+
301
+ parser.add_argument("--max_seq_length", default=512, type=int,
302
+ help="Optional input sequence length before tokenization. The sequence will be dropped if it is longer the max_seq_length")
303
+
304
+
305
+ ## Variational auto-encoder
306
+ parser.add_argument("--nz", default=32, type=int,
307
+ help="Latent space dimension.")
308
+
309
+ parser.add_argument("--prompt", type=str, default="")
310
+ parser.add_argument("--padding_text", type=str, default="")
311
+ parser.add_argument("--length", type=int, default=20)
312
+ parser.add_argument("--temperature", type=float, default=1.0)
313
+ parser.add_argument("--top_k", type=int, default=0)
314
+ parser.add_argument("--top_p", type=float, default=0.9)
315
+ parser.add_argument("--no_cuda", action='store_true',
316
+ help="Avoid using CUDA when available")
317
+ parser.add_argument('--seed', type=int, default=42,
318
+ help="random seed for initialization")
319
+
320
+ parser.add_argument("--block_size", default=-1, type=int,
321
+ help="Optional input sequence length after tokenization."
322
+ "The training dataset will be truncated in block of this size for training."
323
+ "Default to the model max input length for single sentence inputs (take into account special tokens).")
324
+ parser.add_argument("--do_lower_case", action='store_true',
325
+ help="Set this flag if you are using an uncased model.")
326
+
327
+ parser.add_argument("--use_philly", action='store_true',
328
+ help="Use Philly for computing.")
329
+
330
+ args = parser.parse_args()
331
+
332
+ args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
333
+ args.n_gpu = torch.cuda.device_count()
334
+
335
+ set_seed(args)
336
+ args.decoder_model_type = args.decoder_model_type.lower()
337
+
338
+
339
+ global_step = args.gloabl_step_eval
340
+
341
+ output_decoder_dir = os.path.join(args.checkpoint_dir, 'checkpoint-{}'.format(global_step))
342
+ checkpoints = [ output_decoder_dir ]
343
+ logger.info("Evaluate the following checkpoints: %s", checkpoints)
344
+
345
+ # Load a trained Decoder model and vocabulary that you have fine-tuned
346
+ decoder_config_class, decoder_model_class, decoder_tokenizer_class = MODEL_CLASSES[args.decoder_model_type]
347
+ model_decoder = decoder_model_class.from_pretrained(output_decoder_dir)
348
+ tokenizer_decoder = decoder_tokenizer_class.from_pretrained(args.decoder_tokenizer_name if args.decoder_tokenizer_name else args.decoder_model_name_or_path, do_lower_case=args.do_lower_case)
349
+ model_decoder.to(args.device)
350
+ if args.block_size <= 0:
351
+ args.block_size = tokenizer_decoder.max_len_single_sentence # Our input block size will be the max possible for the model
352
+ args.block_size = min(args.block_size, tokenizer_decoder.max_len_single_sentence)
353
+
354
+ # pdb.set_trace()
355
+ # Chunyuan: Add Padding token to GPT2
356
+ special_tokens_dict = {'pad_token': '<PAD>', 'bos_token': '<BOS>', 'eos_token': '<EOS>'}
357
+ num_added_toks = tokenizer_decoder.add_special_tokens(special_tokens_dict)
358
+ print('We have added', num_added_toks, 'tokens to GPT2')
359
+ model_decoder.resize_token_embeddings(len(tokenizer_decoder)) # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer.
360
+ assert tokenizer_decoder.pad_token == '<PAD>'
361
+
362
+
363
+ # Evaluation
364
+ if not os.path.exists(args.output_dir): os.makedirs(args.output_dir)
365
+ args.output_generation_file = os.path.join(args.output_dir, f"generation_from_gpt2_t{args.temperature}_p{args.top_p}.txt")
366
+ # args.output_generation_file = args.train_data_file
367
+ result = evaluate_generation_from_gpt2(model_decoder, tokenizer_decoder, args)
368
+
369
+ bleu5 = Bleu(test_text= args.output_generation_file,
370
+ real_text=args.eval_data_file,
371
+ num_real_sentences=args.num_sents,
372
+ num_fake_sentences=args.num_sents,
373
+ gram=5).get_score()
374
+ logger.info(f'The bleu score is {bleu5}')
375
+
376
+ sbleu5 = SelfBleu(test_text= args.output_generation_file,
377
+ num_sentences=args.num_sents,
378
+ gram=5).get_score()
379
+ logger.info(f'The self-bleu score is {sbleu5}')
380
+
381
+ args.eval_results_file = os.path.join(args.output_dir, f"eval_results_t{args.temperature}_p{args.top_p}.txt")
382
+ eval_results = {'bleu5':bleu5 , 'sbleu5':sbleu5}
383
+ with open(args.eval_results_file, "w") as writer:
384
+ logger.info("***** SHOW the quantative evalution results *****")
385
+ for key in sorted(eval_results.keys()):
386
+ writer.write("%s %s" % (key, str(eval_results[key])) )
387
+
388
+
389
+ if __name__ == '__main__':
390
+ main()
Optimus/code/examples/big_ae/run_latent_generation.py ADDED
@@ -0,0 +1,577 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # coding=utf-8
3
+ # Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team.
4
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ """ Conditional text generation with the auto-regressive models of the library (GPT/GPT-2/Transformer-XL/XLNet)
18
+ """
19
+ from __future__ import absolute_import, division, print_function, unicode_literals
20
+
21
+ import argparse
22
+ import glob
23
+ import logging
24
+ import os
25
+ import pickle
26
+ import random
27
+
28
+
29
+ import torch
30
+ import torch.nn.functional as F
31
+ import numpy as np
32
+
33
+ from torch.utils.data import DataLoader, Dataset, SequentialSampler, RandomSampler, TensorDataset
34
+ from torch.utils.data.distributed import DistributedSampler
35
+ from tqdm import tqdm, trange
36
+
37
+
38
+ from pytorch_transformers import GPT2Config, OpenAIGPTConfig, XLNetConfig, TransfoXLConfig, BertConfig
39
+ from pytorch_transformers import GPT2LMHeadModel, GPT2Tokenizer, GPT2ForLatentConnector
40
+ from pytorch_transformers import OpenAIGPTLMHeadModel, OpenAIGPTTokenizer
41
+ from pytorch_transformers import XLNetLMHeadModel, XLNetTokenizer
42
+ from pytorch_transformers import TransfoXLLMHeadModel, TransfoXLTokenizer
43
+ from pytorch_transformers import BertForLatentConnector, BertTokenizer
44
+
45
+ from collections import defaultdict
46
+ from modules import VAE
47
+ from utils import (TextDataset_Split, TextDataset_2Tokenizers, BucketingDataLoader)
48
+
49
+
50
+ import pdb
51
+
52
+
53
+ logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
54
+ datefmt = '%m/%d/%Y %H:%M:%S',
55
+ level = logging.INFO)
56
+ logger = logging.getLogger(__name__)
57
+
58
+ MAX_LENGTH = int(10000) # Hardcoded max length to avoid infinite loop
59
+
60
+ ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (GPT2Config, OpenAIGPTConfig, XLNetConfig, TransfoXLConfig)), ())
61
+
62
+ MODEL_CLASSES = {
63
+ 'gpt2': (GPT2Config, GPT2ForLatentConnector, GPT2Tokenizer),
64
+ 'bert': (BertConfig, BertForLatentConnector, BertTokenizer)
65
+ }
66
+
67
+ # Padding text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia
68
+ # in https://github.com/rusiaaman/XLNet-gen#methodology
69
+ # and https://medium.com/@amanrusia/xlnet-speaks-comparison-to-gpt-2-ea1a4e9ba39e
70
+ PADDING_TEXT = """ In 1991, the remains of Russian Tsar Nicholas II and his family
71
+ (except for Alexei and Maria) are discovered.
72
+ The voice of Nicholas's young son, Tsarevich Alexei Nikolaevich, narrates the
73
+ remainder of the story. 1883 Western Siberia,
74
+ a young Grigori Rasputin is asked by his father and a group of men to perform magic.
75
+ Rasputin has a vision and denounces one of the men as a horse thief. Although his
76
+ father initially slaps him for making such an accusation, Rasputin watches as the
77
+ man is chased outside and beaten. Twenty years later, Rasputin sees a vision of
78
+ the Virgin Mary, prompting him to become a priest. Rasputin quickly becomes famous,
79
+ with people, even a bishop, begging for his blessing. <eod> </s> <eos>"""
80
+
81
+
82
+ def set_seed(args):
83
+ np.random.seed(args.seed)
84
+ torch.manual_seed(args.seed)
85
+ if args.n_gpu > 0:
86
+ torch.cuda.manual_seed_all(args.seed)
87
+
88
+
89
+ def load_and_cache_examples(args, tokenizer, evaluate=False):
90
+ if isinstance(tokenizer, list):
91
+ dataset = TextDataset_2Tokenizers(tokenizer, args, file_path=args.eval_data_file if evaluate else args.train_data_file, block_size=args.block_size)
92
+ else:
93
+ dataset = TextDataset_Split(tokenizer, args, file_path=args.eval_data_file if evaluate else args.train_data_file, block_size=args.block_size)
94
+ return dataset
95
+
96
+ def build_dataload_and_cache_examples(args, tokenizer, evaluate=False):
97
+ if isinstance(tokenizer, list):
98
+ if not evaluate:
99
+ args.batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
100
+ file_path=args.train_data_file
101
+ else:
102
+ args.batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
103
+ file_path=args.eval_data_file
104
+ dataloader = BucketingDataLoader(file_path, args.batch_size, args.max_seq_length, tokenizer, args, bucket=100, shuffle=False)
105
+ else:
106
+ pass
107
+ return dataloader
108
+
109
+
110
+ def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
111
+ """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
112
+ Args:
113
+ logits: logits distribution shape (vocabulary size)
114
+ top_k > 0: keep only top k tokens with highest probability (top-k filtering).
115
+ top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
116
+ Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
117
+ From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
118
+ """
119
+ assert logits.dim() == 1 # batch size 1 for now - could be updated for more but the code would be less clear
120
+ top_k = min(top_k, logits.size(-1)) # Safety check
121
+ if top_k > 0:
122
+ # Remove all tokens with a probability less than the last token of the top-k
123
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
124
+ logits[indices_to_remove] = filter_value
125
+
126
+ if top_p > 0.0:
127
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
128
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
129
+
130
+ # Remove tokens with cumulative probability above the threshold
131
+ sorted_indices_to_remove = cumulative_probs > top_p
132
+ # Shift the indices to the right to keep also the first token above the threshold
133
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
134
+ sorted_indices_to_remove[..., 0] = 0
135
+
136
+ indices_to_remove = sorted_indices[sorted_indices_to_remove]
137
+ logits[indices_to_remove] = filter_value
138
+ return logits
139
+
140
+
141
+ def sample_sequence(model, length, context, num_samples=1, temperature=1, top_k=0, top_p=0.0, is_xlnet=False, device='cpu'):
142
+ context = torch.tensor(context, dtype=torch.long, device=device)
143
+ context = context.unsqueeze(0).repeat(num_samples, 1)
144
+ generated = context
145
+ with torch.no_grad():
146
+ for _ in trange(length):
147
+
148
+ inputs = {'input_ids': generated}
149
+ if is_xlnet:
150
+ # XLNet is a direct (predict same token, not next token) and bi-directional model by default
151
+ # => need one additional dummy token in the input (will be masked), attention mask and target mapping (see model docstring)
152
+ input_ids = torch.cat((generated, torch.zeros((1, 1), dtype=torch.long, device=device)), dim=1)
153
+ perm_mask = torch.zeros((1, input_ids.shape[1], input_ids.shape[1]), dtype=torch.float, device=device)
154
+ perm_mask[:, :, -1] = 1.0 # Previous tokens don't see last token
155
+ target_mapping = torch.zeros((1, 1, input_ids.shape[1]), dtype=torch.float, device=device)
156
+ target_mapping[0, 0, -1] = 1.0 # predict last token
157
+ inputs = {'input_ids': input_ids, 'perm_mask': perm_mask, 'target_mapping': target_mapping}
158
+
159
+ outputs = model(**inputs) # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet (cached hidden-states)
160
+ next_token_logits = outputs[0][0, -1, :] / temperature
161
+ filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
162
+ next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
163
+ generated = torch.cat((generated, next_token.unsqueeze(0)), dim=1)
164
+ return generated
165
+
166
+ def sample_sequence_conditional(model, length, context, past=None, num_samples=1, temperature=1, top_k=0, top_p=0.0, device='cpu', decoder_tokenizer=None):
167
+
168
+ context = torch.tensor(context, dtype=torch.long, device=device)
169
+ context = context.unsqueeze(0).repeat(num_samples, 1)
170
+ generated = context
171
+ with torch.no_grad():
172
+ while True:
173
+ # for _ in trange(length):
174
+ inputs = {'input_ids': generated, 'past': past}
175
+ outputs = model(**inputs) # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet (cached hidden-states)
176
+ next_token_logits = outputs[0][0, -1, :] / temperature
177
+ filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
178
+ next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
179
+ generated = torch.cat((generated, next_token.unsqueeze(0)), dim=1)
180
+
181
+ # pdb.set_trace()
182
+ if next_token.unsqueeze(0)[0,0].item() == decoder_tokenizer.encode('<EOS>')[0]:
183
+ break
184
+
185
+ return generated
186
+
187
+
188
+ def latent_code_from_text(text, tokenizer_encoder, model_vae, args):
189
+ tokenized1 = tokenizer_encoder.encode(text)
190
+ tokenized1 = [101] + tokenized1 + [102]
191
+ coded1 = torch.Tensor([tokenized1])
192
+ coded1 =torch.Tensor.long(coded1)
193
+ with torch.no_grad():
194
+ x0 = coded1
195
+ x0 = x0.to(args.device)
196
+ pooled_hidden_fea = model_vae.encoder(x0, attention_mask=(x0 > 0).float())[1]
197
+ mean, logvar = model_vae.encoder.linear(pooled_hidden_fea).chunk(2, -1)
198
+ latent_z = mean.squeeze(1)
199
+ coded_length = len(tokenized1)
200
+ return latent_z, coded_length
201
+
202
+ def text_from_latent_code(latent_z, model_vae, args, tokenizer_decoder):
203
+ past = latent_z
204
+ context_tokens = tokenizer_decoder.encode('<BOS>')
205
+
206
+ length = 128 # maximum length, but not used
207
+ out = sample_sequence_conditional(
208
+ model=model_vae.decoder,
209
+ context=context_tokens,
210
+ past=past,
211
+ length= length, # Chunyuan: Fix length; or use <EOS> to complete a sentence
212
+ temperature=args.temperature,
213
+ top_k=args.top_k,
214
+ top_p=args.top_p,
215
+ device=args.device,
216
+ decoder_tokenizer = tokenizer_decoder
217
+ )
218
+ text_x1 = tokenizer_decoder.decode(out[0,:].tolist(), clean_up_tokenization_spaces=True)
219
+ text_x1 = text_x1.split()[1:-1]
220
+ text_x1 = ' '.join(text_x1)
221
+ return text_x1
222
+
223
+
224
+ # a wrapper function to choose between different play modes
225
+ def evaluate_latent_space(args, model_vae, encoder_tokenizer, decoder_tokenizer, prefix=""):
226
+
227
+ eval_dataloader = build_dataload_and_cache_examples(args, [encoder_tokenizer, decoder_tokenizer], evaluate=False)
228
+
229
+ # Eval!
230
+ logger.info("***** Running recontruction evaluation {} *****".format(prefix))
231
+ logger.info(" Num examples = %d", len(eval_dataloader))
232
+ logger.info(" Batch size = %d", args.per_gpu_eval_batch_size)
233
+
234
+ model_vae.eval()
235
+
236
+ model_vae = model_vae.module if hasattr(model_vae, 'module') else model_vae # Take care of distributed/parallel training
237
+
238
+ if args.play_mode == 'reconstrction':
239
+ result = calc_rec(model_vae, eval_dataloader, encoder_tokenizer, decoder_tokenizer, args, ns=100)
240
+ result_file_name = "eval_recontruction_results.txt"
241
+ elif args.play_mode == 'interpolation':
242
+ result = calc_interpolate(model_vae, eval_dataloader, encoder_tokenizer, decoder_tokenizer, args, ns=100)
243
+ result_file_name = "eval_interpolation_results.txt"
244
+ else:
245
+ logger.info("Please specify the corrent play mode [reconstrction, interpolation]")
246
+
247
+
248
+ eval_output_dir = args.output_dir
249
+ output_eval_file = os.path.join(eval_output_dir, result_file_name)
250
+
251
+ with open(output_eval_file, "w") as writer:
252
+ logger.info("***** Eval {} results *****".format(args.play_mode))
253
+ for key in sorted(result.keys()):
254
+ logger.info(" %s \n %s", key, str(result[key]))
255
+ writer.write("%s \n %s\n" % (key, str(result[key])))
256
+
257
+ return result
258
+
259
+
260
+ def calc_rec(model_vae, eval_dataloader, encoder_tokenizer, decoder_tokenizer, args, ns=1):
261
+
262
+ count = 0
263
+ result = defaultdict(str)
264
+ for batch in tqdm(eval_dataloader, desc="Evaluating recontruction"):
265
+ # pdb.set_trace()
266
+ x0, x1, x_lengths = batch
267
+
268
+ max_len_values, _ = x_lengths.max(0)
269
+ x0 = x0[:,:max_len_values[0]]
270
+ x1 = x1[:,:max_len_values[1]]
271
+
272
+ x0 = x0.to(args.device)
273
+ x1 = x1.to(args.device)
274
+ x_lengths = x_lengths.to(args.device)
275
+
276
+ context_tokens = decoder_tokenizer.encode('<BOS>')
277
+
278
+ with torch.no_grad():
279
+
280
+ text_x0 = encoder_tokenizer.decode(x0[0,:x_lengths[0,0]].tolist(), clean_up_tokenization_spaces=True)[0]
281
+ # result["INPUT TEXT " + str(count)].append(text_x0)
282
+
283
+ pooled_hidden_fea = model_vae.encoder(x0, attention_mask=(x0 > 0).float())[1]
284
+
285
+ # Connect hidden feature to the latent space
286
+ # latent_z, loss_kl = model_vae.connect(pooled_hidden_fea)
287
+ mean, logvar = model_vae.encoder.linear(pooled_hidden_fea).chunk(2, -1)
288
+ latent_z = mean.squeeze(1)
289
+
290
+ past = latent_z
291
+ out = sample_sequence_conditional(
292
+ model=model_vae.decoder,
293
+ context=context_tokens,
294
+ past=past,
295
+ length=x_lengths[0,1], # Chunyuan: Fix length; or use <EOS> to complete a sentence
296
+ temperature=args.temperature,
297
+ top_k=args.top_k,
298
+ top_p=args.top_p,
299
+ device=args.device,
300
+ decoder_tokenizer = decoder_tokenizer
301
+ )
302
+ text_x1 = decoder_tokenizer.decode(out[0,:].tolist(), clean_up_tokenization_spaces=True)
303
+ text_x1 = text_x1.split()[1:-1]
304
+ text_x1 = ' '.join(text_x1) + '\n'
305
+ result[text_x0] = text_x1
306
+
307
+ count += 1
308
+ if count>args.total_sents:
309
+ break
310
+
311
+
312
+ return result
313
+
314
+
315
+
316
+
317
+ def calc_interpolate(model_vae, eval_dataloader, encoder_tokenizer, decoder_tokenizer, args, ns=1):
318
+
319
+ count = 0
320
+ latent_codes = []
321
+ sample_interval = 0
322
+ for batch in tqdm(eval_dataloader, desc="Evaluating interpolation"):
323
+ # pdb.set_trace()
324
+ x0, x1, x_lengths = batch
325
+
326
+ max_len_values, _ = x_lengths.max(0)
327
+ x0 = x0[:,:max_len_values[0]]
328
+ x0 = x0.to(args.device)
329
+ x_lengths = x_lengths.to(args.device)
330
+
331
+
332
+ with torch.no_grad():
333
+ if sample_interval == 0 or sample_interval == args.total_sents:
334
+ text_x0 = encoder_tokenizer.decode(x0[0,:x_lengths[0,0]].tolist(), clean_up_tokenization_spaces=True)[0]
335
+ pooled_hidden_fea = model_vae.encoder(x0, attention_mask=(x0 > 0).float())[1]
336
+
337
+ # Connect hidden feature to the latent space
338
+ mean, logvar = model_vae.encoder.linear(pooled_hidden_fea).chunk(2, -1)
339
+ latent_z = mean.squeeze(1)
340
+
341
+ latent_codes.append(latent_z)
342
+
343
+ if sample_interval == 5:
344
+ latent_codes.append(latent_z)
345
+ sample_interval = 0
346
+ continue
347
+ else:
348
+ sample_interval += 1
349
+ continue
350
+
351
+ count += 1
352
+ if count>args.total_sents:
353
+ break
354
+
355
+ context_tokens = decoder_tokenizer.encode('<BOS>')
356
+ result = defaultdict(str)
357
+ latent_codes_interpolation = []
358
+ num_steps = args.num_interpolation_steps
359
+ for step in range(num_steps+1):
360
+ latent_z = latent_codes[0] + (latent_codes[1] - latent_codes[0]) * step * 1.0/num_steps
361
+
362
+ past = latent_z
363
+ out = sample_sequence_conditional(
364
+ model=model_vae.decoder,
365
+ context=context_tokens,
366
+ past=past,
367
+ length=x_lengths[0,1], # Chunyuan: Fix length; or use <EOS> to complete a sentence
368
+ temperature=args.temperature,
369
+ top_k=args.top_k,
370
+ top_p=args.top_p,
371
+ device=args.device,
372
+ decoder_tokenizer = decoder_tokenizer
373
+ )
374
+ text_x1 = decoder_tokenizer.decode(out[0,:].tolist(), clean_up_tokenization_spaces=True)
375
+ text_x1 = text_x1.split()[1:-1]
376
+ text_x1 = ' '.join(text_x1)
377
+ result[step] = text_x1
378
+
379
+ return result
380
+
381
+
382
+ def interpolate(model_vae, tokenizer_encoder, tokenizer_decoder, args):
383
+ # and then in the main function
384
+ latent_z1, coded_length1 = latent_code_from_text(args.sent_source, tokenizer_encoder, model_vae, args)
385
+ latent_z2, coded_length2 = latent_code_from_text(args.sent_target, tokenizer_encoder, model_vae, args)
386
+
387
+ result = defaultdict(str)
388
+
389
+ num_steps = args.num_interpolation_steps + 1
390
+ for step in range(num_steps+1):
391
+ latent_z = latent_z1 + (latent_z2 - latent_z1) * step * 1.0/num_steps
392
+
393
+ text_interpolate = text_from_latent_code(latent_z, model_vae, args, tokenizer_decoder)
394
+ result[step] = text_interpolate
395
+ print(text_interpolate)
396
+
397
+ return result
398
+
399
+
400
+ def analogy(model_vae, tokenizer_encoder, tokenizer_decoder, args):
401
+
402
+ latent_z1, coded_length1 = latent_code_from_text(args.sent_source, tokenizer_encoder, model_vae, args)
403
+ latent_z2, coded_length2 = latent_code_from_text(args.sent_target, tokenizer_encoder, model_vae, args)
404
+ latent_z3, coded_length3 = latent_code_from_text(args.sent_input, tokenizer_encoder, model_vae, args)
405
+
406
+ result = defaultdict(str)
407
+
408
+ latent_z = latent_z3 + args.degree_to_target * (latent_z2 - latent_z1)
409
+
410
+ text_analogy = text_from_latent_code(latent_z, model_vae, args, tokenizer_decoder)
411
+ result[0] = text_analogy
412
+ print(text_analogy)
413
+
414
+ return result
415
+
416
+
417
+ def main():
418
+ parser = argparse.ArgumentParser()
419
+
420
+ parser.add_argument("--train_data_file", default=None, type=str, required=True,
421
+ help="The input training data file (a text file).")
422
+ parser.add_argument("--eval_data_file", default=None, type=str,
423
+ help="An input evaluation data file to evaluate the perplexity on (a text file).")
424
+ parser.add_argument("--checkpoint_dir", default=None, type=str, required=True,
425
+ help="The directory where checkpoints are saved.")
426
+ parser.add_argument("--output_dir", default=None, type=str, required=True,
427
+ help="The output directory where the model predictions and checkpoints will be written.")
428
+ parser.add_argument("--dataset", default='Snli', type=str, help="The dataset.")
429
+
430
+ ## Variational auto-encoder
431
+ parser.add_argument("--latent_size", default=32, type=int, help="Latent space dimension.")
432
+ parser.add_argument("--total_sents", default=10, type=int, help="Total sentences to test recontruction.")
433
+ parser.add_argument("--num_interpolation_steps", default=10, type=int, help="Total sentences to test recontruction.")
434
+ parser.add_argument("--play_mode", default="interpolation", type=str,
435
+ help="interpolation or reconstruction.")
436
+
437
+
438
+ ## Encoder options
439
+ parser.add_argument("--encoder_model_type", default="bert", type=str,
440
+ help="The encoder model architecture to be fine-tuned.")
441
+ parser.add_argument("--encoder_model_name_or_path", default="bert-base-cased", type=str,
442
+ help="The encoder model checkpoint for weights initialization.")
443
+ parser.add_argument("--encoder_config_name", default="", type=str,
444
+ help="Optional pretrained config name or path if not the same as model_name_or_path")
445
+ parser.add_argument("--encoder_tokenizer_name", default="", type=str,
446
+ help="Optional pretrained tokenizer name or path if not the same as model_name_or_path")
447
+
448
+ ## Decoder options
449
+ parser.add_argument("--decoder_model_type", default="gpt2", type=str,
450
+ help="The decoder model architecture to be fine-tuned.")
451
+ parser.add_argument("--decoder_model_name_or_path", default="bert-base-cased", type=str,
452
+ help="The decoder model checkpoint for weights initialization.")
453
+ parser.add_argument("--decoder_config_name", default="", type=str,
454
+ help="Optional pretrained config name or path if not the same as model_name_or_path")
455
+ parser.add_argument("--decoder_tokenizer_name", default="", type=str,
456
+ help="Optional pretrained tokenizer name or path if not the same as model_name_or_path")
457
+
458
+
459
+ parser.add_argument("--per_gpu_train_batch_size", default=1, type=int,
460
+ help="Batch size per GPU/CPU for training.")
461
+ parser.add_argument("--per_gpu_eval_batch_size", default=1, type=int,
462
+ help="Batch size per GPU/CPU for evaluation.")
463
+ parser.add_argument('--gloabl_step_eval', type=int, default=661,
464
+ help="Evaluate the results at the given global step")
465
+
466
+ parser.add_argument("--max_seq_length", default=512, type=int,
467
+ help="Optional input sequence length before tokenization. The sequence will be dropped if it is longer the max_seq_length")
468
+
469
+ # Interact with users
470
+ parser.add_argument("--interact_with_user_input", action='store_true', help="Use user input to interact_with.")
471
+ parser.add_argument("--sent_source", type=str, default="")
472
+ parser.add_argument("--sent_target", type=str, default="")
473
+ parser.add_argument("--sent_input", type=str, default="")
474
+ parser.add_argument("--degree_to_target", type=float, default="1.0")
475
+
476
+ ## Variational auto-encoder
477
+ parser.add_argument("--nz", default=32, type=int,
478
+ help="Latent space dimension.")
479
+
480
+ parser.add_argument("--prompt", type=str, default="")
481
+ parser.add_argument("--padding_text", type=str, default="")
482
+ parser.add_argument("--length", type=int, default=20)
483
+ parser.add_argument("--temperature", type=float, default=1.0)
484
+ parser.add_argument("--top_k", type=int, default=0)
485
+ parser.add_argument("--top_p", type=float, default=1.0)
486
+ parser.add_argument("--no_cuda", action='store_true',
487
+ help="Avoid using CUDA when available")
488
+ parser.add_argument('--seed', type=int, default=42,
489
+ help="random seed for initialization")
490
+
491
+ parser.add_argument("--block_size", default=-1, type=int,
492
+ help="Optional input sequence length after tokenization."
493
+ "The training dataset will be truncated in block of this size for training."
494
+ "Default to the model max input length for single sentence inputs (take into account special tokens).")
495
+ parser.add_argument("--do_lower_case", action='store_true',
496
+ help="Set this flag if you are using an uncased model.")
497
+
498
+ parser.add_argument("--use_philly", action='store_true',
499
+ help="Use Philly for computing.")
500
+
501
+ args = parser.parse_args()
502
+
503
+ args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
504
+ args.n_gpu = torch.cuda.device_count()
505
+
506
+ set_seed(args)
507
+
508
+
509
+ args.encoder_model_type = args.encoder_model_type.lower()
510
+ args.decoder_model_type = args.decoder_model_type.lower()
511
+
512
+
513
+ global_step = args.gloabl_step_eval
514
+
515
+ output_encoder_dir = os.path.join(args.checkpoint_dir, 'checkpoint-encoder-{}'.format(global_step))
516
+ output_decoder_dir = os.path.join(args.checkpoint_dir, 'checkpoint-decoder-{}'.format(global_step))
517
+ checkpoints = [ [output_encoder_dir, output_decoder_dir] ]
518
+ logger.info("Evaluate the following checkpoints: %s", checkpoints)
519
+
520
+ # Load a trained Encoder model and vocabulary that you have fine-tuned
521
+ encoder_config_class, encoder_model_class, encoder_tokenizer_class = MODEL_CLASSES[args.encoder_model_type]
522
+ model_encoder = encoder_model_class.from_pretrained(output_encoder_dir, latent_size=args.latent_size)
523
+ tokenizer_encoder = encoder_tokenizer_class.from_pretrained(args.encoder_tokenizer_name if args.encoder_tokenizer_name else args.encoder_model_name_or_path, do_lower_case=args.do_lower_case)
524
+
525
+ model_encoder.to(args.device)
526
+ if args.block_size <= 0:
527
+ args.block_size = tokenizer_encoder.max_len_single_sentence # Our input block size will be the max possible for the model
528
+ args.block_size = min(args.block_size, tokenizer_encoder.max_len_single_sentence)
529
+
530
+ # Load a trained Decoder model and vocabulary that you have fine-tuned
531
+ decoder_config_class, decoder_model_class, decoder_tokenizer_class = MODEL_CLASSES[args.decoder_model_type]
532
+ model_decoder = decoder_model_class.from_pretrained(output_decoder_dir, latent_size=args.latent_size)
533
+ tokenizer_decoder = decoder_tokenizer_class.from_pretrained(args.decoder_tokenizer_name if args.decoder_tokenizer_name else args.decoder_model_name_or_path, do_lower_case=args.do_lower_case)
534
+ model_decoder.to(args.device)
535
+ if args.block_size <= 0:
536
+ args.block_size = tokenizer_decoder.max_len_single_sentence # Our input block size will be the max possible for the model
537
+ args.block_size = min(args.block_size, tokenizer_decoder.max_len_single_sentence)
538
+
539
+ # Load full model
540
+ output_full_dir = os.path.join(args.checkpoint_dir, 'checkpoint-full-{}'.format(global_step))
541
+ checkpoint = torch.load(os.path.join(output_full_dir, 'training.bin'))
542
+
543
+ # Chunyuan: Add Padding token to GPT2
544
+ special_tokens_dict = {'pad_token': '<PAD>', 'bos_token': '<BOS>', 'eos_token': '<EOS>'}
545
+ num_added_toks = tokenizer_decoder.add_special_tokens(special_tokens_dict)
546
+ print('We have added', num_added_toks, 'tokens to GPT2')
547
+ model_decoder.resize_token_embeddings(len(tokenizer_decoder)) # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer.
548
+ assert tokenizer_decoder.pad_token == '<PAD>'
549
+
550
+
551
+ # Evaluation
552
+ model_vae = VAE(model_encoder, model_decoder, tokenizer_encoder, tokenizer_decoder, args)
553
+ model_vae.load_state_dict(checkpoint['model_state_dict'])
554
+ logger.info("Pre-trained Optimus is successfully loaded")
555
+ model_vae.to(args.device)
556
+
557
+ if args.interact_with_user_input:
558
+
559
+ if args.play_mode == 'interpolation':
560
+ if len(args.sent_source) > 0 and len(args.sent_source) > 0:
561
+ result = interpolate(model_vae, tokenizer_encoder, tokenizer_decoder, args)
562
+ else:
563
+ print('Please check: specify the source and target sentences!')
564
+
565
+ if args.play_mode == 'analogy':
566
+ if len(args.sent_source) > 0 and len(args.sent_source) > 0 and len(args.sent_input) > 0:
567
+ result = analogy(model_vae, tokenizer_encoder, tokenizer_decoder, args)
568
+ else:
569
+ print('Please check: specify the source, target and input analogy sentences!')
570
+
571
+
572
+ else:
573
+ result = evaluate_latent_space(args, model_vae, tokenizer_encoder, tokenizer_decoder, prefix=global_step)
574
+
575
+
576
+ if __name__ == '__main__':
577
+ main()
Optimus/code/examples/big_ae/run_lm_ae_pretraining.py ADDED
@@ -0,0 +1,692 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Fine-tuning the library models for language modeling on a text file (GPT, GPT-2, BERT, RoBERTa).
18
+ GPT and GPT-2 are fine-tuned using a causal language modeling (CLM) loss while BERT and RoBERTa are fine-tuned
19
+ using a masked language modeling (MLM) loss.
20
+ """
21
+
22
+ from __future__ import absolute_import, division, print_function
23
+
24
+
25
+ import pdb
26
+ import argparse
27
+ import glob
28
+ import logging
29
+ import os
30
+ import pickle
31
+ import random
32
+
33
+ import numpy as np
34
+ import torch
35
+ from torch.utils.data import DataLoader, Dataset, SequentialSampler, RandomSampler, TensorDataset
36
+ from torch.utils.data.distributed import DistributedSampler
37
+ from tensorboardX import SummaryWriter
38
+ from tqdm import tqdm, trange
39
+
40
+ from pytorch_transformers import (WEIGHTS_NAME, AdamW, WarmupLinearSchedule,
41
+ BertConfig, BertModel, BertTokenizer,
42
+ GPT2Config, GPT2LMHeadModel, GPT2Tokenizer,
43
+ OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer,
44
+ RobertaConfig, RobertaForMaskedLM, RobertaTokenizer)
45
+
46
+
47
+ logger = logging.getLogger(__name__)
48
+
49
+
50
+ MODEL_CLASSES = {
51
+ 'gpt2': (GPT2Config, GPT2LMHeadModel, GPT2Tokenizer),
52
+ 'openai-gpt': (OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer),
53
+ 'bert': (BertConfig, BertModel, BertTokenizer),
54
+ 'roberta': (RobertaConfig, RobertaForMaskedLM, RobertaTokenizer)
55
+ }
56
+
57
+
58
+ class TextDataset(Dataset):
59
+ def __init__(self, tokenizer, file_path='train', block_size=512):
60
+ assert os.path.isfile(file_path)
61
+ directory, filename = os.path.split(file_path)
62
+ cached_features_file = os.path.join(directory, f'cached_lm_{block_size}_{filename}')
63
+
64
+ if os.path.exists(cached_features_file):
65
+ logger.info("Loading features from cached file %s", cached_features_file)
66
+ with open(cached_features_file, 'rb') as handle:
67
+ self.examples = pickle.load(handle)
68
+ else:
69
+ logger.info("Creating features from dataset file at %s", directory)
70
+
71
+ self.examples = []
72
+ with open(file_path, encoding="utf-8") as f:
73
+ text = f.read()
74
+
75
+
76
+ tokenized_text = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text))
77
+
78
+ while len(tokenized_text) >= block_size: # Truncate in block of block_size
79
+ self.examples.append(tokenizer.add_special_tokens_single_sentence(tokenized_text[:block_size]))
80
+ tokenized_text = tokenized_text[block_size:]
81
+ # Note that we are loosing the last truncated example here for the sake of simplicity (no padding)
82
+ # If your dataset is small, first you should loook for a bigger one :-) and second you
83
+ # can change this behavior by adding (model specific) padding.
84
+
85
+ logger.info("Saving features into cached file %s", cached_features_file)
86
+ with open(cached_features_file, 'wb') as handle:
87
+ pickle.dump(self.examples, handle, protocol=pickle.HIGHEST_PROTOCOL)
88
+
89
+ def __len__(self):
90
+ return len(self.examples)
91
+
92
+ def __getitem__(self, item):
93
+ return torch.tensor(self.examples[item])
94
+
95
+
96
+
97
+ class TextDataset_2Tokenizers(Dataset):
98
+ def __init__(self, tokenizers, file_path='train', block_size=512):
99
+ assert os.path.isfile(file_path)
100
+ directory, filename = os.path.split(file_path)
101
+ cached_features_file = os.path.join(directory, f'cached_lm_gpt_bert_{block_size}_{filename}')
102
+
103
+
104
+
105
+ if os.path.exists(cached_features_file):
106
+ logger.info("Loading features from cached file %s", cached_features_file)
107
+ with open(cached_features_file, 'rb') as handle:
108
+ self.examples = pickle.load(handle)
109
+ else:
110
+ logger.info("Creating features from dataset file at %s", directory)
111
+
112
+
113
+ with open(file_path, encoding="utf-8") as f:
114
+ text = f.read()
115
+
116
+ # pdb.set_trace()
117
+ self.examples = []
118
+ # Chunyuan: divide the linguistic text into the same length, then different tokenization schemes are applied
119
+ while len(text) >= block_size: # Truncate in block of block_size
120
+
121
+ tokenized_text0 = tokenizers[0].convert_tokens_to_ids(tokenizers[0].tokenize(text[:block_size]))
122
+ tokenized_text0 = tokenizers[0].add_special_tokens_single_sentence(tokenized_text0)
123
+ tokenized_text0_length = len(tokenized_text0)
124
+ pad_token=tokenizers[0].convert_tokens_to_ids([tokenizers[0].pad_token])[0]
125
+ tokenized_text0 = tokenized_text0 + ([pad_token] * (block_size - tokenized_text0_length) ) # Pad up to the sequence length.
126
+ assert len(tokenized_text0) == block_size
127
+
128
+ tokenized_text1 = tokenizers[1].convert_tokens_to_ids(tokenizers[1].tokenize(text[:block_size]))
129
+ tokenized_text1 = tokenizers[1].add_special_tokens_single_sentence(tokenized_text1)
130
+ tokenized_text1_length = len(tokenized_text1)
131
+ pad_token=tokenizers[1].convert_tokens_to_ids([tokenizers[1].pad_token])[0]
132
+ tokenized_text1 = tokenized_text1 + ([pad_token] * (block_size - tokenized_text1_length) ) # Pad up to the sequence length.
133
+ assert len(tokenized_text1) == block_size
134
+
135
+ self.examples.append([tokenized_text0, tokenized_text0_length, tokenized_text1, tokenized_text1_length])
136
+
137
+ text = text[block_size:]
138
+ # Note that we are loosing the last truncated example here for the sake of simplicity (no padding)
139
+ # If your dataset is small, first you should loook for a bigger one :-) and second you
140
+ # can change this behavior by adding (model specific) padding.
141
+
142
+ logger.info("Saving features into cached file %s", cached_features_file)
143
+ with open(cached_features_file, 'wb') as handle:
144
+ pickle.dump(self.examples, handle, protocol=pickle.HIGHEST_PROTOCOL)
145
+
146
+ def __len__(self):
147
+ return len(self.examples)
148
+
149
+ def __getitem__(self, item):
150
+ # pdb.set_trace()
151
+ # Convert to Tensors and build dataset
152
+ tokenized_text0= torch.tensor(self.examples[item][0], dtype=torch.long)
153
+ tokenized_text1= torch.tensor(self.examples[item][2], dtype=torch.long)
154
+ tokenized_text_lengths = torch.tensor([self.examples[item][1], self.examples[item][3]], dtype=torch.long)
155
+ # pdb.set_trace()
156
+ return (tokenized_text0, tokenized_text1, tokenized_text_lengths)
157
+
158
+ def load_and_cache_examples(args, tokenizer, evaluate=False):
159
+ if isinstance(tokenizer, list):
160
+ dataset = TextDataset_2Tokenizers(tokenizer, file_path=args.eval_data_file if evaluate else args.train_data_file, block_size=args.block_size)
161
+ else:
162
+ dataset = TextDataset(tokenizer, file_path=args.eval_data_file if evaluate else args.train_data_file, block_size=args.block_size)
163
+ return dataset
164
+
165
+
166
+ def set_seed(args):
167
+ random.seed(args.seed)
168
+ np.random.seed(args.seed)
169
+ torch.manual_seed(args.seed)
170
+ if args.n_gpu > 0:
171
+ torch.cuda.manual_seed_all(args.seed)
172
+
173
+
174
+ def mask_tokens(inputs, tokenizer, args):
175
+ """ Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. """
176
+ labels = inputs.clone()
177
+ # We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
178
+
179
+ masked_indices = torch.bernoulli(torch.full(labels.shape, args.mlm_probability)).to(torch.uint8)
180
+ labels[masked_indices==1] = -1 # We only compute loss on masked tokens
181
+
182
+ # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
183
+ indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).to(torch.uint8) & masked_indices
184
+ inputs[indices_replaced] = tokenizer.convert_tokens_to_ids(tokenizer.mask_token)
185
+
186
+ # 10% of the time, we replace masked input tokens with random word
187
+ indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).to(torch.uint8) & masked_indices & ~indices_replaced
188
+ indices_random = indices_random
189
+ random_words = torch.randint(len(tokenizer), labels.shape, dtype=torch.long)
190
+ inputs[indices_random] = random_words[indices_random]
191
+
192
+ # The rest of the time (10% of the time) we keep the masked input tokens unchanged
193
+ return inputs, labels
194
+
195
+
196
+ def train(args, train_dataset, model_encoder, model_decoder, encoder_tokenizer, decoder_tokenizer):
197
+ """ Train the model """
198
+ if args.local_rank in [-1, 0]:
199
+ tb_writer = SummaryWriter()
200
+
201
+ args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
202
+ train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
203
+ train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
204
+
205
+ if args.max_steps > 0:
206
+ t_total = args.max_steps
207
+ args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
208
+ else:
209
+ t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
210
+
211
+ # Prepare optimizer and schedule (linear warmup and decay)
212
+ no_decay = ['bias', 'LayerNorm.weight']
213
+ optimizer_grouped_encoder_parameters = [
214
+ {'params': [p for n, p in model_encoder.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
215
+ {'params': [p for n, p in model_encoder.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
216
+ ]
217
+
218
+ optimizer_grouped_decoder_parameters = [
219
+ {'params': [p for n, p in model_decoder.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
220
+ {'params': [p for n, p in model_decoder.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
221
+ ]
222
+
223
+
224
+ optimizer_encoder = AdamW(optimizer_grouped_encoder_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
225
+ optimizer_decoder = AdamW(optimizer_grouped_decoder_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
226
+ scheduler_encoder = WarmupLinearSchedule(optimizer_encoder, warmup_steps=args.warmup_steps, t_total=t_total)
227
+ scheduler_decoder = WarmupLinearSchedule(optimizer_decoder, warmup_steps=args.warmup_steps, t_total=t_total)
228
+
229
+ if args.fp16:
230
+ try:
231
+ from apex import amp
232
+ except ImportError:
233
+ raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
234
+ model_encoder, optimizer_encoder = amp.initialize(model_encoder, optimizer_encoder, opt_level=args.fp16_opt_level)
235
+ model_decoder, optimizer_decoder = amp.initialize(model_decoder, optimizer_decoder, opt_level=args.fp16_opt_level)
236
+
237
+ # multi-gpu training (should be after apex fp16 initialization)
238
+ if args.n_gpu > 1:
239
+ model_encoder = torch.nn.DataParallel(model_encoder)
240
+ model_decoder = torch.nn.DataParallel(model_decoder)
241
+
242
+ # Distributed training (should be after apex fp16 initialization)
243
+ if args.local_rank != -1:
244
+ model_encoder = torch.nn.parallel.DistributedDataParallel(model_encoder, device_ids=[args.local_rank],
245
+ output_device=args.local_rank,
246
+ find_unused_parameters=True)
247
+ model_decoder = torch.nn.parallel.DistributedDataParallel(model_decoder, device_ids=[args.local_rank],
248
+ output_device=args.local_rank,
249
+ find_unused_parameters=True)
250
+
251
+ # Train!
252
+ logger.info("***** Running training *****")
253
+ logger.info(" Num examples = %d", len(train_dataset))
254
+ logger.info(" Num Epochs = %d", args.num_train_epochs)
255
+ logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
256
+ logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d",
257
+ args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
258
+ logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
259
+ logger.info(" Total optimization steps = %d", t_total)
260
+
261
+ global_step = 0
262
+ tr_loss, logging_loss = 0.0, 0.0
263
+ model_encoder.zero_grad()
264
+ model_decoder.zero_grad()
265
+
266
+ train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0])
267
+ set_seed(args) # Added here for reproducibility (even between python 2 and 3)
268
+ for _ in train_iterator:
269
+ epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
270
+ for step, batch in enumerate(epoch_iterator):
271
+
272
+ tokenized_text0, tokenized_text1, tokenized_text_lengths = batch
273
+ # tokenized_text0 = tokenized_text0.to(args.device)
274
+ # tokenized_text1 = tokenized_text1.to(args.device)
275
+ # prepare input-output data for reconstruction
276
+ inputs, labels = mask_tokens(tokenized_text0, encoder_tokenizer, args) if args.mlm else (tokenized_text0, tokenized_text1)
277
+ labels = tokenized_text1
278
+
279
+ inputs = inputs.to(args.device)
280
+ labels = labels.to(args.device)
281
+
282
+ model_encoder.train()
283
+ model_decoder.train()
284
+
285
+
286
+ # Encoding
287
+ outputs = model_encoder(inputs)
288
+ pooled_hidden_fea = outputs[1] # model outputs are always tuple in pytorch-transformers (see doc)
289
+
290
+
291
+ # Decoding
292
+ outputs = model_decoder(input_ids=tokenized_text1, past=pooled_hidden_fea, labels=labels)
293
+ loss = outputs[0] # model outputs are always tuple in pytorch-transformers (see doc)
294
+
295
+
296
+ if args.n_gpu > 1:
297
+ loss = loss.mean() # mean() to average on multi-gpu parallel training
298
+ if args.gradient_accumulation_steps > 1:
299
+ loss = loss / args.gradient_accumulation_steps
300
+
301
+ if args.fp16:
302
+ with amp.scale_loss(loss, optimizer) as scaled_loss:
303
+ scaled_loss.backward()
304
+ else:
305
+ loss.backward()
306
+
307
+ tr_loss += loss.item()
308
+ if (step + 1) % args.gradient_accumulation_steps == 0:
309
+ if args.fp16:
310
+ torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer_encoder), args.max_grad_norm)
311
+ torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer_decoder), args.max_grad_norm)
312
+ else:
313
+ torch.nn.utils.clip_grad_norm_(model_encoder.parameters(), args.max_grad_norm)
314
+ torch.nn.utils.clip_grad_norm_(model_decoder.parameters(), args.max_grad_norm)
315
+ optimizer_encoder.step()
316
+ optimizer_decoder.step()
317
+ scheduler_encoder.step() # Update learning rate schedule
318
+ scheduler_decoder.step()
319
+ model_encoder.zero_grad()
320
+ model_decoder.zero_grad()
321
+ global_step += 1
322
+
323
+
324
+ if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
325
+ # Log metrics
326
+ if args.local_rank == -1 and args.evaluate_during_training: # Only evaluate when single GPU otherwise metrics may not average well
327
+ results = evaluate(args, model_encoder, model_decoder, encoder_tokenizer, decoder_tokenizer)
328
+ for key, value in results.items():
329
+ tb_writer.add_scalar('eval_{}'.format(key), value, global_step)
330
+ tb_writer.add_scalar('lr_encoder', scheduler_encoder.get_lr()[0], global_step)
331
+ tb_writer.add_scalar('lr_decoder', scheduler_decoder.get_lr()[0], global_step)
332
+ tb_writer.add_scalar('loss', (tr_loss - logging_loss)/args.logging_steps, global_step)
333
+ logging_loss = tr_loss
334
+
335
+ if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
336
+ # Save model checkpoint
337
+ output_encoder_dir = os.path.join(args.output_dir, 'checkpoint-encoder-{}'.format(global_step))
338
+ output_decoder_dir = os.path.join(args.output_dir, 'checkpoint-decoder-{}'.format(global_step))
339
+ if not os.path.exists(output_encoder_dir):
340
+ os.makedirs(output_encoder_dir)
341
+ if not os.path.exists(output_decoder_dir):
342
+ os.makedirs(output_decoder_dir)
343
+
344
+ model_encoder_to_save = model_encoder.module if hasattr(model_encoder, 'module') else model_encoder # Take care of distributed/parallel training
345
+ model_decoder_to_save = model_decoder.module if hasattr(model_decoder, 'module') else model_decoder # Take care of distributed/parallel training
346
+
347
+ model_encoder_to_save.save_pretrained(output_encoder_dir)
348
+ torch.save(args, os.path.join(output_encoder_dir, 'training_encoder_args.bin'))
349
+
350
+ model_decoder_to_save.save_pretrained(output_decoder_dir)
351
+ torch.save(args, os.path.join(output_decoder_dir, 'training_decoder_args.bin'))
352
+
353
+ logger.info("Saving model checkpoint to %s", output_encoder_dir)
354
+ logger.info("Saving model checkpoint to %s", output_decoder_dir)
355
+
356
+ if args.max_steps > 0 and global_step > args.max_steps:
357
+ epoch_iterator.close()
358
+ break
359
+ if args.max_steps > 0 and global_step > args.max_steps:
360
+ train_iterator.close()
361
+ break
362
+
363
+ if args.local_rank in [-1, 0]:
364
+ tb_writer.close()
365
+
366
+ return global_step, tr_loss / global_step
367
+
368
+
369
+ def evaluate(args, model_encoder, model_decoder, encoder_tokenizer, decoder_tokenizer, prefix=""):
370
+ # Loop to handle MNLI double evaluation (matched, mis-matched)
371
+ eval_output_dir = args.output_dir
372
+
373
+ eval_dataset = load_and_cache_examples(args, [encoder_tokenizer, decoder_tokenizer], evaluate=True)
374
+
375
+ if not os.path.exists(eval_output_dir) and args.local_rank in [-1, 0]:
376
+ os.makedirs(eval_output_dir)
377
+
378
+ args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
379
+ # Note that DistributedSampler samples randomly
380
+ eval_sampler = SequentialSampler(eval_dataset) if args.local_rank == -1 else DistributedSampler(eval_dataset)
381
+ eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
382
+
383
+ # Eval!
384
+ logger.info("***** Running evaluation {} *****".format(prefix))
385
+ logger.info(" Num examples = %d", len(eval_dataset))
386
+ logger.info(" Batch size = %d", args.eval_batch_size)
387
+ eval_loss = 0.0
388
+ nb_eval_steps = 0
389
+ model_encoder.eval()
390
+ model_decoder.eval()
391
+
392
+ for batch in tqdm(eval_dataloader, desc="Evaluating"):
393
+ # pdb.set_trace()
394
+ tokenized_text0, tokenized_text1, tokenized_text_lengths = batch
395
+ # prepare input-output data for evaluation
396
+ inputs, labels = tokenized_text0, tokenized_text1
397
+
398
+ tokenized_text1 = tokenized_text1.to(args.device)
399
+ inputs = inputs.to(args.device)
400
+ labels = labels.to(args.device)
401
+
402
+ with torch.no_grad():
403
+ # Encoding
404
+ outputs = model_encoder(inputs)
405
+ pooled_hidden_fea = outputs[1] # model outputs are always tuple in pytorch-transformers (see doc)
406
+
407
+ # Decoding
408
+ outputs = model_decoder(input_ids=tokenized_text1, past=pooled_hidden_fea, labels=labels)
409
+ lm_loss = outputs[0]
410
+
411
+ eval_loss += lm_loss.mean().item()
412
+ nb_eval_steps += 1
413
+
414
+ eval_loss = eval_loss / nb_eval_steps
415
+ perplexity = torch.exp(torch.tensor(eval_loss))
416
+
417
+ result = {
418
+ "perplexity": perplexity
419
+ }
420
+
421
+ output_eval_file = os.path.join(eval_output_dir, "eval_results.txt")
422
+ with open(output_eval_file, "w") as writer:
423
+ logger.info("***** Eval results {} *****".format(prefix))
424
+ for key in sorted(result.keys()):
425
+ logger.info(" %s = %s", key, str(result[key]))
426
+ writer.write("%s = %s\n" % (key, str(result[key])))
427
+
428
+ return result
429
+
430
+
431
+ def main():
432
+ parser = argparse.ArgumentParser()
433
+
434
+ ## Required parameters
435
+ parser.add_argument("--train_data_file", default=None, type=str, required=True,
436
+ help="The input training data file (a text file).")
437
+ parser.add_argument("--output_dir", default=None, type=str, required=True,
438
+ help="The output directory where the model predictions and checkpoints will be written.")
439
+
440
+ ## Other parameters
441
+ parser.add_argument("--eval_data_file", default=None, type=str,
442
+ help="An optional input evaluation data file to evaluate the perplexity on (a text file).")
443
+
444
+ ## Encoder options
445
+ parser.add_argument("--encoder_model_type", default="bert", type=str,
446
+ help="The encoder model architecture to be fine-tuned.")
447
+ parser.add_argument("--encoder_model_name_or_path", default="bert-base-cased", type=str,
448
+ help="The encoder model checkpoint for weights initialization.")
449
+ parser.add_argument("--encoder_config_name", default="", type=str,
450
+ help="Optional pretrained config name or path if not the same as model_name_or_path")
451
+ parser.add_argument("--encoder_tokenizer_name", default="", type=str,
452
+ help="Optional pretrained tokenizer name or path if not the same as model_name_or_path")
453
+
454
+ ## Decoder options
455
+ parser.add_argument("--decoder_model_type", default="gpt2", type=str,
456
+ help="The decoder model architecture to be fine-tuned.")
457
+ parser.add_argument("--decoder_model_name_or_path", default="bert-base-cased", type=str,
458
+ help="The decoder model checkpoint for weights initialization.")
459
+ parser.add_argument("--decoder_config_name", default="", type=str,
460
+ help="Optional pretrained config name or path if not the same as model_name_or_path")
461
+ parser.add_argument("--decoder_tokenizer_name", default="", type=str,
462
+ help="Optional pretrained tokenizer name or path if not the same as model_name_or_path")
463
+
464
+ ## Objective functions
465
+ parser.add_argument("--mlm", action='store_true',
466
+ help="Train with masked-language modeling loss instead of language modeling.")
467
+ parser.add_argument("--mlm_probability", type=float, default=0.15,
468
+ help="Ratio of tokens to mask for masked language modeling loss")
469
+
470
+
471
+
472
+ parser.add_argument("--cache_dir", default="", type=str,
473
+ help="Optional directory to store the pre-trained models downloaded from s3 (instread of the default one)")
474
+ parser.add_argument("--block_size", default=-1, type=int,
475
+ help="Optional input sequence length after tokenization."
476
+ "The training dataset will be truncated in block of this size for training."
477
+ "Default to the model max input length for single sentence inputs (take into account special tokens).")
478
+ parser.add_argument("--do_train", action='store_true',
479
+ help="Whether to run training.")
480
+ parser.add_argument("--do_eval", action='store_true',
481
+ help="Whether to run eval on the dev set.")
482
+ parser.add_argument("--evaluate_during_training", action='store_true',
483
+ help="Run evaluation during training at each logging step.")
484
+ parser.add_argument("--do_lower_case", action='store_true',
485
+ help="Set this flag if you are using an uncased model.")
486
+
487
+
488
+ # Training Schedules
489
+ parser.add_argument("--per_gpu_train_batch_size", default=4, type=int,
490
+ help="Batch size per GPU/CPU for training.")
491
+ parser.add_argument("--per_gpu_eval_batch_size", default=4, type=int,
492
+ help="Batch size per GPU/CPU for evaluation.")
493
+ parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
494
+ help="Number of updates steps to accumulate before performing a backward/update pass.")
495
+ parser.add_argument("--learning_rate", default=5e-5, type=float,
496
+ help="The initial learning rate for Adam.")
497
+ parser.add_argument("--weight_decay", default=0.0, type=float,
498
+ help="Weight deay if we apply some.")
499
+ parser.add_argument("--adam_epsilon", default=1e-8, type=float,
500
+ help="Epsilon for Adam optimizer.")
501
+ parser.add_argument("--max_grad_norm", default=1.0, type=float,
502
+ help="Max gradient norm.")
503
+ parser.add_argument("--num_train_epochs", default=1.0, type=float,
504
+ help="Total number of training epochs to perform.")
505
+ parser.add_argument("--max_steps", default=-1, type=int,
506
+ help="If > 0: set total number of training steps to perform. Override num_train_epochs.")
507
+ parser.add_argument("--warmup_steps", default=0, type=int,
508
+ help="Linear warmup over warmup_steps.")
509
+
510
+
511
+ ## IO: Logging and Saving
512
+ parser.add_argument('--logging_steps', type=int, default=50,
513
+ help="Log every X updates steps.")
514
+ parser.add_argument('--save_steps', type=int, default=50,
515
+ help="Save checkpoint every X updates steps.")
516
+ parser.add_argument("--eval_all_checkpoints", action='store_true',
517
+ help="Evaluate all checkpoints starting with the same prefix as model_name_or_path ending and ending with step number")
518
+ parser.add_argument("--no_cuda", action='store_true',
519
+ help="Avoid using CUDA when available")
520
+ parser.add_argument('--overwrite_output_dir', action='store_true',
521
+ help="Overwrite the content of the output directory")
522
+ parser.add_argument('--overwrite_cache', action='store_true',
523
+ help="Overwrite the cached training and evaluation sets")
524
+ parser.add_argument('--seed', type=int, default=42,
525
+ help="random seed for initialization")
526
+
527
+ # Precision & Distributed Training
528
+ parser.add_argument('--fp16', action='store_true',
529
+ help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit")
530
+ parser.add_argument('--fp16_opt_level', type=str, default='O1',
531
+ help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
532
+ "See details at https://nvidia.github.io/apex/amp.html")
533
+ parser.add_argument("--local_rank", type=int, default=-1,
534
+ help="For distributed training: local_rank")
535
+ parser.add_argument('--server_ip', type=str, default='', help="For distant debugging.")
536
+ parser.add_argument('--server_port', type=str, default='', help="For distant debugging.")
537
+ args = parser.parse_args()
538
+
539
+ if args.decoder_model_type in ["bert", "roberta"] and not args.mlm:
540
+ raise ValueError("BERT and RoBERTa do not have LM heads but masked LM heads. They must be run using the --mlm "
541
+ "flag (masked language modeling).")
542
+ if args.eval_data_file is None and args.do_eval:
543
+ raise ValueError("Cannot do evaluation without an evaluation data file. Either supply a file to --eval_data_file "
544
+ "or remove the --do_eval argument.")
545
+
546
+ if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir:
547
+ raise ValueError("Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir))
548
+
549
+ # Setup distant debugging if needed
550
+ if args.server_ip and args.server_port:
551
+ # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
552
+ import ptvsd
553
+ print("Waiting for debugger attach")
554
+ ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
555
+ ptvsd.wait_for_attach()
556
+
557
+ # Setup CUDA, GPU & distributed training
558
+ if args.local_rank == -1 or args.no_cuda:
559
+ device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
560
+ args.n_gpu = torch.cuda.device_count()
561
+ else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
562
+ torch.cuda.set_device(args.local_rank)
563
+ device = torch.device("cuda", args.local_rank)
564
+ torch.distributed.init_process_group(backend='nccl')
565
+ args.n_gpu = 1
566
+ args.device = device
567
+
568
+ # Setup logging
569
+ logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
570
+ datefmt = '%m/%d/%Y %H:%M:%S',
571
+ level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
572
+ logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
573
+ args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16)
574
+
575
+ # Set seed
576
+ set_seed(args)
577
+
578
+ # Load pretrained model and tokenizer
579
+ if args.local_rank not in [-1, 0]:
580
+ torch.distributed.barrier() # Barrier to make sure only the first process in distributed training download model & vocab
581
+
582
+ ## Encoder
583
+ encoder_config_class, encoder_model_class, encoder_tokenizer_class = MODEL_CLASSES[args.encoder_model_type]
584
+ encoder_config = encoder_config_class.from_pretrained(args.encoder_config_name if args.encoder_config_name else args.encoder_model_name_or_path)
585
+ tokenizer_encoder = encoder_tokenizer_class.from_pretrained(args.encoder_tokenizer_name if args.encoder_tokenizer_name else args.encoder_model_name_or_path, do_lower_case=args.do_lower_case)
586
+ if args.block_size <= 0:
587
+ args.block_size = tokenizer_encoder.max_len_single_sentence # Our input block size will be the max possible for the model
588
+ args.block_size = min(args.block_size, tokenizer_encoder.max_len_single_sentence)
589
+ model_encoder = encoder_model_class.from_pretrained(args.encoder_model_name_or_path, from_tf=bool('.ckpt' in args.encoder_model_name_or_path), config=encoder_config)
590
+ model_encoder.to(args.device)
591
+
592
+ ## Decoder
593
+ decoder_config_class, decoder_model_class, decoder_tokenizer_class = MODEL_CLASSES[args.decoder_model_type]
594
+ decoder_config = decoder_config_class.from_pretrained(args.decoder_config_name if args.decoder_config_name else args.decoder_model_name_or_path)
595
+ tokenizer_decoder = decoder_tokenizer_class.from_pretrained(args.decoder_tokenizer_name if args.decoder_tokenizer_name else args.decoder_model_name_or_path, do_lower_case=args.do_lower_case)
596
+ if args.block_size <= 0:
597
+ args.block_size = tokenizer_decoder.max_len_single_sentence # Our input block size will be the max possible for the model
598
+ args.block_size = min(args.block_size, tokenizer_decoder.max_len_single_sentence)
599
+ model_decoder = decoder_model_class.from_pretrained(args.decoder_model_name_or_path, from_tf=bool('.ckpt' in args.decoder_model_name_or_path), config=decoder_config)
600
+
601
+ # Chunyuan: Add Padding token to GPT2
602
+ special_tokens_dict = {'pad_token': '<PAD>'}
603
+ num_added_toks = tokenizer_decoder.add_special_tokens(special_tokens_dict)
604
+ print('We have added', num_added_toks, 'tokens')
605
+ model_decoder.resize_token_embeddings(len(tokenizer_decoder)) # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer.
606
+ assert tokenizer_decoder.pad_token == '<PAD>'
607
+
608
+ model_decoder.to(args.device)
609
+
610
+ if args.local_rank == 0:
611
+ torch.distributed.barrier() # End of barrier to make sure only the first process in distributed training download model & vocab
612
+
613
+ logger.info("Training/evaluation parameters %s", args)
614
+
615
+ global_step= 0
616
+ # Training
617
+ if args.do_train:
618
+ if args.local_rank not in [-1, 0]:
619
+ torch.distributed.barrier() # Barrier to make sure only the first process in distributed training process the dataset, and the others will use the cache
620
+
621
+ train_dataset = load_and_cache_examples(args, [tokenizer_encoder, tokenizer_decoder], evaluate=False)
622
+
623
+ if args.local_rank == 0:
624
+ torch.distributed.barrier()
625
+
626
+ global_step, tr_loss = train(args, train_dataset, model_encoder, model_decoder, tokenizer_encoder, tokenizer_decoder)
627
+ logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
628
+
629
+
630
+ # Saving best-practices: if you use save_pretrained for the model and tokenizer, you can reload them using from_pretrained()
631
+ if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
632
+ # Create output directory if needed
633
+ # Save model checkpoint
634
+ output_encoder_dir = os.path.join(args.output_dir, 'checkpoint-encoder-{}'.format(global_step))
635
+ output_decoder_dir = os.path.join(args.output_dir, 'checkpoint-decoder-{}'.format(global_step))
636
+ if not os.path.exists(output_encoder_dir) and args.local_rank in [-1, 0]:
637
+ os.makedirs(output_encoder_dir)
638
+ if not os.path.exists(output_decoder_dir) and args.local_rank in [-1, 0]:
639
+ os.makedirs(output_decoder_dir)
640
+
641
+ logger.info("Saving encoder model checkpoint to %s", output_encoder_dir)
642
+ logger.info("Saving decoder model checkpoint to %s", output_decoder_dir)
643
+ # Save a trained model, configuration and tokenizer using `save_pretrained()`.
644
+ # They can then be reloaded using `from_pretrained()`
645
+
646
+ model_encoder_to_save = model_encoder.module if hasattr(model_encoder, 'module') else model_encoder # Take care of distributed/parallel training
647
+ model_decoder_to_save = model_decoder.module if hasattr(model_decoder, 'module') else model_decoder # Take care of distributed/parallel training
648
+
649
+ # Good practice: save your training arguments together with the trained model
650
+ model_encoder_to_save.save_pretrained(output_encoder_dir)
651
+ torch.save(args, os.path.join(output_encoder_dir, 'training_encoder_args.bin'))
652
+
653
+ model_decoder_to_save.save_pretrained(output_decoder_dir)
654
+ torch.save(args, os.path.join(output_decoder_dir, 'training_decoder_args.bin'))
655
+
656
+
657
+ # Load a trained model and vocabulary that you have fine-tuned
658
+ model_encoder = encoder_model_class.from_pretrained(output_encoder_dir)
659
+ tokenizer_encoder = encoder_tokenizer_class.from_pretrained(output_encoder_dir, do_lower_case=args.do_lower_case)
660
+ model_encoder.to(args.device)
661
+
662
+ # Load a trained model and vocabulary that you have fine-tuned
663
+ model_decoder = decoder_model_class.from_pretrained(output_decoder_dir)
664
+ tokenizer_decoder = decoder_tokenizer_class.from_pretrained(output_decoder_dir, do_lower_case=args.do_lower_case)
665
+ model_decoder.to(args.device)
666
+
667
+
668
+ # Evaluation
669
+ results = {}
670
+ if args.do_eval and args.local_rank in [-1, 0]:
671
+ global_step= 881
672
+ output_encoder_dir = os.path.join(args.output_dir, 'checkpoint-encoder-{}'.format(global_step))
673
+ output_decoder_dir = os.path.join(args.output_dir, 'checkpoint-decoder-{}'.format(global_step))
674
+ checkpoints = [ [output_encoder_dir, output_decoder_dir] ]
675
+
676
+ logger.info("Evaluate the following checkpoints: %s", checkpoints)
677
+ for checkpoint in checkpoints:
678
+ global_step = checkpoint[0].split('-')[-1] if len(checkpoints) > 1 else ""
679
+
680
+ model_encoder = encoder_model_class.from_pretrained(checkpoint[0])
681
+ model_encoder.to(args.device)
682
+ model_decoder = decoder_model_class.from_pretrained(checkpoint[1])
683
+ model_decoder.to(args.device)
684
+ result = evaluate(args, model_encoder, model_decoder, tokenizer_encoder, tokenizer_decoder, prefix=global_step)
685
+ result = dict((k + '_{}'.format(global_step), v) for k, v in result.items())
686
+ results.update(result)
687
+
688
+ return results
689
+
690
+
691
+ if __name__ == "__main__":
692
+ main()
Optimus/code/examples/big_ae/run_lm_causal_pretraining.py ADDED
@@ -0,0 +1,692 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Fine-tuning the library models for language modeling on a text file (GPT, GPT-2, BERT, RoBERTa).
18
+ GPT and GPT-2 are fine-tuned using a causal language modeling (CLM) loss while BERT and RoBERTa are fine-tuned
19
+ using a masked language modeling (MLM) loss.
20
+ """
21
+
22
+ from __future__ import absolute_import, division, print_function
23
+
24
+
25
+ import pdb
26
+ import argparse
27
+ import glob
28
+ import logging
29
+ import os
30
+ import pickle
31
+ import random
32
+
33
+ import numpy as np
34
+ import torch
35
+ from torch.utils.data import DataLoader, Dataset, SequentialSampler, RandomSampler, TensorDataset
36
+ from torch.utils.data.distributed import DistributedSampler
37
+ from tensorboardX import SummaryWriter
38
+ from tqdm import tqdm, trange
39
+
40
+ from pytorch_transformers import (WEIGHTS_NAME, AdamW, WarmupLinearSchedule,
41
+ BertConfig, BertModel, BertTokenizer,
42
+ GPT2Config, GPT2LMHeadModel, GPT2Tokenizer,
43
+ OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer,
44
+ RobertaConfig, RobertaForMaskedLM, RobertaTokenizer)
45
+
46
+
47
+ logger = logging.getLogger(__name__)
48
+
49
+
50
+ MODEL_CLASSES = {
51
+ 'gpt2': (GPT2Config, GPT2LMHeadModel, GPT2Tokenizer),
52
+ 'openai-gpt': (OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer),
53
+ 'bert': (BertConfig, BertModel, BertTokenizer),
54
+ 'roberta': (RobertaConfig, RobertaForMaskedLM, RobertaTokenizer)
55
+ }
56
+
57
+
58
+ class TextDataset(Dataset):
59
+ def __init__(self, tokenizer, file_path='train', block_size=512):
60
+ assert os.path.isfile(file_path)
61
+ directory, filename = os.path.split(file_path)
62
+ cached_features_file = os.path.join(directory, f'cached_lm_{block_size}_{filename}')
63
+
64
+ if os.path.exists(cached_features_file):
65
+ logger.info("Loading features from cached file %s", cached_features_file)
66
+ with open(cached_features_file, 'rb') as handle:
67
+ self.examples = pickle.load(handle)
68
+ else:
69
+ logger.info("Creating features from dataset file at %s", directory)
70
+
71
+ self.examples = []
72
+ with open(file_path, encoding="utf-8") as f:
73
+ text = f.read()
74
+
75
+
76
+ tokenized_text = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text))
77
+
78
+ while len(tokenized_text) >= block_size: # Truncate in block of block_size
79
+ self.examples.append(tokenizer.add_special_tokens_single_sentence(tokenized_text[:block_size]))
80
+ tokenized_text = tokenized_text[block_size:]
81
+ # Note that we are loosing the last truncated example here for the sake of simplicity (no padding)
82
+ # If your dataset is small, first you should loook for a bigger one :-) and second you
83
+ # can change this behavior by adding (model specific) padding.
84
+
85
+ logger.info("Saving features into cached file %s", cached_features_file)
86
+ with open(cached_features_file, 'wb') as handle:
87
+ pickle.dump(self.examples, handle, protocol=pickle.HIGHEST_PROTOCOL)
88
+
89
+ def __len__(self):
90
+ return len(self.examples)
91
+
92
+ def __getitem__(self, item):
93
+ return torch.tensor(self.examples[item])
94
+
95
+
96
+
97
+ class TextDataset_2Tokenizers(Dataset):
98
+ def __init__(self, tokenizers, file_path='train', block_size=512):
99
+ assert os.path.isfile(file_path)
100
+ directory, filename = os.path.split(file_path)
101
+ cached_features_file = os.path.join(directory, f'cached_lm_gpt_bert_{block_size}_{filename}')
102
+
103
+
104
+
105
+ if os.path.exists(cached_features_file):
106
+ logger.info("Loading features from cached file %s", cached_features_file)
107
+ with open(cached_features_file, 'rb') as handle:
108
+ self.examples = pickle.load(handle)
109
+ else:
110
+ logger.info("Creating features from dataset file at %s", directory)
111
+
112
+
113
+ with open(file_path, encoding="utf-8") as f:
114
+ text = f.read()
115
+
116
+ # pdb.set_trace()
117
+ self.examples = []
118
+ # Chunyuan: divide the linguistic text into the same length, then different tokenization schemes are applied
119
+ while len(text) >= block_size: # Truncate in block of block_size
120
+
121
+ tokenized_text0 = tokenizers[0].convert_tokens_to_ids(tokenizers[0].tokenize(text[:block_size]))
122
+ tokenized_text0 = tokenizers[0].add_special_tokens_single_sentence(tokenized_text0)
123
+ tokenized_text0_length = len(tokenized_text0)
124
+ pad_token=tokenizers[0].convert_tokens_to_ids([tokenizers[0].pad_token])[0]
125
+ tokenized_text0 = tokenized_text0 + ([pad_token] * (block_size - tokenized_text0_length) ) # Pad up to the sequence length.
126
+ assert len(tokenized_text0) == block_size
127
+
128
+ tokenized_text1 = tokenizers[1].convert_tokens_to_ids(tokenizers[1].tokenize(text[:block_size]))
129
+ tokenized_text1 = tokenizers[1].add_special_tokens_single_sentence(tokenized_text1)
130
+ tokenized_text1_length = len(tokenized_text1)
131
+ pad_token=tokenizers[1].convert_tokens_to_ids([tokenizers[1].pad_token])[0]
132
+ tokenized_text1 = tokenized_text1 + ([pad_token] * (block_size - tokenized_text1_length) ) # Pad up to the sequence length.
133
+ assert len(tokenized_text1) == block_size
134
+
135
+ self.examples.append([tokenized_text0, tokenized_text0_length, tokenized_text1, tokenized_text1_length])
136
+
137
+ text = text[block_size:]
138
+ # Note that we are loosing the last truncated example here for the sake of simplicity (no padding)
139
+ # If your dataset is small, first you should loook for a bigger one :-) and second you
140
+ # can change this behavior by adding (model specific) padding.
141
+
142
+ logger.info("Saving features into cached file %s", cached_features_file)
143
+ with open(cached_features_file, 'wb') as handle:
144
+ pickle.dump(self.examples, handle, protocol=pickle.HIGHEST_PROTOCOL)
145
+
146
+ def __len__(self):
147
+ return len(self.examples)
148
+
149
+ def __getitem__(self, item):
150
+ # pdb.set_trace()
151
+ # Convert to Tensors and build dataset
152
+ tokenized_text0= torch.tensor(self.examples[item][0], dtype=torch.long)
153
+ tokenized_text1= torch.tensor(self.examples[item][2], dtype=torch.long)
154
+ tokenized_text_lengths = torch.tensor([self.examples[item][1], self.examples[item][3]], dtype=torch.long)
155
+ # pdb.set_trace()
156
+ return (tokenized_text0, tokenized_text1, tokenized_text_lengths)
157
+
158
+ def load_and_cache_examples(args, tokenizer, evaluate=False):
159
+ if isinstance(tokenizer, list):
160
+ dataset = TextDataset_2Tokenizers(tokenizer, file_path=args.eval_data_file if evaluate else args.train_data_file, block_size=args.block_size)
161
+ else:
162
+ dataset = TextDataset(tokenizer, file_path=args.eval_data_file if evaluate else args.train_data_file, block_size=args.block_size)
163
+ return dataset
164
+
165
+
166
+ def set_seed(args):
167
+ random.seed(args.seed)
168
+ np.random.seed(args.seed)
169
+ torch.manual_seed(args.seed)
170
+ if args.n_gpu > 0:
171
+ torch.cuda.manual_seed_all(args.seed)
172
+
173
+
174
+ def mask_tokens(inputs, tokenizer, args):
175
+ """ Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. """
176
+ labels = inputs.clone()
177
+ # We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
178
+
179
+ masked_indices = torch.bernoulli(torch.full(labels.shape, args.mlm_probability)).to(torch.uint8)
180
+ labels[masked_indices==1] = -1 # We only compute loss on masked tokens
181
+
182
+ # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
183
+ indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).to(torch.uint8) & masked_indices
184
+ inputs[indices_replaced] = tokenizer.convert_tokens_to_ids(tokenizer.mask_token)
185
+
186
+ # 10% of the time, we replace masked input tokens with random word
187
+ indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).to(torch.uint8) & masked_indices & ~indices_replaced
188
+ indices_random = indices_random
189
+ random_words = torch.randint(len(tokenizer), labels.shape, dtype=torch.long)
190
+ inputs[indices_random] = random_words[indices_random]
191
+
192
+ # The rest of the time (10% of the time) we keep the masked input tokens unchanged
193
+ return inputs, labels
194
+
195
+
196
+ def train(args, train_dataset, model_encoder, model_decoder, encoder_tokenizer, decoder_tokenizer):
197
+ """ Train the model """
198
+ if args.local_rank in [-1, 0]:
199
+ tb_writer = SummaryWriter()
200
+
201
+ args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
202
+ train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
203
+ train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
204
+
205
+ if args.max_steps > 0:
206
+ t_total = args.max_steps
207
+ args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
208
+ else:
209
+ t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
210
+
211
+ # Prepare optimizer and schedule (linear warmup and decay)
212
+ no_decay = ['bias', 'LayerNorm.weight']
213
+ optimizer_grouped_encoder_parameters = [
214
+ {'params': [p for n, p in model_encoder.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
215
+ {'params': [p for n, p in model_encoder.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
216
+ ]
217
+
218
+ optimizer_grouped_decoder_parameters = [
219
+ {'params': [p for n, p in model_decoder.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
220
+ {'params': [p for n, p in model_decoder.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
221
+ ]
222
+
223
+
224
+ optimizer_encoder = AdamW(optimizer_grouped_encoder_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
225
+ optimizer_decoder = AdamW(optimizer_grouped_decoder_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
226
+ scheduler_encoder = WarmupLinearSchedule(optimizer_encoder, warmup_steps=args.warmup_steps, t_total=t_total)
227
+ scheduler_decoder = WarmupLinearSchedule(optimizer_decoder, warmup_steps=args.warmup_steps, t_total=t_total)
228
+
229
+ if args.fp16:
230
+ try:
231
+ from apex import amp
232
+ except ImportError:
233
+ raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
234
+ model_encoder, optimizer_encoder = amp.initialize(model_encoder, optimizer_encoder, opt_level=args.fp16_opt_level)
235
+ model_decoder, optimizer_decoder = amp.initialize(model_decoder, optimizer_decoder, opt_level=args.fp16_opt_level)
236
+
237
+ # multi-gpu training (should be after apex fp16 initialization)
238
+ if args.n_gpu > 1:
239
+ model_encoder = torch.nn.DataParallel(model_encoder)
240
+ model_decoder = torch.nn.DataParallel(model_decoder)
241
+
242
+ # Distributed training (should be after apex fp16 initialization)
243
+ if args.local_rank != -1:
244
+ model_encoder = torch.nn.parallel.DistributedDataParallel(model_encoder, device_ids=[args.local_rank],
245
+ output_device=args.local_rank,
246
+ find_unused_parameters=True)
247
+ model_decoder = torch.nn.parallel.DistributedDataParallel(model_decoder, device_ids=[args.local_rank],
248
+ output_device=args.local_rank,
249
+ find_unused_parameters=True)
250
+
251
+ # Train!
252
+ logger.info("***** Running training *****")
253
+ logger.info(" Num examples = %d", len(train_dataset))
254
+ logger.info(" Num Epochs = %d", args.num_train_epochs)
255
+ logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
256
+ logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d",
257
+ args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
258
+ logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
259
+ logger.info(" Total optimization steps = %d", t_total)
260
+
261
+ global_step = 0
262
+ tr_loss, logging_loss = 0.0, 0.0
263
+ model_encoder.zero_grad()
264
+ model_decoder.zero_grad()
265
+
266
+ train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0])
267
+ set_seed(args) # Added here for reproducibility (even between python 2 and 3)
268
+ for _ in train_iterator:
269
+ epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
270
+ for step, batch in enumerate(epoch_iterator):
271
+
272
+ tokenized_text0, tokenized_text1, tokenized_text_lengths = batch
273
+ # tokenized_text0 = tokenized_text0.to(args.device)
274
+ # tokenized_text1 = tokenized_text1.to(args.device)
275
+ # prepare input-output data for reconstruction
276
+ inputs, labels = mask_tokens(tokenized_text0, encoder_tokenizer, args) if args.mlm else (tokenized_text0, tokenized_text1)
277
+ labels = tokenized_text1
278
+
279
+ inputs = inputs.to(args.device)
280
+ labels = labels.to(args.device)
281
+
282
+ model_encoder.train()
283
+ model_decoder.train()
284
+
285
+
286
+ # Encoding
287
+ outputs = model_encoder(inputs)
288
+ pooled_hidden_fea = outputs[1] # model outputs are always tuple in pytorch-transformers (see doc)
289
+
290
+
291
+ # Decoding
292
+ outputs = model_decoder(input_ids=tokenized_text1, past=None, labels=labels)
293
+ loss = outputs[0] # model outputs are always tuple in pytorch-transformers (see doc)
294
+
295
+
296
+ if args.n_gpu > 1:
297
+ loss = loss.mean() # mean() to average on multi-gpu parallel training
298
+ if args.gradient_accumulation_steps > 1:
299
+ loss = loss / args.gradient_accumulation_steps
300
+
301
+ if args.fp16:
302
+ with amp.scale_loss(loss, optimizer) as scaled_loss:
303
+ scaled_loss.backward()
304
+ else:
305
+ loss.backward()
306
+
307
+ tr_loss += loss.item()
308
+ if (step + 1) % args.gradient_accumulation_steps == 0:
309
+ if args.fp16:
310
+ torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer_encoder), args.max_grad_norm)
311
+ torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer_decoder), args.max_grad_norm)
312
+ else:
313
+ torch.nn.utils.clip_grad_norm_(model_encoder.parameters(), args.max_grad_norm)
314
+ torch.nn.utils.clip_grad_norm_(model_decoder.parameters(), args.max_grad_norm)
315
+ optimizer_encoder.step()
316
+ optimizer_decoder.step()
317
+ scheduler_encoder.step() # Update learning rate schedule
318
+ scheduler_decoder.step()
319
+ model_encoder.zero_grad()
320
+ model_decoder.zero_grad()
321
+ global_step += 1
322
+
323
+
324
+ if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
325
+ # Log metrics
326
+ if args.local_rank == -1 and args.evaluate_during_training: # Only evaluate when single GPU otherwise metrics may not average well
327
+ results = evaluate(args, model_encoder, model_decoder, encoder_tokenizer, decoder_tokenizer)
328
+ for key, value in results.items():
329
+ tb_writer.add_scalar('eval_{}'.format(key), value, global_step)
330
+ tb_writer.add_scalar('lr_encoder', scheduler_encoder.get_lr()[0], global_step)
331
+ tb_writer.add_scalar('lr_decoder', scheduler_decoder.get_lr()[0], global_step)
332
+ tb_writer.add_scalar('loss', (tr_loss - logging_loss)/args.logging_steps, global_step)
333
+ logging_loss = tr_loss
334
+
335
+ if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
336
+ # Save model checkpoint
337
+ output_encoder_dir = os.path.join(args.output_dir, 'checkpoint-encoder-{}'.format(global_step))
338
+ output_decoder_dir = os.path.join(args.output_dir, 'checkpoint-decoder-{}'.format(global_step))
339
+ if not os.path.exists(output_encoder_dir):
340
+ os.makedirs(output_encoder_dir)
341
+ if not os.path.exists(output_decoder_dir):
342
+ os.makedirs(output_decoder_dir)
343
+
344
+ model_encoder_to_save = model_encoder.module if hasattr(model_encoder, 'module') else model_encoder # Take care of distributed/parallel training
345
+ model_decoder_to_save = model_decoder.module if hasattr(model_decoder, 'module') else model_decoder # Take care of distributed/parallel training
346
+
347
+ model_encoder_to_save.save_pretrained(output_encoder_dir)
348
+ torch.save(args, os.path.join(output_encoder_dir, 'training_encoder_args.bin'))
349
+
350
+ model_decoder_to_save.save_pretrained(output_decoder_dir)
351
+ torch.save(args, os.path.join(output_decoder_dir, 'training_decoder_args.bin'))
352
+
353
+ logger.info("Saving model checkpoint to %s", output_encoder_dir)
354
+ logger.info("Saving model checkpoint to %s", output_decoder_dir)
355
+
356
+ if args.max_steps > 0 and global_step > args.max_steps:
357
+ epoch_iterator.close()
358
+ break
359
+ if args.max_steps > 0 and global_step > args.max_steps:
360
+ train_iterator.close()
361
+ break
362
+
363
+ if args.local_rank in [-1, 0]:
364
+ tb_writer.close()
365
+
366
+ return global_step, tr_loss / global_step
367
+
368
+
369
+ def evaluate(args, model_encoder, model_decoder, encoder_tokenizer, decoder_tokenizer, prefix=""):
370
+ # Loop to handle MNLI double evaluation (matched, mis-matched)
371
+ eval_output_dir = args.output_dir
372
+
373
+ eval_dataset = load_and_cache_examples(args, [encoder_tokenizer, decoder_tokenizer], evaluate=True)
374
+
375
+ if not os.path.exists(eval_output_dir) and args.local_rank in [-1, 0]:
376
+ os.makedirs(eval_output_dir)
377
+
378
+ args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
379
+ # Note that DistributedSampler samples randomly
380
+ eval_sampler = SequentialSampler(eval_dataset) if args.local_rank == -1 else DistributedSampler(eval_dataset)
381
+ eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
382
+
383
+ # Eval!
384
+ logger.info("***** Running evaluation {} *****".format(prefix))
385
+ logger.info(" Num examples = %d", len(eval_dataset))
386
+ logger.info(" Batch size = %d", args.eval_batch_size)
387
+ eval_loss = 0.0
388
+ nb_eval_steps = 0
389
+ model_encoder.eval()
390
+ model_decoder.eval()
391
+
392
+ for batch in tqdm(eval_dataloader, desc="Evaluating"):
393
+ # pdb.set_trace()
394
+ tokenized_text0, tokenized_text1, tokenized_text_lengths = batch
395
+ # prepare input-output data for evaluation
396
+ inputs, labels = tokenized_text0, tokenized_text1
397
+
398
+ tokenized_text1 = tokenized_text1.to(args.device)
399
+ inputs = inputs.to(args.device)
400
+ labels = labels.to(args.device)
401
+
402
+ with torch.no_grad():
403
+ # Encoding
404
+ outputs = model_encoder(inputs)
405
+ pooled_hidden_fea = outputs[1] # model outputs are always tuple in pytorch-transformers (see doc)
406
+
407
+ # Decoding
408
+ outputs = model_decoder(input_ids=tokenized_text1, past=None, labels=labels)
409
+ lm_loss = outputs[0]
410
+
411
+ eval_loss += lm_loss.mean().item()
412
+ nb_eval_steps += 1
413
+
414
+ eval_loss = eval_loss / nb_eval_steps
415
+ perplexity = torch.exp(torch.tensor(eval_loss))
416
+
417
+ result = {
418
+ "perplexity": perplexity
419
+ }
420
+
421
+ output_eval_file = os.path.join(eval_output_dir, "eval_results.txt")
422
+ with open(output_eval_file, "w") as writer:
423
+ logger.info("***** Eval results {} *****".format(prefix))
424
+ for key in sorted(result.keys()):
425
+ logger.info(" %s = %s", key, str(result[key]))
426
+ writer.write("%s = %s\n" % (key, str(result[key])))
427
+
428
+ return result
429
+
430
+
431
+ def main():
432
+ parser = argparse.ArgumentParser()
433
+
434
+ ## Required parameters
435
+ parser.add_argument("--train_data_file", default=None, type=str, required=True,
436
+ help="The input training data file (a text file).")
437
+ parser.add_argument("--output_dir", default=None, type=str, required=True,
438
+ help="The output directory where the model predictions and checkpoints will be written.")
439
+
440
+ ## Other parameters
441
+ parser.add_argument("--eval_data_file", default=None, type=str,
442
+ help="An optional input evaluation data file to evaluate the perplexity on (a text file).")
443
+
444
+ ## Encoder options
445
+ parser.add_argument("--encoder_model_type", default="bert", type=str,
446
+ help="The encoder model architecture to be fine-tuned.")
447
+ parser.add_argument("--encoder_model_name_or_path", default="bert-base-cased", type=str,
448
+ help="The encoder model checkpoint for weights initialization.")
449
+ parser.add_argument("--encoder_config_name", default="", type=str,
450
+ help="Optional pretrained config name or path if not the same as model_name_or_path")
451
+ parser.add_argument("--encoder_tokenizer_name", default="", type=str,
452
+ help="Optional pretrained tokenizer name or path if not the same as model_name_or_path")
453
+
454
+ ## Decoder options
455
+ parser.add_argument("--decoder_model_type", default="gpt2", type=str,
456
+ help="The decoder model architecture to be fine-tuned.")
457
+ parser.add_argument("--decoder_model_name_or_path", default="bert-base-cased", type=str,
458
+ help="The decoder model checkpoint for weights initialization.")
459
+ parser.add_argument("--decoder_config_name", default="", type=str,
460
+ help="Optional pretrained config name or path if not the same as model_name_or_path")
461
+ parser.add_argument("--decoder_tokenizer_name", default="", type=str,
462
+ help="Optional pretrained tokenizer name or path if not the same as model_name_or_path")
463
+
464
+ ## Objective functions
465
+ parser.add_argument("--mlm", action='store_true',
466
+ help="Train with masked-language modeling loss instead of language modeling.")
467
+ parser.add_argument("--mlm_probability", type=float, default=0.15,
468
+ help="Ratio of tokens to mask for masked language modeling loss")
469
+
470
+
471
+
472
+ parser.add_argument("--cache_dir", default="", type=str,
473
+ help="Optional directory to store the pre-trained models downloaded from s3 (instread of the default one)")
474
+ parser.add_argument("--block_size", default=-1, type=int,
475
+ help="Optional input sequence length after tokenization."
476
+ "The training dataset will be truncated in block of this size for training."
477
+ "Default to the model max input length for single sentence inputs (take into account special tokens).")
478
+ parser.add_argument("--do_train", action='store_true',
479
+ help="Whether to run training.")
480
+ parser.add_argument("--do_eval", action='store_true',
481
+ help="Whether to run eval on the dev set.")
482
+ parser.add_argument("--evaluate_during_training", action='store_true',
483
+ help="Run evaluation during training at each logging step.")
484
+ parser.add_argument("--do_lower_case", action='store_true',
485
+ help="Set this flag if you are using an uncased model.")
486
+
487
+
488
+ # Training Schedules
489
+ parser.add_argument("--per_gpu_train_batch_size", default=4, type=int,
490
+ help="Batch size per GPU/CPU for training.")
491
+ parser.add_argument("--per_gpu_eval_batch_size", default=4, type=int,
492
+ help="Batch size per GPU/CPU for evaluation.")
493
+ parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
494
+ help="Number of updates steps to accumulate before performing a backward/update pass.")
495
+ parser.add_argument("--learning_rate", default=5e-5, type=float,
496
+ help="The initial learning rate for Adam.")
497
+ parser.add_argument("--weight_decay", default=0.0, type=float,
498
+ help="Weight deay if we apply some.")
499
+ parser.add_argument("--adam_epsilon", default=1e-8, type=float,
500
+ help="Epsilon for Adam optimizer.")
501
+ parser.add_argument("--max_grad_norm", default=1.0, type=float,
502
+ help="Max gradient norm.")
503
+ parser.add_argument("--num_train_epochs", default=1.0, type=float,
504
+ help="Total number of training epochs to perform.")
505
+ parser.add_argument("--max_steps", default=-1, type=int,
506
+ help="If > 0: set total number of training steps to perform. Override num_train_epochs.")
507
+ parser.add_argument("--warmup_steps", default=0, type=int,
508
+ help="Linear warmup over warmup_steps.")
509
+
510
+
511
+ ## IO: Logging and Saving
512
+ parser.add_argument('--logging_steps', type=int, default=50,
513
+ help="Log every X updates steps.")
514
+ parser.add_argument('--save_steps', type=int, default=50,
515
+ help="Save checkpoint every X updates steps.")
516
+ parser.add_argument("--eval_all_checkpoints", action='store_true',
517
+ help="Evaluate all checkpoints starting with the same prefix as model_name_or_path ending and ending with step number")
518
+ parser.add_argument("--no_cuda", action='store_true',
519
+ help="Avoid using CUDA when available")
520
+ parser.add_argument('--overwrite_output_dir', action='store_true',
521
+ help="Overwrite the content of the output directory")
522
+ parser.add_argument('--overwrite_cache', action='store_true',
523
+ help="Overwrite the cached training and evaluation sets")
524
+ parser.add_argument('--seed', type=int, default=42,
525
+ help="random seed for initialization")
526
+
527
+ # Precision & Distributed Training
528
+ parser.add_argument('--fp16', action='store_true',
529
+ help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit")
530
+ parser.add_argument('--fp16_opt_level', type=str, default='O1',
531
+ help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
532
+ "See details at https://nvidia.github.io/apex/amp.html")
533
+ parser.add_argument("--local_rank", type=int, default=-1,
534
+ help="For distributed training: local_rank")
535
+ parser.add_argument('--server_ip', type=str, default='', help="For distant debugging.")
536
+ parser.add_argument('--server_port', type=str, default='', help="For distant debugging.")
537
+ args = parser.parse_args()
538
+
539
+ if args.decoder_model_type in ["bert", "roberta"] and not args.mlm:
540
+ raise ValueError("BERT and RoBERTa do not have LM heads but masked LM heads. They must be run using the --mlm "
541
+ "flag (masked language modeling).")
542
+ if args.eval_data_file is None and args.do_eval:
543
+ raise ValueError("Cannot do evaluation without an evaluation data file. Either supply a file to --eval_data_file "
544
+ "or remove the --do_eval argument.")
545
+
546
+ if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir:
547
+ raise ValueError("Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir))
548
+
549
+ # Setup distant debugging if needed
550
+ if args.server_ip and args.server_port:
551
+ # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
552
+ import ptvsd
553
+ print("Waiting for debugger attach")
554
+ ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
555
+ ptvsd.wait_for_attach()
556
+
557
+ # Setup CUDA, GPU & distributed training
558
+ if args.local_rank == -1 or args.no_cuda:
559
+ device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
560
+ args.n_gpu = torch.cuda.device_count()
561
+ else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
562
+ torch.cuda.set_device(args.local_rank)
563
+ device = torch.device("cuda", args.local_rank)
564
+ torch.distributed.init_process_group(backend='nccl')
565
+ args.n_gpu = 1
566
+ args.device = device
567
+
568
+ # Setup logging
569
+ logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
570
+ datefmt = '%m/%d/%Y %H:%M:%S',
571
+ level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
572
+ logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
573
+ args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16)
574
+
575
+ # Set seed
576
+ set_seed(args)
577
+
578
+ # Load pretrained model and tokenizer
579
+ if args.local_rank not in [-1, 0]:
580
+ torch.distributed.barrier() # Barrier to make sure only the first process in distributed training download model & vocab
581
+
582
+ ## Encoder
583
+ encoder_config_class, encoder_model_class, encoder_tokenizer_class = MODEL_CLASSES[args.encoder_model_type]
584
+ encoder_config = encoder_config_class.from_pretrained(args.encoder_config_name if args.encoder_config_name else args.encoder_model_name_or_path)
585
+ tokenizer_encoder = encoder_tokenizer_class.from_pretrained(args.encoder_tokenizer_name if args.encoder_tokenizer_name else args.encoder_model_name_or_path, do_lower_case=args.do_lower_case)
586
+ if args.block_size <= 0:
587
+ args.block_size = tokenizer_encoder.max_len_single_sentence # Our input block size will be the max possible for the model
588
+ args.block_size = min(args.block_size, tokenizer_encoder.max_len_single_sentence)
589
+ model_encoder = encoder_model_class.from_pretrained(args.encoder_model_name_or_path, from_tf=bool('.ckpt' in args.encoder_model_name_or_path), config=encoder_config)
590
+ model_encoder.to(args.device)
591
+
592
+ ## Decoder
593
+ decoder_config_class, decoder_model_class, decoder_tokenizer_class = MODEL_CLASSES[args.decoder_model_type]
594
+ decoder_config = decoder_config_class.from_pretrained(args.decoder_config_name if args.decoder_config_name else args.decoder_model_name_or_path)
595
+ tokenizer_decoder = decoder_tokenizer_class.from_pretrained(args.decoder_tokenizer_name if args.decoder_tokenizer_name else args.decoder_model_name_or_path, do_lower_case=args.do_lower_case)
596
+ if args.block_size <= 0:
597
+ args.block_size = tokenizer_decoder.max_len_single_sentence # Our input block size will be the max possible for the model
598
+ args.block_size = min(args.block_size, tokenizer_decoder.max_len_single_sentence)
599
+ model_decoder = decoder_model_class.from_pretrained(args.decoder_model_name_or_path, from_tf=bool('.ckpt' in args.decoder_model_name_or_path), config=decoder_config)
600
+
601
+ # Chunyuan: Add Padding token to GPT2
602
+ special_tokens_dict = {'pad_token': '<PAD>'}
603
+ num_added_toks = tokenizer_decoder.add_special_tokens(special_tokens_dict)
604
+ print('We have added', num_added_toks, 'tokens')
605
+ model_decoder.resize_token_embeddings(len(tokenizer_decoder)) # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer.
606
+ assert tokenizer_decoder.pad_token == '<PAD>'
607
+
608
+ model_decoder.to(args.device)
609
+
610
+ if args.local_rank == 0:
611
+ torch.distributed.barrier() # End of barrier to make sure only the first process in distributed training download model & vocab
612
+
613
+ logger.info("Training/evaluation parameters %s", args)
614
+
615
+ global_step= 0
616
+ # Training
617
+ if args.do_train:
618
+ if args.local_rank not in [-1, 0]:
619
+ torch.distributed.barrier() # Barrier to make sure only the first process in distributed training process the dataset, and the others will use the cache
620
+
621
+ train_dataset = load_and_cache_examples(args, [tokenizer_encoder, tokenizer_decoder], evaluate=False)
622
+
623
+ if args.local_rank == 0:
624
+ torch.distributed.barrier()
625
+
626
+ global_step, tr_loss = train(args, train_dataset, model_encoder, model_decoder, tokenizer_encoder, tokenizer_decoder)
627
+ logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
628
+
629
+
630
+ # Saving best-practices: if you use save_pretrained for the model and tokenizer, you can reload them using from_pretrained()
631
+ if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
632
+ # Create output directory if needed
633
+ # Save model checkpoint
634
+ output_encoder_dir = os.path.join(args.output_dir, 'checkpoint-encoder-{}'.format(global_step))
635
+ output_decoder_dir = os.path.join(args.output_dir, 'checkpoint-decoder-{}'.format(global_step))
636
+ if not os.path.exists(output_encoder_dir) and args.local_rank in [-1, 0]:
637
+ os.makedirs(output_encoder_dir)
638
+ if not os.path.exists(output_decoder_dir) and args.local_rank in [-1, 0]:
639
+ os.makedirs(output_decoder_dir)
640
+
641
+ logger.info("Saving encoder model checkpoint to %s", output_encoder_dir)
642
+ logger.info("Saving decoder model checkpoint to %s", output_decoder_dir)
643
+ # Save a trained model, configuration and tokenizer using `save_pretrained()`.
644
+ # They can then be reloaded using `from_pretrained()`
645
+
646
+ model_encoder_to_save = model_encoder.module if hasattr(model_encoder, 'module') else model_encoder # Take care of distributed/parallel training
647
+ model_decoder_to_save = model_decoder.module if hasattr(model_decoder, 'module') else model_decoder # Take care of distributed/parallel training
648
+
649
+ # Good practice: save your training arguments together with the trained model
650
+ model_encoder_to_save.save_pretrained(output_encoder_dir)
651
+ torch.save(args, os.path.join(output_encoder_dir, 'training_encoder_args.bin'))
652
+
653
+ model_decoder_to_save.save_pretrained(output_decoder_dir)
654
+ torch.save(args, os.path.join(output_decoder_dir, 'training_decoder_args.bin'))
655
+
656
+
657
+ # Load a trained model and vocabulary that you have fine-tuned
658
+ model_encoder = encoder_model_class.from_pretrained(output_encoder_dir)
659
+ tokenizer_encoder = encoder_tokenizer_class.from_pretrained(output_encoder_dir, do_lower_case=args.do_lower_case)
660
+ model_encoder.to(args.device)
661
+
662
+ # Load a trained model and vocabulary that you have fine-tuned
663
+ model_decoder = decoder_model_class.from_pretrained(output_decoder_dir)
664
+ tokenizer_decoder = decoder_tokenizer_class.from_pretrained(output_decoder_dir, do_lower_case=args.do_lower_case)
665
+ model_decoder.to(args.device)
666
+
667
+
668
+ # Evaluation
669
+ results = {}
670
+ if args.do_eval and args.local_rank in [-1, 0]:
671
+ global_step= 881
672
+ output_encoder_dir = os.path.join(args.output_dir, 'checkpoint-encoder-{}'.format(global_step))
673
+ output_decoder_dir = os.path.join(args.output_dir, 'checkpoint-decoder-{}'.format(global_step))
674
+ checkpoints = [ [output_encoder_dir, output_decoder_dir] ]
675
+
676
+ logger.info("Evaluate the following checkpoints: %s", checkpoints)
677
+ for checkpoint in checkpoints:
678
+ global_step = checkpoint[0].split('-')[-1] if len(checkpoints) > 1 else ""
679
+
680
+ model_encoder = encoder_model_class.from_pretrained(checkpoint[0])
681
+ model_encoder.to(args.device)
682
+ model_decoder = decoder_model_class.from_pretrained(checkpoint[1])
683
+ model_decoder.to(args.device)
684
+ result = evaluate(args, model_encoder, model_decoder, tokenizer_encoder, tokenizer_decoder, prefix=global_step)
685
+ result = dict((k + '_{}'.format(global_step), v) for k, v in result.items())
686
+ results.update(result)
687
+
688
+ return results
689
+
690
+
691
+ if __name__ == "__main__":
692
+ main()
Optimus/code/examples/big_ae/run_lm_finetuning_baseline.py ADDED
@@ -0,0 +1,573 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Fine-tuning the library models for language modeling on a text file (GPT, GPT-2, BERT, RoBERTa).
18
+ GPT and GPT-2 are fine-tuned using a causal language modeling (CLM) loss while BERT and RoBERTa are fine-tuned
19
+ using a masked language modeling (MLM) loss.
20
+ """
21
+
22
+ from __future__ import absolute_import, division, print_function
23
+
24
+ import pdb
25
+
26
+ import sys
27
+ sys.path.insert(0, '.')
28
+
29
+ import argparse
30
+ import glob
31
+ import logging
32
+ import os
33
+ import pickle
34
+ import random
35
+
36
+ import numpy as np
37
+ import torch
38
+ from torch.utils.data import DataLoader, Dataset, SequentialSampler, RandomSampler
39
+ from torch.utils.data.distributed import DistributedSampler
40
+ from tensorboardX import SummaryWriter
41
+ from tqdm import tqdm, trange
42
+
43
+ from pytorch_transformers import (WEIGHTS_NAME, AdamW, WarmupLinearSchedule,
44
+ BertConfig, BertForMaskedLM, BertTokenizer,
45
+ GPT2Config, GPT2LMHeadModel, GPT2Tokenizer,
46
+ OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer,
47
+ RobertaConfig, RobertaForMaskedLM, RobertaTokenizer)
48
+
49
+ from utils import (calc_iwnll, calc_mi, calc_au, TextDataset_Split, TextDataset_2Tokenizers)
50
+
51
+ import pdb
52
+
53
+ logger = logging.getLogger(__name__)
54
+
55
+
56
+ MODEL_CLASSES = {
57
+ 'gpt2': (GPT2Config, GPT2LMHeadModel, GPT2Tokenizer),
58
+ 'openai-gpt': (OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer),
59
+ 'bert': (BertConfig, BertForMaskedLM, BertTokenizer),
60
+ 'roberta': (RobertaConfig, RobertaForMaskedLM, RobertaTokenizer)
61
+ }
62
+
63
+
64
+ class TextDataset(Dataset):
65
+ def __init__(self, tokenizer, file_path='train', block_size=512):
66
+ assert os.path.isfile(file_path)
67
+ directory, filename = os.path.split(file_path)
68
+ cached_features_file = os.path.join(directory, f'cached_lm_{block_size}_{filename}')
69
+
70
+ if os.path.exists(cached_features_file):
71
+ logger.info("Loading features from cached file %s", cached_features_file)
72
+ with open(cached_features_file, 'rb') as handle:
73
+ self.examples = pickle.load(handle)
74
+ else:
75
+ logger.info("Creating features from dataset file at %s", directory)
76
+
77
+ self.examples = []
78
+ with open(file_path, encoding="utf-8") as f:
79
+ text = f.read()
80
+
81
+ tokenized_text = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text))
82
+
83
+ while len(tokenized_text) >= block_size: # Truncate in block of block_size
84
+ self.examples.append(tokenizer.add_special_tokens_single_sentence(tokenized_text[:block_size]))
85
+ tokenized_text = tokenized_text[block_size:]
86
+ # Note that we are loosing the last truncated example here for the sake of simplicity (no padding)
87
+ # If your dataset is small, first you should loook for a bigger one :-) and second you
88
+ # can change this behavior by adding (model specific) padding.
89
+
90
+ logger.info("Saving features into cached file %s", cached_features_file)
91
+ with open(cached_features_file, 'wb') as handle:
92
+ pickle.dump(self.examples, handle, protocol=pickle.HIGHEST_PROTOCOL)
93
+
94
+ def __len__(self):
95
+ return len(self.examples)
96
+
97
+ def __getitem__(self, item):
98
+ return torch.tensor(self.examples[item])
99
+
100
+
101
+ def load_and_cache_examples(args, tokenizer, evaluate=False):
102
+ if isinstance(tokenizer, list):
103
+ dataset = TextDataset_2Tokenizers(tokenizer, args, file_path=args.eval_data_file if evaluate else args.train_data_file, block_size=args.block_size)
104
+ else:
105
+ dataset = TextDataset_Split(tokenizer, args, file_path=args.eval_data_file if evaluate else args.train_data_file, block_size=args.block_size)
106
+ return dataset
107
+
108
+
109
+ def set_seed(args):
110
+ random.seed(args.seed)
111
+ np.random.seed(args.seed)
112
+ torch.manual_seed(args.seed)
113
+ if args.n_gpu > 0:
114
+ torch.cuda.manual_seed_all(args.seed)
115
+
116
+
117
+ def mask_tokens(inputs, tokenizer, args):
118
+ """ Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. """
119
+ labels = inputs.clone()
120
+ # We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
121
+
122
+ masked_indices = torch.bernoulli(torch.full(labels.shape, args.mlm_probability)).to(torch.uint8)
123
+ labels[masked_indices==1] = -1 # We only compute loss on masked tokens
124
+
125
+ # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
126
+ indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).to(torch.uint8) & masked_indices
127
+ inputs[indices_replaced] = tokenizer.convert_tokens_to_ids(tokenizer.mask_token)
128
+
129
+ # 10% of the time, we replace masked input tokens with random word
130
+ indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).to(torch.uint8) & masked_indices & ~indices_replaced
131
+ random_words = torch.randint(len(tokenizer), labels.shape, dtype=torch.long)
132
+ inputs[indices_random] = random_words[indices_random]
133
+
134
+ # The rest of the time (10% of the time) we keep the masked input tokens unchanged
135
+ return inputs, labels
136
+
137
+
138
+ def train(args, train_dataset, model, tokenizer):
139
+ """ Train the model """
140
+ if args.local_rank in [-1, 0]:
141
+ tb_writer = SummaryWriter()
142
+
143
+ args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
144
+ train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
145
+ train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
146
+
147
+ if args.max_steps > 0:
148
+ t_total = args.max_steps
149
+ args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
150
+ else:
151
+ t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
152
+
153
+ # Prepare optimizer and schedule (linear warmup and decay)
154
+ no_decay = ['bias', 'LayerNorm.weight']
155
+ optimizer_grouped_parameters = [
156
+ {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
157
+ {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
158
+ ]
159
+ optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
160
+ scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total)
161
+ if args.fp16:
162
+ try:
163
+ from apex import amp
164
+ except ImportError:
165
+ raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
166
+ model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
167
+
168
+ # multi-gpu training (should be after apex fp16 initialization)
169
+ if args.n_gpu > 1:
170
+ model = torch.nn.DataParallel(model)
171
+
172
+ # Distributed training (should be after apex fp16 initialization)
173
+ if args.local_rank != -1:
174
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
175
+ output_device=args.local_rank,
176
+ find_unused_parameters=True)
177
+
178
+
179
+ # Train!
180
+ logger.info("***** Running training *****")
181
+ logger.info(" Num examples = %d", len(train_dataset))
182
+ logger.info(" Num Epochs = %d", args.num_train_epochs)
183
+ logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
184
+ logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d",
185
+ args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
186
+ logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
187
+ logger.info(" Total optimization steps = %d", t_total)
188
+
189
+ global_step = 0
190
+ tr_loss, logging_loss = 0.0, 0.0
191
+ model.zero_grad()
192
+ train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0])
193
+ set_seed(args) # Added here for reproducibility (even between python 2 and 3)
194
+ for _ in train_iterator:
195
+ epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
196
+ for step, batch in enumerate(epoch_iterator):
197
+
198
+ tokenized_text1, tokenized_text_lengths = batch
199
+
200
+ inputs, labels = tokenized_text1, tokenized_text1
201
+
202
+ inputs = inputs.to(args.device)
203
+ labels = labels.to(args.device)
204
+
205
+ model.train()
206
+
207
+ outputs = model(inputs, labels=labels, label_ignore=tokenizer.pad_token_id)
208
+
209
+ # pdb.set_trace()
210
+ loss = outputs[0].mean() # model outputs are always tuple in pytorch-transformers (see doc)
211
+
212
+ if args.use_philly:
213
+ print("PROGRESS: {}%".format(round(100 * (step + epoch*len(epoch_iterator) ) /(int(args.num_train_epochs) * len(epoch_iterator)) , 4)))
214
+ print("EVALERR: {}%".format(loss))
215
+
216
+
217
+
218
+ if args.n_gpu > 1:
219
+ loss = loss.mean() # mean() to average on multi-gpu parallel training
220
+ if args.gradient_accumulation_steps > 1:
221
+ loss = loss / args.gradient_accumulation_steps
222
+
223
+ if args.fp16:
224
+ with amp.scale_loss(loss, optimizer) as scaled_loss:
225
+ scaled_loss.backward()
226
+ else:
227
+ loss.backward()
228
+
229
+ tr_loss += loss.item()
230
+ if (step + 1) % args.gradient_accumulation_steps == 0:
231
+ if args.fp16:
232
+ torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
233
+ else:
234
+ torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
235
+ optimizer.step()
236
+ scheduler.step() # Update learning rate schedule
237
+ model.zero_grad()
238
+ global_step += 1
239
+
240
+ if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
241
+ # Log metrics
242
+ if args.local_rank == -1 and args.evaluate_during_training: # Only evaluate when single GPU otherwise metrics may not average well
243
+ results = evaluate(args, model, tokenizer)
244
+ for key, value in results.items():
245
+ tb_writer.add_scalar('eval_{}'.format(key), value, global_step)
246
+ tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step)
247
+ tb_writer.add_scalar('loss', (tr_loss - logging_loss)/args.logging_steps, global_step)
248
+ logging_loss = tr_loss
249
+
250
+ if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
251
+ # Save model checkpoint
252
+ output_dir = os.path.join(args.output_dir, 'checkpoint-{}'.format(global_step))
253
+ if not os.path.exists(output_dir):
254
+ os.makedirs(output_dir)
255
+ model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
256
+ model_to_save.save_pretrained(output_dir)
257
+ torch.save(args, os.path.join(output_dir, 'training_args.bin'))
258
+ logger.info("Saving model checkpoint to %s", output_dir)
259
+
260
+ if args.max_steps > 0 and global_step > args.max_steps:
261
+ epoch_iterator.close()
262
+ break
263
+ if args.max_steps > 0 and global_step > args.max_steps:
264
+ train_iterator.close()
265
+ break
266
+
267
+ if args.local_rank in [-1, 0]:
268
+ tb_writer.close()
269
+
270
+ return global_step, tr_loss / global_step
271
+
272
+
273
+ def evaluate(args, model, tokenizer, prefix=""):
274
+ # Loop to handle MNLI double evaluation (matched, mis-matched)
275
+ eval_output_dir = args.output_dir
276
+
277
+ eval_dataset = load_and_cache_examples(args, tokenizer, evaluate=True)
278
+
279
+ if not os.path.exists(eval_output_dir) and args.local_rank in [-1, 0]:
280
+ os.makedirs(eval_output_dir)
281
+
282
+ args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
283
+ # Note that DistributedSampler samples randomly
284
+ eval_sampler = SequentialSampler(eval_dataset) if args.local_rank == -1 else DistributedSampler(eval_dataset)
285
+ eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
286
+
287
+ # Eval!
288
+ logger.info("***** Running evaluation {} *****".format(prefix))
289
+ logger.info(" Num examples = %d", len(eval_dataset))
290
+ logger.info(" Batch size = %d", args.eval_batch_size)
291
+ eval_loss = 0.0
292
+ eval_loss_sum = 0.0
293
+ nb_eval_steps = 0
294
+ report_num_words = 0
295
+
296
+ model.eval()
297
+
298
+ for batch in tqdm(eval_dataloader, desc="Evaluating"):
299
+
300
+ tokenized_text1, x_lengths = batch
301
+ x_lengths = x_lengths.to(args.device)
302
+ report_num_words += x_lengths.sum().item()
303
+
304
+ inputs, labels = tokenized_text1, tokenized_text1
305
+
306
+ inputs = inputs.to(args.device)
307
+ labels = labels.to(args.device)
308
+
309
+
310
+ with torch.no_grad():
311
+ outputs = model(inputs, labels=labels, label_ignore=tokenizer.pad_token_id)
312
+ lm_loss = outputs[0]
313
+
314
+
315
+ eval_loss += lm_loss.mean().item()/x_lengths.sum().item()
316
+ eval_loss_sum += lm_loss.sum().item()
317
+
318
+
319
+ nb_eval_steps += 1
320
+
321
+ # pdb.set_trace()
322
+
323
+ eval_loss = eval_loss / nb_eval_steps
324
+ perplexity1 = torch.exp(torch.tensor(eval_loss))
325
+ perplexity2 = torch.exp(torch.tensor(eval_loss_sum / report_num_words))
326
+
327
+
328
+
329
+ result = {
330
+ "perplexity1": perplexity1, "perplexity2": perplexity2
331
+ }
332
+
333
+ output_eval_file = os.path.join(eval_output_dir, "eval_results.txt")
334
+ with open(output_eval_file, "w") as writer:
335
+ logger.info("***** Eval results {} *****".format(prefix))
336
+ for key in sorted(result.keys()):
337
+ logger.info(" %s = %s", key, str(result[key]))
338
+ writer.write("%s = %s\n" % (key, str(result[key])))
339
+
340
+ return result
341
+
342
+
343
+ def main():
344
+ parser = argparse.ArgumentParser()
345
+
346
+ ## Required parameters
347
+ parser.add_argument("--train_data_file", default=None, type=str, required=True,
348
+ help="The input training data file (a text file).")
349
+ parser.add_argument("--output_dir", default=None, type=str, required=True,
350
+ help="The output directory where the model predictions and checkpoints will be written.")
351
+ parser.add_argument("--dataset", default=None, type=str, help="The dataset.")
352
+
353
+
354
+ ## Other parameters
355
+ parser.add_argument("--eval_data_file", default=None, type=str,
356
+ help="An optional input evaluation data file to evaluate the perplexity on (a text file).")
357
+
358
+ parser.add_argument("--model_type", default="bert", type=str,
359
+ help="The model architecture to be fine-tuned.")
360
+ parser.add_argument("--model_name_or_path", default="bert-base-cased", type=str,
361
+ help="The model checkpoint for weights initialization.")
362
+
363
+
364
+ parser.add_argument("--use_philly", action='store_true',
365
+ help="Use Philly for computing.")
366
+
367
+ parser.add_argument("--mlm", action='store_true',
368
+ help="Train with masked-language modeling loss instead of language modeling.")
369
+ parser.add_argument("--mlm_probability", type=float, default=0.15,
370
+ help="Ratio of tokens to mask for masked language modeling loss")
371
+
372
+ parser.add_argument("--config_name", default="", type=str,
373
+ help="Optional pretrained config name or path if not the same as model_name_or_path")
374
+ parser.add_argument("--tokenizer_name", default="", type=str,
375
+ help="Optional pretrained tokenizer name or path if not the same as model_name_or_path")
376
+ parser.add_argument("--cache_dir", default="", type=str,
377
+ help="Optional directory to store the pre-trained models downloaded from s3 (instread of the default one)")
378
+ parser.add_argument("--block_size", default=-1, type=int,
379
+ help="Optional input sequence length after tokenization."
380
+ "The training dataset will be truncated in block of this size for training."
381
+ "Default to the model max input length for single sentence inputs (take into account special tokens).")
382
+ parser.add_argument("--do_train", action='store_true',
383
+ help="Whether to run training.")
384
+ parser.add_argument("--do_eval", action='store_true',
385
+ help="Whether to run eval on the dev set.")
386
+ parser.add_argument("--evaluate_during_training", action='store_true',
387
+ help="Run evaluation during training at each logging step.")
388
+ parser.add_argument("--do_lower_case", action='store_true',
389
+ help="Set this flag if you are using an uncased model.")
390
+
391
+ parser.add_argument("--per_gpu_train_batch_size", default=4, type=int,
392
+ help="Batch size per GPU/CPU for training.")
393
+ parser.add_argument("--per_gpu_eval_batch_size", default=1, type=int,
394
+ help="Batch size per GPU/CPU for evaluation.")
395
+ parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
396
+ help="Number of updates steps to accumulate before performing a backward/update pass.")
397
+ parser.add_argument("--learning_rate", default=5e-5, type=float,
398
+ help="The initial learning rate for Adam.")
399
+ parser.add_argument("--weight_decay", default=0.0, type=float,
400
+ help="Weight deay if we apply some.")
401
+ parser.add_argument("--adam_epsilon", default=1e-8, type=float,
402
+ help="Epsilon for Adam optimizer.")
403
+ parser.add_argument("--max_grad_norm", default=1.0, type=float,
404
+ help="Max gradient norm.")
405
+ parser.add_argument("--num_train_epochs", default=1.0, type=float,
406
+ help="Total number of training epochs to perform.")
407
+ parser.add_argument("--max_steps", default=-1, type=int,
408
+ help="If > 0: set total number of training steps to perform. Override num_train_epochs.")
409
+ parser.add_argument("--warmup_steps", default=0, type=int,
410
+ help="Linear warmup over warmup_steps.")
411
+
412
+ parser.add_argument('--gloabl_step_eval', type=int, default=661,
413
+ help="Evaluate the results at the given global step")
414
+
415
+
416
+
417
+ parser.add_argument('--logging_steps', type=int, default=100,
418
+ help="Log every X updates steps.")
419
+ parser.add_argument('--save_steps', type=int, default=100,
420
+ help="Save checkpoint every X updates steps.")
421
+ parser.add_argument("--eval_all_checkpoints", action='store_true',
422
+ help="Evaluate all checkpoints starting with the same prefix as model_name_or_path ending and ending with step number")
423
+ parser.add_argument("--no_cuda", action='store_true',
424
+ help="Avoid using CUDA when available")
425
+ parser.add_argument('--overwrite_output_dir', action='store_true',
426
+ help="Overwrite the content of the output directory")
427
+ parser.add_argument('--overwrite_cache', action='store_true',
428
+ help="Overwrite the cached training and evaluation sets")
429
+ parser.add_argument('--seed', type=int, default=42,
430
+ help="random seed for initialization")
431
+
432
+ parser.add_argument('--fp16', action='store_true',
433
+ help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit")
434
+ parser.add_argument('--fp16_opt_level', type=str, default='O1',
435
+ help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
436
+ "See details at https://nvidia.github.io/apex/amp.html")
437
+ parser.add_argument("--local_rank", type=int, default=-1,
438
+ help="For distributed training: local_rank")
439
+ parser.add_argument('--server_ip', type=str, default='', help="For distant debugging.")
440
+ parser.add_argument('--server_port', type=str, default='', help="For distant debugging.")
441
+ args = parser.parse_args()
442
+
443
+ if args.model_type in ["bert", "roberta"] and not args.mlm:
444
+ raise ValueError("BERT and RoBERTa do not have LM heads but masked LM heads. They must be run using the --mlm "
445
+ "flag (masked language modeling).")
446
+ if args.eval_data_file is None and args.do_eval:
447
+ raise ValueError("Cannot do evaluation without an evaluation data file. Either supply a file to --eval_data_file "
448
+ "or remove the --do_eval argument.")
449
+
450
+ if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir:
451
+ raise ValueError("Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir))
452
+
453
+ # Setup distant debugging if needed
454
+ if args.server_ip and args.server_port:
455
+ # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
456
+ import ptvsd
457
+ print("Waiting for debugger attach")
458
+ ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
459
+ ptvsd.wait_for_attach()
460
+
461
+ # Setup CUDA, GPU & distributed training
462
+ if args.local_rank == -1 or args.no_cuda:
463
+ device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
464
+ args.n_gpu = torch.cuda.device_count()
465
+ else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
466
+ torch.cuda.set_device(args.local_rank)
467
+ device = torch.device("cuda", args.local_rank)
468
+ torch.distributed.init_process_group(backend='nccl')
469
+ args.n_gpu = 1
470
+ args.device = device
471
+
472
+ # Setup logging
473
+ logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
474
+ datefmt = '%m/%d/%Y %H:%M:%S',
475
+ level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
476
+ logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
477
+ args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16)
478
+
479
+ # Set seed
480
+ set_seed(args)
481
+
482
+ # Load pretrained model and tokenizer
483
+ if args.local_rank not in [-1, 0]:
484
+ torch.distributed.barrier() # Barrier to make sure only the first process in distributed training download model & vocab
485
+
486
+ config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
487
+ config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path)
488
+ tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, do_lower_case=args.do_lower_case)
489
+ if args.block_size <= 0:
490
+ args.block_size = tokenizer.max_len_single_sentence # Our input block size will be the max possible for the model
491
+ args.block_size = min(args.block_size, tokenizer.max_len_single_sentence)
492
+ model = model_class.from_pretrained(args.model_name_or_path, from_tf=bool('.ckpt' in args.model_name_or_path), config=config)
493
+ model.to(args.device)
494
+
495
+ # Chunyuan: Add Padding token to GPT2
496
+ special_tokens_dict = {'pad_token': '<PAD>', 'bos_token': '<BOS>', 'eos_token': '<EOS>'}
497
+ num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)
498
+ print('We have added', num_added_toks, 'tokens to GPT2')
499
+ model.resize_token_embeddings(len(tokenizer)) # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer.
500
+ assert tokenizer.pad_token == '<PAD>'
501
+
502
+
503
+ # pdb.set_trace()
504
+
505
+ if args.local_rank == 0:
506
+ torch.distributed.barrier() # End of barrier to make sure only the first process in distributed training download model & vocab
507
+
508
+ logger.info("Training/evaluation parameters %s", args)
509
+
510
+ # Training
511
+ global_step= 0
512
+ if args.do_train:
513
+ if args.local_rank not in [-1, 0]:
514
+ torch.distributed.barrier() # Barrier to make sure only the first process in distributed training process the dataset, and the others will use the cache
515
+
516
+ train_dataset = load_and_cache_examples(args, tokenizer, evaluate=False)
517
+
518
+ if args.local_rank == 0:
519
+ torch.distributed.barrier()
520
+
521
+ global_step, tr_loss = train(args, train_dataset, model, tokenizer)
522
+ logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
523
+
524
+
525
+ # Saving best-practices: if you use save_pretrained for the model and tokenizer, you can reload them using from_pretrained()
526
+ if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
527
+ # Create output directory if needed
528
+ if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
529
+ os.makedirs(args.output_dir)
530
+
531
+ logger.info("Saving model checkpoint to %s", args.output_dir)
532
+ # Save a trained model, configuration and tokenizer using `save_pretrained()`.
533
+ # They can then be reloaded using `from_pretrained()`
534
+ model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
535
+ model_to_save.save_pretrained(args.output_dir)
536
+ tokenizer.save_pretrained(args.output_dir)
537
+
538
+ # Good practice: save your training arguments together with the trained model
539
+ torch.save(args, os.path.join(args.output_dir, 'training_args.bin'))
540
+
541
+ # Load a trained model and vocabulary that you have fine-tuned
542
+ model = model_class.from_pretrained(args.output_dir)
543
+ tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
544
+ model.to(args.device)
545
+
546
+
547
+ # Evaluation
548
+ results = {}
549
+ if args.do_eval and args.local_rank in [-1, 0]:
550
+
551
+ if global_step == 0:
552
+ global_step = args.gloabl_step_eval
553
+ output_dir = os.path.join(args.output_dir, 'checkpoint-{}'.format(global_step))
554
+
555
+ checkpoints = [args.output_dir]
556
+ if args.eval_all_checkpoints:
557
+ checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True)))
558
+ logging.getLogger("pytorch_transformers.modeling_utils").setLevel(logging.WARN) # Reduce logging
559
+ logger.info("Evaluate the following checkpoints: %s", checkpoints)
560
+ print("Evaluate the following checkpoints: %s", checkpoints)
561
+ for checkpoint in checkpoints:
562
+ global_step = checkpoint.split('-')[-1] if len(checkpoints) > 1 else ""
563
+ model = model_class.from_pretrained(checkpoint)
564
+ model.to(args.device)
565
+ result = evaluate(args, model, tokenizer, prefix=global_step)
566
+ result = dict((k + '_{}'.format(global_step), v) for k, v in result.items())
567
+ results.update(result)
568
+
569
+ return results
570
+
571
+
572
+ if __name__ == "__main__":
573
+ main()
Optimus/code/examples/big_ae/run_lm_gpt2_training.py ADDED
@@ -0,0 +1,658 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Fine-tuning the library models for language modeling on a text file (GPT, GPT-2, BERT, RoBERTa).
18
+ GPT and GPT-2 are fine-tuned using a causal language modeling (CLM) loss while BERT and RoBERTa are fine-tuned
19
+ using a masked language modeling (MLM) loss.
20
+ """
21
+
22
+ from __future__ import absolute_import, division, print_function
23
+
24
+
25
+ import pdb
26
+ import argparse
27
+ import glob
28
+ import logging
29
+
30
+ import os
31
+ import pickle
32
+ import random
33
+
34
+ import numpy as np
35
+ import torch
36
+ from torch.utils.data import DataLoader, Dataset, SequentialSampler, RandomSampler, TensorDataset
37
+ from torch.utils.data.distributed import DistributedSampler
38
+ from tensorboardX import SummaryWriter
39
+ from tqdm import tqdm, trange
40
+ from collections import defaultdict
41
+
42
+ # from azure.cosmosdb.table.tableservice import TableService
43
+ # from azure.cosmosdb.table.models import Entity
44
+ from datetime import datetime
45
+
46
+
47
+
48
+ from pytorch_transformers import (WEIGHTS_NAME, AdamW, WarmupLinearSchedule,
49
+ BertConfig, BertForLatentConnector, BertTokenizer,
50
+ GPT2Config, GPT2LMHeadModel, GPT2Tokenizer,
51
+ OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer,
52
+ RobertaConfig, RobertaForMaskedLM, RobertaTokenizer)
53
+
54
+ from utils import (BucketingDataLoader, TextDataset_Split, TextDataset_2Tokenizers)
55
+
56
+
57
+ logger = logging.getLogger(__name__)
58
+
59
+
60
+ MODEL_CLASSES = {
61
+ 'gpt2': (GPT2Config, GPT2LMHeadModel, GPT2Tokenizer),
62
+ 'openai-gpt': (OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer),
63
+ 'bert': (BertConfig, BertForLatentConnector, BertTokenizer),
64
+ 'roberta': (RobertaConfig, RobertaForMaskedLM, RobertaTokenizer)
65
+ }
66
+
67
+
68
+ storage_name="textae"
69
+ key=r"6yBCXlblof8DVFJ4BD3eNFTrGQCej6cKfCf5z308cKnevyHaG+yl/m+ITVErB9yt0kvN3ToqxLIh0knJEfFmPA=="
70
+ # ts = TableService(account_name=storage_name, account_key=key)
71
+
72
+
73
+ def load_and_cache_examples(args, tokenizer, evaluate=False):
74
+ if isinstance(tokenizer, list):
75
+ dataset = TextDataset_2Tokenizers(tokenizer, args, file_path=args.eval_data_file if evaluate else args.train_data_file, block_size=args.block_size)
76
+ else:
77
+ dataset = TextDataset_Split(tokenizer, args, file_path=args.eval_data_file if evaluate else args.train_data_file, block_size=args.block_size)
78
+ return dataset
79
+
80
+ def build_dataload_and_cache_examples(args, tokenizer, evaluate=False):
81
+ if isinstance(tokenizer, list):
82
+ if not evaluate:
83
+ args.batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
84
+ file_path=args.train_data_file
85
+ else:
86
+ args.batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
87
+ file_path=args.eval_data_file
88
+ dataloader = BucketingDataLoader(file_path, args.batch_size, args.max_seq_length, tokenizer, args, bucket=100, shuffle=True)
89
+ else:
90
+ pass
91
+ return dataloader
92
+
93
+
94
+
95
+
96
+ def set_seed(args):
97
+ random.seed(args.seed)
98
+ np.random.seed(args.seed)
99
+ torch.manual_seed(args.seed)
100
+ if args.n_gpu > 0:
101
+ torch.cuda.manual_seed_all(args.seed)
102
+
103
+
104
+ def mask_tokens(inputs, tokenizer, args):
105
+ """ Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. """
106
+ labels = inputs.clone()
107
+ # We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
108
+
109
+ masked_indices = torch.bernoulli(torch.full(labels.shape, args.mlm_probability)).to(torch.uint8)
110
+ labels[masked_indices==1] = -1 # We only compute loss on masked tokens
111
+
112
+ # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
113
+ indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).to(torch.uint8) & masked_indices
114
+ inputs[indices_replaced] = tokenizer.convert_tokens_to_ids(tokenizer.mask_token)
115
+
116
+ # 10% of the time, we replace masked input tokens with random word
117
+ indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).to(torch.uint8) & masked_indices & ~indices_replaced
118
+ indices_random = indices_random
119
+ random_words = torch.randint(len(tokenizer), labels.shape, dtype=torch.long)
120
+ inputs[indices_random] = random_words[indices_random]
121
+
122
+ # The rest of the time (10% of the time) we keep the masked input tokens unchanged
123
+ return inputs, labels
124
+
125
+
126
+ def train(args, train_dataloader, model, encoder_tokenizer, decoder_tokenizer, table_name):
127
+ """ Train the model """
128
+ if args.local_rank in [-1, 0]:
129
+ tb_writer = SummaryWriter()
130
+
131
+ args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
132
+ # train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
133
+ # train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
134
+
135
+ if args.max_steps > 0:
136
+ t_total = args.max_steps
137
+ args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
138
+ else:
139
+ t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
140
+
141
+ # Prepare optimizer and schedule (linear warmup and decay)
142
+
143
+
144
+ no_decay = ['bias', 'LayerNorm.weight']
145
+ optimizer_grouped_parameters = [
146
+ {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
147
+ {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
148
+ ]
149
+
150
+ optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
151
+ scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total)
152
+
153
+
154
+ if args.fp16:
155
+ try:
156
+ from apex import amp
157
+ except ImportError:
158
+ raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
159
+ model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
160
+
161
+ # multi-gpu training (should be after apex fp16 initialization)
162
+ if args.n_gpu > 1:
163
+ model = torch.nn.DataParallel(model, device_ids=range(args.n_gpu)).to(args.device)
164
+
165
+ # Distributed training (should be after apex fp16 initialization)
166
+ if args.local_rank != -1:
167
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
168
+ output_device=args.local_rank,
169
+ find_unused_parameters=True)
170
+
171
+
172
+ # Train!
173
+ logger.info("***** Running training *****")
174
+ logger.info(" Num examples = %d", train_dataloader.num_examples)
175
+ logger.info(" Num Epochs = %d", args.num_train_epochs)
176
+ logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
177
+ logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d",
178
+ args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
179
+ logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
180
+ logger.info(" Total optimization steps = %d", t_total)
181
+
182
+ global_step = 0
183
+ tr_loss, logging_loss = 0.0, 0.0
184
+
185
+
186
+ model.zero_grad()
187
+ train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0])
188
+
189
+ n_iter = int(args.num_train_epochs) * len(train_dataloader)
190
+
191
+ tmp_list = []
192
+ set_seed(args) # Added here for reproducibility (even between python 2 and 3)
193
+ for epoch in train_iterator:
194
+ epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
195
+ for step, batch in enumerate(epoch_iterator):
196
+
197
+ tokenized_text0, tokenized_text1, tokenized_text_lengths = batch
198
+ inputs, labels = tokenized_text1.to(args.device), tokenized_text1.to(args.device)
199
+
200
+ model.train()
201
+
202
+ outputs = model(inputs, labels=labels, label_ignore=decoder_tokenizer.pad_token_id)
203
+ loss = outputs[0].mean() # model outputs are always tuple in pytorch-transformers (see doc)
204
+
205
+ if args.n_gpu > 1:
206
+ loss = loss.mean()
207
+
208
+ if args.use_philly:
209
+ print("PROGRESS: {}%".format(round(100 * (step + epoch*len(epoch_iterator) ) /(int(args.num_train_epochs) * len(epoch_iterator)) , 4)))
210
+ print("EVALERR: {}%".format(loss))
211
+
212
+ epoch_iterator.set_description(
213
+ (
214
+ f'iter: {step + epoch*len(epoch_iterator) }; loss: {loss.item():.3f}; '
215
+ )
216
+ )
217
+
218
+ if args.gradient_accumulation_steps > 1:
219
+ loss = loss / args.gradient_accumulation_steps
220
+
221
+ if args.fp16:
222
+ with amp.scale_loss(loss, optimizer) as scaled_loss:
223
+ scaled_loss.backward()
224
+ else:
225
+ loss.backward()
226
+
227
+ tr_loss += loss.item()
228
+ if (step + 1) % args.gradient_accumulation_steps == 0:
229
+ if args.fp16:
230
+ torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
231
+ else:
232
+ torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
233
+
234
+ optimizer.step()
235
+
236
+ scheduler.step() # Update learning rate schedule
237
+
238
+ model.zero_grad()
239
+
240
+ global_step += 1
241
+
242
+
243
+ if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
244
+ # Log metrics
245
+ if args.local_rank == -1 and args.evaluate_during_training: # Only evaluate when single GPU otherwise metrics may not average well
246
+ results = evaluate(args, model_vae, encoder_tokenizer, decoder_tokenizer)
247
+ for key, value in results.items():
248
+ tb_writer.add_scalar('eval_{}'.format(key), value, global_step)
249
+ tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step)
250
+ tb_writer.add_scalar('loss', (tr_loss - logging_loss)/args.logging_steps, global_step)
251
+ logging_loss = tr_loss
252
+
253
+ if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
254
+
255
+ # Save decoder model checkpoint
256
+ output_decoder_dir = os.path.join(args.output_dir, 'checkpoint-decoder-{}'.format(global_step))
257
+
258
+ if not os.path.exists(output_decoder_dir):
259
+ os.makedirs(output_decoder_dir)
260
+
261
+ model_decoder_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
262
+ if args.use_philly:
263
+ save_solid = False
264
+ while not save_solid:
265
+ try:
266
+ model_decoder_to_save.save_pretrained(output_decoder_dir)
267
+ torch.save(args, os.path.join(output_decoder_dir, 'training_args.bin'))
268
+ logger.info("Saving model checkpoint to %s", output_decoder_dir)
269
+ save_solid = True
270
+ except:
271
+ pass
272
+ else:
273
+ model_decoder_to_save.save_pretrained(output_decoder_dir)
274
+ torch.save(args, os.path.join(output_decoder_dir, 'training_args.bin'))
275
+ logger.info("Saving model checkpoint to %s", output_decoder_dir)
276
+
277
+
278
+ if args.max_steps > 0 and global_step > args.max_steps:
279
+ epoch_iterator.close()
280
+ break
281
+
282
+
283
+ if args.max_steps > 0 and global_step > args.max_steps:
284
+ train_iterator.close()
285
+ break
286
+
287
+ if args.local_rank in [-1, 0]:
288
+ tb_writer.close()
289
+
290
+ return global_step, tr_loss / global_step
291
+
292
+
293
+ def evaluate(args, model, encoder_tokenizer, decoder_tokenizer, table_name, prefix="", subset="test"):
294
+ # Loop to handle MNLI double evaluation (matched, mis-matched)
295
+ eval_output_dir = args.output_dir
296
+
297
+ logger.info("***** Running evaluation on {} dataset *****".format(subset))
298
+
299
+ if not os.path.exists(eval_output_dir) and args.local_rank in [-1, 0]:
300
+ os.makedirs(eval_output_dir)
301
+
302
+ args.per_gpu_eval_batch_size = 1
303
+ args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
304
+
305
+ eval_dataloader = build_dataload_and_cache_examples(args, [encoder_tokenizer, decoder_tokenizer], evaluate=True)
306
+
307
+ # Eval!
308
+ logger.info("***** Running evaluation {} *****".format(prefix))
309
+ logger.info(" Num examples = %d", len(eval_dataloader))
310
+ logger.info(" Batch size = %d", args.eval_batch_size)
311
+ eval_loss = 0.0
312
+ eval_loss_sum = 0.0
313
+ nb_eval_steps = 0
314
+ report_num_words = 0
315
+
316
+ model.eval()
317
+
318
+ for batch in tqdm(eval_dataloader, desc="Evaluating"):
319
+
320
+ _, tokenized_text1, tokenized_text_lengths = batch
321
+ inputs, labels = tokenized_text1.to(args.device), tokenized_text1.to(args.device)
322
+
323
+ x_lengths = tokenized_text_lengths[:,1].to(args.device)
324
+ report_num_words += x_lengths.sum().item()
325
+
326
+
327
+ with torch.no_grad():
328
+ outputs = model(inputs, labels=labels, label_ignore=decoder_tokenizer.pad_token_id)
329
+ lm_loss = outputs[0]
330
+
331
+ eval_loss += lm_loss.mean().item()/x_lengths.sum().item()
332
+ eval_loss_sum += lm_loss.sum().item()
333
+
334
+ nb_eval_steps += 1
335
+
336
+ eval_loss = eval_loss / nb_eval_steps
337
+ perplexity1 = torch.exp(torch.tensor(eval_loss))
338
+ perplexity2 = torch.exp(torch.tensor(eval_loss_sum / report_num_words))
339
+
340
+
341
+ result = {
342
+ "perplexity1": perplexity1, "perplexity2": perplexity2
343
+ }
344
+
345
+ output_eval_file = os.path.join(eval_output_dir, "eval_results.txt")
346
+ with open(output_eval_file, "w") as writer:
347
+ logger.info("***** Eval results {} *****".format(prefix))
348
+ for key in sorted(result.keys()):
349
+ logger.info(" %s = %s", key, str(result[key]))
350
+ writer.write("%s = %s\n" % (key, str(result[key])))
351
+
352
+
353
+
354
+
355
+ return result
356
+
357
+
358
+ def main():
359
+ parser = argparse.ArgumentParser()
360
+
361
+ ## Required parameters
362
+ parser.add_argument("--train_data_file", default=None, type=str, required=True,
363
+ help="The input training data file (a text file).")
364
+ parser.add_argument("--output_dir", default=None, type=str, required=True,
365
+ help="The output directory where the model predictions and checkpoints will be written.")
366
+ parser.add_argument("--dataset", default=None, type=str, help="The dataset.")
367
+
368
+ ## Other parameters
369
+ parser.add_argument("--eval_data_file", default=None, type=str,
370
+ help="An optional input evaluation data file to evaluate the perplexity on (a text file).")
371
+ parser.add_argument("--ExpName", default="", type=str,
372
+ help="The experiment name used in Azure Table.")
373
+ parser.add_argument("--save_bert_gpt_init", action='store_true',
374
+ help="Use Philly for computing.")
375
+
376
+
377
+ ## Encoder options
378
+ parser.add_argument("--encoder_model_type", default="bert", type=str,
379
+ help="The encoder model architecture to be fine-tuned.")
380
+ parser.add_argument("--encoder_model_name_or_path", default="bert-base-cased", type=str,
381
+ help="The encoder model checkpoint for weights initialization.")
382
+ parser.add_argument("--encoder_config_name", default="", type=str,
383
+ help="Optional pretrained config name or path if not the same as model_name_or_path")
384
+ parser.add_argument("--encoder_tokenizer_name", default="", type=str,
385
+ help="Optional pretrained tokenizer name or path if not the same as model_name_or_path")
386
+
387
+ ## Decoder options
388
+ parser.add_argument("--decoder_model_type", default="gpt2", type=str,
389
+ help="The decoder model architecture to be fine-tuned.")
390
+ parser.add_argument("--decoder_model_name_or_path", default="bert-base-cased", type=str,
391
+ help="The decoder model checkpoint for weights initialization.")
392
+ parser.add_argument("--decoder_config_name", default="", type=str,
393
+ help="Optional pretrained config name or path if not the same as model_name_or_path")
394
+ parser.add_argument("--decoder_tokenizer_name", default="", type=str,
395
+ help="Optional pretrained tokenizer name or path if not the same as model_name_or_path")
396
+
397
+ ## Variational auto-encoder
398
+ parser.add_argument("--latent_size", default=32, type=int, help="Latent space dimension.")
399
+ parser.add_argument("--use_deterministic_connect", action='store_true',
400
+ help="Use deterministic inference to generate latent codes, i.e., standard auto-encoders.")
401
+ parser.add_argument("--use_pretrained_model", action='store_true',
402
+ help="Use pre-trained auto-encoder models as the initialization")
403
+
404
+ ## Objective functions
405
+ parser.add_argument("--mlm", action='store_true',
406
+ help="Train with masked-language modeling loss instead of language modeling.")
407
+ parser.add_argument("--mlm_probability", type=float, default=0.15,
408
+ help="Ratio of tokens to mask for masked language modeling loss")
409
+ parser.add_argument("--beta", type=float, default=1.0,
410
+ help="The weighting hyper-parameter of the KL term in VAE")
411
+
412
+
413
+ parser.add_argument("--cache_dir", default="", type=str,
414
+ help="Optional directory to store the pre-trained models downloaded from s3 (instread of the default one)")
415
+ parser.add_argument("--max_seq_length", default=512, type=int,
416
+ help="Optional input sequence length before tokenization. The sequence will be dropped if it is longer the max_seq_length")
417
+ parser.add_argument("--block_size", default=-1, type=int,
418
+ help="Optional input sequence length after tokenization."
419
+ "The training dataset will be truncated in block of this size for training."
420
+ "Default to the model max input length for single sentence inputs (take into account special tokens).")
421
+ parser.add_argument("--do_train", action='store_true',
422
+ help="Whether to run training.")
423
+ parser.add_argument("--do_eval", action='store_true',
424
+ help="Whether to run eval on the dev set.")
425
+ parser.add_argument("--evaluate_during_training", action='store_true',
426
+ help="Run evaluation during training at each logging step.")
427
+ parser.add_argument("--do_lower_case", action='store_true',
428
+ help="Set this flag if you are using an uncased model.")
429
+
430
+
431
+ # Training Schedules
432
+ parser.add_argument("--ratio_increase", default=0.25, type=float,
433
+ help="Learning schedule, the percentage for the annealing stage.")
434
+ parser.add_argument("--ratio_zero", default=0.25, type=float,
435
+ help="Learning schedule, the percentage for the pure auto-encoding stage.")
436
+ parser.add_argument("--fb_mode", default=0, type=int,
437
+ help="free bit training mode.")
438
+ parser.add_argument("--dim_target_kl", default=3.0, type=float,
439
+ help="dim_target_kl free bit training mode.")
440
+ parser.add_argument("--per_gpu_train_batch_size", default=4, type=int,
441
+ help="Batch size per GPU/CPU for training.")
442
+ parser.add_argument("--per_gpu_eval_batch_size", default=1, type=int,
443
+ help="Batch size per GPU/CPU for evaluation.")
444
+ parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
445
+ help="Number of updates steps to accumulate before performing a backward/update pass.")
446
+ parser.add_argument("--learning_rate", default=5e-5, type=float,
447
+ help="The initial learning rate for Adam.")
448
+ parser.add_argument("--weight_decay", default=0.0, type=float,
449
+ help="Weight deay if we apply some.")
450
+ parser.add_argument("--adam_epsilon", default=1e-8, type=float,
451
+ help="Epsilon for Adam optimizer.")
452
+ parser.add_argument("--max_grad_norm", default=1.0, type=float,
453
+ help="Max gradient norm.")
454
+ parser.add_argument("--num_train_epochs", default=1.0, type=float,
455
+ help="Total number of training epochs to perform.")
456
+ parser.add_argument("--max_steps", default=-1, type=int,
457
+ help="If > 0: set total number of training steps to perform. Override num_train_epochs.")
458
+ parser.add_argument("--warmup_steps", default=0, type=int,
459
+ help="Linear warmup over warmup_steps.")
460
+ parser.add_argument("--use_philly", action='store_true',
461
+ help="Use Philly for computing.")
462
+
463
+
464
+ ## IO: Logging and Saving
465
+ parser.add_argument('--logging_steps', type=int, default=50,
466
+ help="Log every X updates steps.")
467
+ parser.add_argument('--save_steps', type=int, default=50,
468
+ help="Save checkpoint every X updates steps.")
469
+ parser.add_argument("--eval_all_checkpoints", action='store_true',
470
+ help="Evaluate all checkpoints starting with the same prefix as model_name_or_path ending and ending with step number")
471
+ parser.add_argument("--no_cuda", action='store_true',
472
+ help="Avoid using CUDA when available")
473
+ parser.add_argument('--overwrite_output_dir', action='store_true',
474
+ help="Overwrite the content of the output directory")
475
+ parser.add_argument('--overwrite_cache', action='store_true',
476
+ help="Overwrite the cached training and evaluation sets")
477
+ parser.add_argument('--seed', type=int, default=42,
478
+ help="random seed for initialization")
479
+ parser.add_argument('--gloabl_step_eval', type=int, default=661,
480
+ help="Evaluate the results at the given global step")
481
+
482
+ # Precision & Distributed Training
483
+ parser.add_argument('--fp16', action='store_true',
484
+ help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit")
485
+ parser.add_argument('--fp16_opt_level', type=str, default='O1',
486
+ help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
487
+ "See details at https://nvidia.github.io/apex/amp.html")
488
+ parser.add_argument("--local_rank", type=int, default=-1,
489
+ help="For distributed training: local_rank")
490
+ parser.add_argument('--server_ip', type=str, default='', help="For distant debugging.")
491
+ parser.add_argument('--server_port', type=str, default='', help="For distant debugging.")
492
+ args = parser.parse_args()
493
+
494
+ if args.decoder_model_type in ["bert", "roberta"] and not args.mlm:
495
+ raise ValueError("BERT and RoBERTa do not have LM heads but masked LM heads. They must be run using the --mlm "
496
+ "flag (masked language modeling).")
497
+ if args.eval_data_file is None and args.do_eval:
498
+ raise ValueError("Cannot do evaluation without an evaluation data file. Either supply a file to --eval_data_file "
499
+ "or remove the --do_eval argument.")
500
+
501
+ if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir:
502
+ raise ValueError("Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir))
503
+
504
+ # Setup distant debugging if needed
505
+ if args.server_ip and args.server_port:
506
+ # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
507
+ import ptvsd
508
+ print("Waiting for debugger attach")
509
+ ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
510
+ ptvsd.wait_for_attach()
511
+
512
+ # Setup CUDA, GPU & distributed training
513
+ if args.local_rank == -1 or args.no_cuda:
514
+ device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
515
+ args.n_gpu = torch.cuda.device_count()
516
+ else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
517
+ torch.cuda.set_device(args.local_rank)
518
+ device = torch.device("cuda", args.local_rank)
519
+ torch.distributed.init_process_group(backend='nccl')
520
+ args.n_gpu = 1
521
+ args.device = device
522
+
523
+ # Setup logging
524
+ logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
525
+ datefmt = '%m/%d/%Y %H:%M:%S',
526
+ level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
527
+ logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
528
+ args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16)
529
+
530
+ args.ExpName = 'Vae_' + args.dataset + '_Nz_' + str(args.latent_size) + '_Beta_' + str(args.beta) + '_Dkl_' + str(args.dim_target_kl) + '_Ra_' + str(args.ratio_increase) + '_R0_' + str(args.ratio_zero)
531
+ table_name = 'Vae' + args.dataset + 'Nz' + str(args.latent_size)
532
+ try:
533
+ ts.create_table(table_name)
534
+ except:
535
+ pass
536
+
537
+
538
+ # Set seed
539
+ set_seed(args)
540
+
541
+ # Load pretrained model and tokenizer
542
+ if args.local_rank not in [-1, 0]:
543
+ torch.distributed.barrier() # Barrier to make sure only the first process in distributed training download model & vocab
544
+
545
+
546
+ ## Encoder
547
+ encoder_config_class, encoder_model_class, encoder_tokenizer_class = MODEL_CLASSES[args.encoder_model_type]
548
+ encoder_config = encoder_config_class.from_pretrained(args.encoder_config_name if args.encoder_config_name else args.encoder_model_name_or_path)
549
+ tokenizer_encoder = encoder_tokenizer_class.from_pretrained(args.encoder_tokenizer_name if args.encoder_tokenizer_name else args.encoder_model_name_or_path, do_lower_case=args.do_lower_case)
550
+ if args.block_size <= 0:
551
+ args.block_size = tokenizer_encoder.max_len_single_sentence # Our input block size will be the max possible for the model
552
+ args.block_size = min(args.block_size, tokenizer_encoder.max_len_single_sentence)
553
+ model_encoder = encoder_model_class.from_pretrained(args.encoder_model_name_or_path, from_tf=bool('.ckpt' in args.encoder_model_name_or_path), config=encoder_config, latent_size=args.latent_size)
554
+ # model_encoder.to(args.device)
555
+
556
+ ## Decoder
557
+ decoder_config_class, decoder_model_class, decoder_tokenizer_class = MODEL_CLASSES[args.decoder_model_type]
558
+ decoder_config = decoder_config_class.from_pretrained(args.decoder_config_name if args.decoder_config_name else args.decoder_model_name_or_path)
559
+ tokenizer_decoder = decoder_tokenizer_class.from_pretrained(args.decoder_tokenizer_name if args.decoder_tokenizer_name else args.decoder_model_name_or_path, do_lower_case=args.do_lower_case)
560
+ if args.block_size <= 0:
561
+ args.block_size = tokenizer_decoder.max_len_single_sentence # Our input block size will be the max possible for the model
562
+ args.block_size = min(args.block_size, tokenizer_decoder.max_len_single_sentence)
563
+ model_decoder = decoder_model_class.from_pretrained(args.decoder_model_name_or_path, from_tf=bool('.ckpt' in args.decoder_model_name_or_path), config=decoder_config)
564
+
565
+ # Chunyuan: Add Padding token to GPT2
566
+ special_tokens_dict = {'pad_token': '<PAD>', 'bos_token': '<BOS>', 'eos_token': '<EOS>'}
567
+ num_added_toks = tokenizer_decoder.add_special_tokens(special_tokens_dict)
568
+ print('We have added', num_added_toks, 'tokens to GPT2')
569
+ model_decoder.resize_token_embeddings(len(tokenizer_decoder)) # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer.
570
+ assert tokenizer_decoder.pad_token == '<PAD>'
571
+
572
+ model_decoder.to(args.device)
573
+
574
+
575
+ if args.local_rank == 0:
576
+ torch.distributed.barrier() # End of barrier to make sure only the first process in distributed training download model & vocab
577
+
578
+ logger.info("Training/evaluation parameters %s", args)
579
+
580
+ global_step= 0
581
+ # Training
582
+ if args.do_train:
583
+ if args.local_rank not in [-1, 0]:
584
+ torch.distributed.barrier() # Barrier to make sure only the first process in distributed training process the dataset, and the others will use the cache
585
+
586
+ train_dataloader = build_dataload_and_cache_examples(args, [tokenizer_encoder, tokenizer_decoder], evaluate=False)
587
+
588
+ if args.local_rank == 0:
589
+ torch.distributed.barrier()
590
+
591
+ global_step, tr_loss = train(args, train_dataloader, model_decoder, tokenizer_encoder, tokenizer_decoder, table_name)
592
+ logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
593
+
594
+
595
+ # Saving best-practices: if you use save_pretrained for the model and tokenizer, you can reload them using from_pretrained()
596
+ if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
597
+ # Create output directory if needed
598
+ # Save model checkpoint
599
+ output_decoder_dir = os.path.join(args.output_dir, 'checkpoint-decoder-{}'.format(global_step))
600
+ if not os.path.exists(output_decoder_dir) and args.local_rank in [-1, 0]:
601
+ os.makedirs(output_decoder_dir)
602
+
603
+
604
+ logger.info("Saving decoder model checkpoint to %s", output_decoder_dir)
605
+ # Save a trained model, configuration and tokenizer using `save_pretrained()`.
606
+ # They can then be reloaded using `from_pretrained()`
607
+
608
+ model_decoder_to_save = model_decoder.module if hasattr(model_decoder, 'module') else model_decoder # Take care of distributed/parallel training
609
+
610
+ # Good practice: save your training arguments together with the trained model
611
+
612
+ if args.use_philly:
613
+ save_solid = False
614
+ while not save_solid:
615
+ try:
616
+ model_decoder_to_save.save_pretrained(output_decoder_dir)
617
+ torch.save(args, os.path.join(output_decoder_dir, 'training_decoder_args.bin'))
618
+ save_solid = True
619
+ except:
620
+ pass
621
+ else:
622
+ model_decoder_to_save.save_pretrained(output_decoder_dir)
623
+ torch.save(args, os.path.join(output_decoder_dir, 'training_encoder_args.bin'))
624
+
625
+ # Load a trained model and vocabulary that you have fine-tuned
626
+ model_decoder = decoder_model_class.from_pretrained(output_decoder_dir)
627
+ model_decoder.to(args.device)
628
+
629
+
630
+ # Evaluation
631
+ results = {}
632
+ if args.do_eval and args.local_rank in [-1, 0]:
633
+ if global_step == 0:
634
+ global_step = args.gloabl_step_eval
635
+
636
+ output_decoder_dir = os.path.join(args.output_dir, 'checkpoint-decoder-{}'.format(global_step))
637
+ checkpoints = [ output_decoder_dir ]
638
+
639
+ logger.info("Evaluate the following checkpoints: %s", checkpoints)
640
+ for checkpoint in checkpoints:
641
+ global_step = checkpoint.split('-')[-1] if len(checkpoints) > 1 else ""
642
+
643
+ model_decoder = decoder_model_class.from_pretrained(checkpoint)
644
+ model_decoder.to(args.device)
645
+
646
+ result = evaluate(args, model_decoder, tokenizer_encoder, tokenizer_decoder, table_name, prefix=global_step, subset='test')
647
+ result = dict((k + '_{}'.format(global_step), v) for k, v in result.items())
648
+ results.update(result)
649
+
650
+ # result = evaluate(args, model_vae, tokenizer_encoder, tokenizer_decoder, table_name, prefix=global_step, subset='train')
651
+ # result = dict((k + '_{}'.format(global_step), v) for k, v in result.items())
652
+ # results.update(result)
653
+
654
+ return results
655
+
656
+
657
+ if __name__ == "__main__":
658
+ main()
Optimus/code/examples/big_ae/run_lm_vae_label_ctrl_gen.py ADDED
@@ -0,0 +1,875 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Fine-tuning the library models for language modeling on a text file (GPT, GPT-2, BERT, RoBERTa).
18
+ GPT and GPT-2 are fine-tuned using a causal language modeling (CLM) loss while BERT and RoBERTa are fine-tuned
19
+ using a masked language modeling (MLM) loss.
20
+ """
21
+
22
+ from __future__ import absolute_import, division, print_function
23
+ import pdb
24
+ import argparse
25
+ import glob
26
+ import logging
27
+ import os
28
+ import pickle
29
+ import random
30
+ import numpy as np
31
+ import torch
32
+ from torch.utils.data import DataLoader, Dataset, SequentialSampler, RandomSampler, TensorDataset
33
+ from torch.utils.data.distributed import DistributedSampler
34
+ from tensorboardX import SummaryWriter
35
+ from tqdm import tqdm, trange
36
+ from collections import defaultdict
37
+ # from azure.cosmosdb.table.tableservice import TableService
38
+ # from azure.cosmosdb.table.models import Entity
39
+ from datetime import datetime
40
+ import sys
41
+ import json
42
+ import nltk
43
+ nltk.download('punkt')
44
+
45
+ sys.path.append('../../')
46
+ from pytorch_transformers import (WEIGHTS_NAME, AdamW, WarmupLinearSchedule,
47
+ BertConfig, BertForLatentConnector, BertTokenizer,
48
+ GPT2Config, GPT2ForLatentConnector, GPT2Tokenizer,
49
+ OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer,
50
+ RobertaConfig, RobertaForMaskedLM, RobertaTokenizer)
51
+ from utils import (TextDataset_Split, TextDataset_2Tokenizers_LCtrlG,
52
+ frange_cycle_linear, frange_cycle_zero_linear, AverageValueMeter)
53
+ # from modules import ARAE
54
+ from modules import CARA
55
+ # logging.getLogger("azure").setLevel(logging.WARNING)
56
+ # logging.getLogger("TableService").setLevel(logging.WARNING)
57
+ logger = logging.getLogger(__name__)
58
+ import time
59
+ def get_time_str():
60
+ return time.ctime().replace(' ', '_').replace(':', '-')
61
+
62
+ MODEL_CLASSES = {
63
+ 'gpt2': (GPT2Config, GPT2ForLatentConnector, GPT2Tokenizer),
64
+ 'openai-gpt': (OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer),
65
+ 'bert': (BertConfig, BertForLatentConnector, BertTokenizer),
66
+ 'roberta': (RobertaConfig, RobertaForMaskedLM, RobertaTokenizer)
67
+ }
68
+
69
+
70
+ storage_name="textae"
71
+ key=r"6yBCXlblof8DVFJ4BD3eNFTrGQCej6cKfCf5z308cKnevyHaG+yl/m+ITVErB9yt0kvN3ToqxLIh0knJEfFmPA=="
72
+ # ts = TableService(account_name=storage_name, account_key=key)
73
+
74
+ def load_and_cache_examples(args, tokenizer, evaluate=False):
75
+ if isinstance(tokenizer, list):
76
+ dataset = TextDataset_2Tokenizers_LCtrlG(tokenizer, args, file_path=args.eval_data_file if evaluate else args.train_data_file,
77
+ block_size=args.block_size, create_new=args.create_new)
78
+ else:
79
+ raise NotImplementedError
80
+ # dataset = TextDataset_Split(tokenizer, args, file_path=args.eval_data_file if evaluate else args.train_data_file, block_size=args.block_size)
81
+ return dataset
82
+
83
+ def set_seed(args):
84
+ random.seed(args.seed)
85
+ np.random.seed(args.seed)
86
+ torch.manual_seed(args.seed)
87
+ if args.n_gpu > 0:
88
+ torch.cuda.manual_seed_all(args.seed)
89
+
90
+ def mask_tokens(inputs, tokenizer, args):
91
+ """ Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. """
92
+ labels = inputs.clone()
93
+ # We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
94
+
95
+ masked_indices = torch.bernoulli(torch.full(labels.shape, args.mlm_probability)).to(torch.uint8)
96
+ labels[masked_indices==1] = -1 # We only compute loss on masked tokens
97
+
98
+ # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
99
+ indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).to(torch.uint8) & masked_indices
100
+ inputs[indices_replaced] = tokenizer.convert_tokens_to_ids(tokenizer.mask_token)
101
+
102
+ # 10% of the time, we replace masked input tokens with random word
103
+ indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).to(torch.uint8) & masked_indices & ~indices_replaced
104
+ indices_random = indices_random
105
+ random_words = torch.randint(len(tokenizer), labels.shape, dtype=torch.long)
106
+ inputs[indices_random] = random_words[indices_random]
107
+
108
+ # The rest of the time (10% of the time) we keep the masked input tokens unchanged
109
+ return inputs, labels
110
+
111
+ def train(args, train_dataset, model_vae, encoder_tokenizer, decoder_tokenizer, table_name, logff):
112
+ """ Train the model """
113
+ if args.local_rank in [-1, 0]:
114
+ tb_writer = SummaryWriter()
115
+ args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
116
+ train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
117
+ train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
118
+ if args.max_steps > 0:
119
+ t_total = args.max_steps
120
+ args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
121
+ else:
122
+ t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
123
+ # Prepare optimizer and schedule (linear warmup and decay)
124
+ # model_encoder, model_decoder, model_connector = model_vae.encoder, model_vae.decoder, model_vae.linear
125
+ no_decay = ['bias', 'LayerNorm.weight']
126
+ optimizer_grouped_parameters = [
127
+ {'params': [p for n, p in model_vae.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
128
+ {'params': [p for n, p in model_vae.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
129
+ ]
130
+ optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
131
+ scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total)
132
+ if args.fp16:
133
+ try:
134
+ from apex import amp
135
+ except ImportError:
136
+ raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
137
+ model_vae, optimizer = amp.initialize(model_vae, optimizer, opt_level=args.fp16_opt_level)
138
+ # multi-gpu training (should be after apex fp16 initialization)
139
+ if args.n_gpu > 1:
140
+ model_vae = torch.nn.DataParallel(model_vae, device_ids=range(args.n_gpu)).to(args.device)
141
+ # Distributed training (should be after apex fp16 initialization)
142
+ if args.local_rank != -1:
143
+ model_vae = torch.nn.parallel.DistributedDataParallel(model_vae, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True)
144
+ # model_vae = model_vae.module if hasattr(model_vae, 'module') else model_vae # Take care of distributed/parallel training
145
+
146
+ # Train!
147
+ logger.info("***** Running training *****")
148
+ logff.write("***** Running training *****\n")
149
+ logger.info(" Num examples = {}".format(len(train_dataset)))
150
+ logff.write(" Num examples = {}\n".format(len(train_dataset)))
151
+ logger.info(" Num Epochs = {}".format(args.num_train_epochs))
152
+ logff.write(" Num Epochs = {}\n".format(args.num_train_epochs))
153
+ logger.info(" Instantaneous batch size per GPU = {}".format(args.per_gpu_train_batch_size))
154
+ logff.write(" Instantaneous batch size per GPU = {}\n".format(args.per_gpu_train_batch_size))
155
+ logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d",
156
+ args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
157
+ logff.write(" Total train batch size (w. parallel, distributed & accumulation) = {}\n".format(
158
+ args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1)))
159
+ logger.info(" Gradient Accumulation steps = {}".format(args.gradient_accumulation_steps))
160
+ logff.write(" Gradient Accumulation steps = {}\n".format(args.gradient_accumulation_steps))
161
+ logger.info(" Total optimization steps = {}".format( t_total))
162
+ logff.write(" Total optimization steps = {}\n".format(t_total))
163
+ logff.flush()
164
+ global_step = 0
165
+ tr_loss, logging_loss = 0.0, 0.0
166
+ model_vae.zero_grad()
167
+ train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0])
168
+ n_iter = int(args.num_train_epochs) * len(train_dataloader)
169
+ beta_t_list = frange_cycle_zero_linear(n_iter, start=1.0, stop=args.beta_cls, n_cycle=1, ratio_increase=args.ratio_increase, ratio_zero=args.ratio_zero)
170
+
171
+ set_seed(args) # Added here for reproducibility (even between python 2 and 3)
172
+ accmeter = {
173
+ 'acc_encode_z_dis': AverageValueMeter(),
174
+ 'acc_gen_z_dis': AverageValueMeter(),
175
+ 'acc_encode_z_cls': AverageValueMeter(),
176
+ 'acc_cls': AverageValueMeter(),
177
+ # 'acc_at_soft_cls': AverageValueMeter(),
178
+ }
179
+ lossmeter = {
180
+ 'loss': AverageValueMeter(),
181
+ 'loss_rec': AverageValueMeter(),
182
+ 'loss_encoder': AverageValueMeter(),
183
+ 'loss_lsc': AverageValueMeter(),
184
+ 'loss_lsd': AverageValueMeter(),
185
+ 'loss_lsg': AverageValueMeter(),
186
+ 'loss_cls': AverageValueMeter(),
187
+ # 'loss_at_soft_cls': AverageValueMeter(),
188
+ }
189
+ for epoch in train_iterator:
190
+ epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
191
+ # pbar = tqdm(total=(len(train_dataloader)+1) // args.gradient_accumulation_steps)
192
+ for step, batch in enumerate(train_dataloader):
193
+
194
+ # if step > 100:
195
+ # break
196
+
197
+ # Data
198
+ input_seq_ids, tgt_seq_ids, tokenized_text_lengths, cond_labels = batch
199
+ max_len_values, _ = tokenized_text_lengths.max(0)
200
+ input_seq_ids = input_seq_ids[:,:max_len_values[0]]
201
+ tgt_seq_ids = tgt_seq_ids[:,:max_len_values[1]]
202
+ input_seq_ids, tgt_seq_ids = mask_tokens(input_seq_ids, encoder_tokenizer, args) if args.mlm else (input_seq_ids, tgt_seq_ids)
203
+ input_seq_ids = input_seq_ids.to(args.device)
204
+ tgt_seq_ids = tgt_seq_ids.to(args.device)
205
+ cond_labels = cond_labels.to(args.device)
206
+ input_mask = torch.where(torch.arange(max_len_values[0].item()).unsqueeze(0).repeat(input_seq_ids.size(0), 1).type_as(tokenized_text_lengths).to(args.device)
207
+ < tokenized_text_lengths[:, 0].unsqueeze(1).to(args.device), torch.ones_like(input_seq_ids), torch.zeros_like(input_seq_ids))
208
+
209
+ # Configs
210
+ model_vae.train()
211
+ beta_t = beta_t_list[step + epoch*len(epoch_iterator)]
212
+ model_vae.module.args.beta_cls = beta_t
213
+ # if beta_t == 0.0:
214
+ # model_vae.args.fb_mode = 0
215
+ # else:
216
+ # model_vae.args.fb_mode = 1
217
+ # if args.use_deterministic_connect:
218
+ # model_vae.args.fb_mode = 2
219
+
220
+ # Model
221
+ loss_dict, acc_dict = model_vae(input_seq_ids=input_seq_ids, tgt_seq_ids=tgt_seq_ids, cond_labels=cond_labels, attention_mask=input_mask)
222
+
223
+ # Loss
224
+ for key, value in loss_dict.items():
225
+ loss_dict[key] = value.mean()
226
+
227
+ loss = loss_dict['loss']
228
+ if args.gradient_accumulation_steps > 1:
229
+ loss = loss / args.gradient_accumulation_steps
230
+ if args.fp16:
231
+ with amp.scale_loss(loss, optimizer) as scaled_loss:
232
+ scaled_loss.backward()
233
+ else:
234
+ loss.backward()
235
+ tr_loss += loss.item()
236
+
237
+ # Log
238
+ for key, value in loss_dict.items():
239
+ lossmeter[key].add(value.item())
240
+
241
+ for key, value in acc_dict.items():
242
+ value = value.cpu().tolist()
243
+ for v in value:
244
+ accmeter[key].add(float(v))
245
+
246
+ # Optimize
247
+ if (step + 1) % args.gradient_accumulation_steps == 0:
248
+ # Optimize
249
+ if args.fp16:
250
+ torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
251
+ else:
252
+ torch.nn.utils.clip_grad_norm_(model_vae.parameters(), args.max_grad_norm)
253
+ optimizer.step()
254
+ scheduler.step() # Update learning rate schedule
255
+ model_vae.zero_grad()
256
+ global_step += 1
257
+ # pbar.update(1)
258
+
259
+ # Log
260
+ if global_step % args.logging_steps == 0:
261
+ logger.info("\n")
262
+ logger.info("global_step: {}, avg loss: {:3f}".format(global_step, tr_loss/global_step))
263
+ logff.write("global_step: {}, avg loss: {:3f}\n".format(global_step, tr_loss/global_step))
264
+ logger.info("loss: {}".format(', '.join(key + ': ' + str(round(meter.mean, 3)) for key, meter in lossmeter.items())))
265
+ logff.write("loss: {}\n".format(', '.join(key + ': ' + str(round(meter.mean, 3)) for key, meter in lossmeter.items())))
266
+ logger.info("acc: {}".format(', '.join(key + ': ' + str(round(meter.mean, 3)) for key, meter in accmeter.items())))
267
+ logff.write("acc: {}\n".format(', '.join(key + ': ' + str(round(meter.mean, 3)) for key, meter in accmeter.items())))
268
+ logff.flush()
269
+
270
+
271
+ if args.use_philly:
272
+ #if args.local_rank in [-1, 0]:
273
+ if args.logging_steps > 0 and global_step % args.logging_steps == 0:
274
+ logger.info("PROGRESS: {}%".format(round(100 * (step + epoch*len(train_dataloader) ) /(int(args.num_train_epochs) * len(train_dataloader)) , 4)))
275
+ logger.info("EVALERR: {}%".format(tr_loss / global_step))
276
+
277
+
278
+ if args.local_rank in [-1, 0] and args.eval_steps > 0 and global_step % args.eval_steps == 0:
279
+ # Log metrics
280
+ if args.local_rank == -1 and args.evaluate_during_training: # Only evaluate when single GPU otherwise metrics may not average well
281
+ results = evaluate(args, model_vae, encoder_tokenizer, decoder_tokenizer, table_name, epoch=epoch)
282
+ for key, value in results.items():
283
+ tb_writer.add_scalar('eval_{}'.format(key), value, global_step)
284
+ tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step)
285
+ tb_writer.add_scalar('loss', (tr_loss - logging_loss)/args.eval_steps, global_step)
286
+ logging_loss = tr_loss
287
+
288
+ # Save checkpoints
289
+ if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
290
+ # Save encoder model checkpoint
291
+ output_encoder_dir = os.path.join(args.output_dir, 'checkpoint-encoder-{}'.format(global_step))
292
+ if not os.path.exists(output_encoder_dir):
293
+ os.makedirs(output_encoder_dir)
294
+ model_encoder_to_save = model_vae.module.encoder if hasattr(model_vae, 'module') else model_vae.encoder # Take care of distributed/parallel training
295
+ if args.use_philly:
296
+ save_solid = False
297
+ while not save_solid:
298
+ try:
299
+ model_encoder_to_save.save_pretrained(output_encoder_dir)
300
+ torch.save(args, os.path.join(output_encoder_dir, 'training_args.bin'))
301
+ logger.info("Saving model checkpoint to %s", output_encoder_dir)
302
+ save_solid = True
303
+ except:
304
+ pass
305
+ else:
306
+ model_encoder_to_save.save_pretrained(output_encoder_dir)
307
+ torch.save(args, os.path.join(output_encoder_dir, 'training_args.bin'))
308
+ logger.info("Saving model checkpoint to %s", output_encoder_dir)
309
+
310
+ # Save decoder model checkpoint
311
+ output_decoder_dir = os.path.join(args.output_dir, 'checkpoint-decoder-{}'.format(global_step))
312
+ if not os.path.exists(output_decoder_dir):
313
+ os.makedirs(output_decoder_dir)
314
+ model_decoder_to_save = model_vae.module.decoder if hasattr(model_vae, 'module') else model_vae.decoder # Take care of distributed/parallel training
315
+ if args.use_philly:
316
+ save_solid = False
317
+ while not save_solid:
318
+ try:
319
+ model_decoder_to_save.save_pretrained(output_decoder_dir)
320
+ torch.save(args, os.path.join(output_decoder_dir, 'training_args.bin'))
321
+ logger.info("Saving model checkpoint to %s", output_decoder_dir)
322
+ save_solid = True
323
+ except:
324
+ pass
325
+ else:
326
+ model_decoder_to_save.save_pretrained(output_decoder_dir)
327
+ torch.save(args, os.path.join(output_decoder_dir, 'training_args.bin'))
328
+ logger.info("Saving model checkpoint to %s", output_decoder_dir)
329
+
330
+ if args.max_steps > 0 and global_step > args.max_steps:
331
+ break
332
+
333
+ if args.max_steps > 0 and global_step > args.max_steps:
334
+ train_iterator.close()
335
+ break
336
+
337
+ if args.local_rank in [-1, 0]:
338
+ tb_writer.close()
339
+
340
+ return global_step, tr_loss / global_step
341
+
342
+
343
+ def evaluate(args, model_vae, encoder_tokenizer, decoder_tokenizer, table_name, prefix="", subset="test", epoch=None):
344
+
345
+ eval_output_dir = args.output_dir
346
+
347
+ if subset == 'test':
348
+ eval_dataset = load_and_cache_examples(args, [encoder_tokenizer, decoder_tokenizer], evaluate=True)
349
+ elif subset == 'train':
350
+ eval_dataset = load_and_cache_examples(args, [encoder_tokenizer, decoder_tokenizer], evaluate=False)
351
+ else:
352
+ raise ValueError
353
+
354
+ args.label_size = len(eval_dataset.get_labels())
355
+
356
+ if not os.path.exists(eval_output_dir) and args.local_rank in [-1, 0]:
357
+ os.makedirs(eval_output_dir)
358
+
359
+ args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
360
+ # Note that DistributedSampler samples randomly
361
+ eval_sampler = SequentialSampler(eval_dataset) if args.local_rank == -1 else DistributedSampler(eval_dataset)
362
+ eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
363
+
364
+ # Eval!
365
+ logger.info("***** Running evaluation {} *****".format(prefix))
366
+ logger.info(" Num examples = %d", len(eval_dataset))
367
+ logger.info(" Batch size = %d", args.eval_batch_size)
368
+ logger.info(" Num steps = %d", len(eval_dataset) // args.eval_batch_size)
369
+ logger.info(" eval_output_dir = %s", eval_output_dir)
370
+
371
+ model_vae.eval()
372
+ model_vae_module = model_vae.module if hasattr(model_vae, 'module') else model_vae # Take care of distributed/parallel training
373
+
374
+ outputs = {
375
+ 'sampled_cond_labels': None,
376
+ 'cond_labels': None,
377
+ 'tgt_seq_ids': None,
378
+ 'generated': None,
379
+ 'at_generated': None,
380
+ 'cg_generated': None,
381
+ 'pred_cls': None,
382
+ 'pred_ge_cls': None,
383
+ 'pred_at_cls': None,
384
+ 'pred_cg_cls': None,
385
+ }
386
+
387
+ for bi, batch in enumerate(tqdm(eval_dataloader, desc="#Sentences", disable=args.local_rank not in [-1, 0]) ):
388
+ # if bi == 3:
389
+ # break
390
+
391
+ # Data
392
+ input_seq_ids, tgt_seq_ids, tokenized_text_lengths, cond_labels = batch
393
+ max_len_values, _ = tokenized_text_lengths.max(0)
394
+ input_seq_ids = input_seq_ids[:,:max_len_values[0]]
395
+ tgt_seq_ids = tgt_seq_ids[:,:max_len_values[1]]
396
+ input_seq_ids = input_seq_ids.to(args.device)
397
+ tgt_seq_ids = tgt_seq_ids.to(args.device)
398
+ cond_labels = cond_labels.to(args.device)
399
+ input_mask = torch.where(torch.arange(max_len_values[0].item()).unsqueeze(0).repeat(input_seq_ids.size(0), 1).type_as(tokenized_text_lengths).to(args.device)
400
+ < tokenized_text_lengths[:, 0].unsqueeze(1).to(args.device), torch.ones_like(input_seq_ids), torch.zeros_like(input_seq_ids))
401
+
402
+ # Model
403
+ with torch.no_grad():
404
+ result = model_vae(input_seq_ids=input_seq_ids, tgt_seq_ids=tgt_seq_ids, cond_labels=cond_labels, attention_mask=input_mask)
405
+ if bi == 0:
406
+ for key in outputs.keys():
407
+ outputs[key] = result[key].cpu().tolist()
408
+ else:
409
+ for key in outputs.keys():
410
+ outputs[key].extend(result[key].cpu().tolist())
411
+
412
+ # compute accuracies and store in results
413
+ acc = np.mean(np.array(np.array(outputs['pred_cls']) == np.array(outputs['cond_labels']), dtype=np.float))
414
+ acc_ge = np.mean(np.array(np.array(outputs['pred_ge_cls']) == np.array(outputs['cond_labels']), dtype=np.float))
415
+ acc_at = np.mean(np.array(np.array(outputs['pred_at_cls']) == np.array(outputs['sampled_cond_labels']), dtype=np.float))
416
+ acc_cg = np.mean(np.array(np.array(outputs['pred_cg_cls']) == np.array(outputs['sampled_cond_labels']), dtype=np.float))
417
+ metrics = {'acc': acc, 'acc_ge': acc_ge, 'acc_at': acc_at, 'acc_cg': acc_cg}
418
+
419
+ # dump generated outputs to file.
420
+ json.dump(outputs, open(os.path.join(eval_output_dir, "outputs_{}.json".format(epoch) if epoch is not None else "outputs.json"), 'w'))
421
+
422
+ # compute BLEU
423
+ bos_token_id = model_vae_module.tokenizer_decoder.encode('<BOS>')[0]
424
+ eos_token_id = model_vae_module.tokenizer_decoder.encode('<EOS>')[0]
425
+ pad_token_id = model_vae_module.tokenizer_decoder.encode('<PAD>')[0]
426
+
427
+ generated_ids = []
428
+ generated_text = []
429
+ for g in outputs['generated']:
430
+ if g and g[0] in [eos_token_id, bos_token_id]:
431
+ g = g[1:]
432
+ if g and g[0] in [eos_token_id, bos_token_id]:
433
+ g = g[1:]
434
+ g = g[:g.index(eos_token_id)] if eos_token_id in g else g
435
+ g = g[:g.index(pad_token_id)] if pad_token_id in g else g
436
+ g_text = model_vae_module.tokenizer_decoder.decode(g, clean_up_tokenization_spaces=True)
437
+ generated_ids.append(g)
438
+ generated_text.append(g_text)
439
+
440
+ tgt_seq_ids = []
441
+ tgt_seq_text = []
442
+ for g in outputs['tgt_seq_ids']:
443
+ if g and g[0] in [eos_token_id, bos_token_id]:
444
+ g = g[1:]
445
+ if g and g[0] in [eos_token_id, bos_token_id]:
446
+ g = g[1:]
447
+ g = g[:g.index(eos_token_id)] if eos_token_id in g else g
448
+ g = g[:g.index(pad_token_id)] if pad_token_id in g else g
449
+ g_text = model_vae_module.tokenizer_decoder.decode(g, clean_up_tokenization_spaces=True)
450
+ tgt_seq_ids.append(g)
451
+ tgt_seq_text.append(g_text)
452
+
453
+ at_generated_ids = []
454
+ at_generated_text = []
455
+ for g in outputs['at_generated']:
456
+ if g and g[0] in [eos_token_id, bos_token_id]:
457
+ g = g[1:]
458
+ if g and g[0] in [eos_token_id, bos_token_id]:
459
+ g = g[1:]
460
+ g = g[:g.index(eos_token_id)] if eos_token_id in g else g
461
+ g = g[:g.index(pad_token_id)] if pad_token_id in g else g
462
+ g_text = model_vae_module.tokenizer_decoder.decode(g, clean_up_tokenization_spaces=True)
463
+ at_generated_ids.append(g)
464
+ at_generated_text.append(g_text)
465
+
466
+ cg_generated_ids = []
467
+ cg_generated_text = []
468
+ for g in outputs['cg_generated']:
469
+ if g and g[0] in [eos_token_id, bos_token_id]:
470
+ g = g[1:]
471
+ if g and g[0] in [eos_token_id, bos_token_id]:
472
+ g = g[1:]
473
+ g = g[:g.index(eos_token_id)] if eos_token_id in g else g
474
+ g = g[:g.index(pad_token_id)] if pad_token_id in g else g
475
+ g_text = model_vae_module.tokenizer_decoder.decode(g, clean_up_tokenization_spaces=True)
476
+ cg_generated_ids.append(g)
477
+ cg_generated_text.append(g_text)
478
+
479
+ f = open(os.path.join(eval_output_dir, "reconstruction{}.txt".format(('_'+str(epoch)) if epoch is not None else '')), 'w')
480
+ f.write('\n'.join([g + '\n' + t for g, t in zip(generated_text, tgt_seq_text)]))
481
+ fat = open(os.path.join(eval_output_dir, "attribute_transfer{}.txt".format(('_'+str(epoch)) if epoch is not None else '')), 'w')
482
+ fat.write('\n'.join([g + '\n' + t for g, t in zip(at_generated_text, tgt_seq_text)]))
483
+ fcg = open(os.path.join(eval_output_dir, "conditional_generation{}.txt".format(('_'+str(epoch)) if epoch is not None else '')), 'w')
484
+ fcg.write('\n'.join(cg_generated_text))
485
+
486
+ rec_bleu = nltk.translate.bleu_score.corpus_bleu(list_of_references=[[nltk.word_tokenize(t)] for t in tgt_seq_text],
487
+ hypotheses=[nltk.word_tokenize(g) for g in generated_text])
488
+
489
+ at_bleu = nltk.translate.bleu_score.corpus_bleu(list_of_references=[[nltk.word_tokenize(t)] for t in tgt_seq_text],
490
+ hypotheses=[nltk.word_tokenize(g) for g in at_generated_text])
491
+
492
+ cg_generated_text_subset = cg_generated_text[:500] # use a subset, otherwise it takes a long time to compute.
493
+ cg_bleu = nltk.translate.bleu_score.corpus_bleu(list_of_references=[[nltk.word_tokenize(t) for t in tgt_seq_text] for _ in range(len(cg_generated_text_subset))],
494
+ hypotheses=[nltk.word_tokenize(g) for g in cg_generated_text_subset])
495
+
496
+ cg_self_bleu = nltk.translate.bleu_score.corpus_bleu(list_of_references=[[nltk.word_tokenize(t) for t in cg_generated_text_subset[:i]+cg_generated_text_subset[i+1:]]
497
+ for i in range(len(cg_generated_text_subset))],
498
+ hypotheses=[nltk.word_tokenize(g) for g in cg_generated_text_subset])
499
+
500
+ metrics['rec_bleu'] = rec_bleu
501
+ metrics['at_bleu'] = at_bleu
502
+ metrics['cg_bleu'] = cg_bleu
503
+ metrics['cg_self_bleu'] = cg_self_bleu
504
+
505
+ output_eval_file = os.path.join(eval_output_dir, "eval_results.txt")
506
+ writer = open(output_eval_file, "w")
507
+ logger.info("***** Eval results, global steps: {} *****".format(prefix))
508
+ for key, value in metrics.items():
509
+ logger.info(" %s = %s", key, str(value))
510
+ writer.write("%s = %s\n" % (key, str(value)))
511
+
512
+ return metrics
513
+
514
+ def main():
515
+ parser = argparse.ArgumentParser()
516
+
517
+ ## Required parameters
518
+ parser.add_argument("--output_dir", default='results_cara', type=str, help="The output directory where the model predictions and checkpoints will be written.")
519
+ parser.add_argument("--temperature", type=float, default=1.0)
520
+ parser.add_argument("--soft_temperature", type=float, default=0.5)
521
+ parser.add_argument("--top_k", type=int, default=5)
522
+ parser.add_argument("--top_p", type=float, default=0.0)
523
+ parser.add_argument("--num_train_epochs", default=10.0, type=float, help="Total number of training epochs to perform.")
524
+ parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
525
+ parser.add_argument("--lambda", default=0, type=float, help="")
526
+
527
+ ## Data parameters
528
+ parser.add_argument("--dataset", default='yelp', type=str, help="The dataset.")
529
+ # parser.add_argument("--train_data_file", default='../../../data/yelp/sentiment.train.tiny.text', type=str, help="The input training data file (a text file).")
530
+ parser.add_argument("--train_data_file", default='../../../data/yelp/sentiment.train.text', type=str, help="The input training data file (a text file).")
531
+ # parser.add_argument("--eval_data_file", default='../../../data/yelp/sentiment.dev.tiny.text', type=str, help="")
532
+ parser.add_argument("--eval_data_file", default='../../../data/yelp/sentiment.dev.small.text', type=str, help="2000 samples.")
533
+ parser.add_argument("--ExpName", default="local_lctrlg_yelp", type=str, help="The experiment name used in Azure Table.")
534
+ parser.add_argument("--create_new", default=0, type=int, help="")
535
+
536
+ # Training parameters
537
+ parser.add_argument("--checkpoint_dir", default='results_arae/checkpoint-47501/pytorch_model.bin', type=str, help='results/checkpoint-1212/pytorch_model.bin')
538
+ # parser.add_argument("--checkpoint", default='', type=str, help='results/checkpoint-1212/pytorch_model.bin')
539
+ parser.add_argument("--start_global_step", default=1001, type=int, help='')
540
+ parser.add_argument("--do_train", action='store_true',
541
+ help="Whether to run training.")
542
+ parser.add_argument("--do_eval", action='store_true',
543
+ help="Whether to run eval on the dev set.")
544
+ parser.add_argument("--per_gpu_train_batch_size", default=8, type=int, help="Batch size per GPU/CPU for training.")
545
+ parser.add_argument("--per_gpu_eval_batch_size", default=8, type=int, help="Batch size per GPU/CPU for evaluation.")
546
+ parser.add_argument('--gradient_accumulation_steps', type=int, default=1, help="Number of updates steps to accumulate before performing a backward/update pass.")
547
+ parser.add_argument("--evaluate_during_training", action='store_true', help="Run evaluation during training at each logging step.")
548
+ parser.add_argument('--gloabl_step_eval', type=int, default=0, help="Evaluate the results at the given global step")
549
+ # parser.add_argument('--logging_steps', type=int, default=2000, help="ARAE")
550
+ parser.add_argument('--logging_steps', type=int, default=10, help="CARA")
551
+ parser.add_argument('--eval_steps', type=int, default=500, help="CARA")
552
+ # parser.add_argument('--save_steps', type=int, default=5000, help="ARAE")
553
+ parser.add_argument('--save_steps', type=int, default=1000, help="CARA")
554
+ parser.add_argument("--eval_all_checkpoints", action='store_true', help="")
555
+
556
+ ## Encoder options
557
+ # parser.add_argument("--encoder_model_name_or_path", default="bert-base-uncased", type=str, )
558
+ parser.add_argument("--encoder_model_name_or_path", default="results_cara/checkpoint-encoder-1000", type=str)
559
+ # parser.add_argument("--encoder_model_name_or_path", default="results/checkpoint-encoder-55000", type=str")
560
+ parser.add_argument("--encoder_config_name", default="", type=str, help="Optional pretrained config name or path if not the same as model_name_or_path")
561
+ parser.add_argument("--encoder_tokenizer_name", default="", type=str, help="Keep empty. Will default to decoder_model_name_or_path")
562
+ parser.add_argument("--encoder_model_type", default="bert", type=str, help="The encoder model architecture to be fine-tuned.")
563
+
564
+ ## Decoder options
565
+ # parser.add_argument("--decoder_model_name_or_path", default="gpt2", type=str)
566
+ parser.add_argument("--decoder_model_name_or_path", default="results_cara/checkpoint-decoder-1000", type=str)
567
+ # parser.add_argument("--decoder_model_name_or_path", default="results/checkpoint-decoder-55000", type=str)
568
+ parser.add_argument("--decoder_config_name", default="", type=str, help="Optional pretrained config name or path if not the same as model_name_or_path")
569
+ parser.add_argument("--decoder_tokenizer_name", default="", type=str, help="Keep empty. Will default to decoder_model_name_or_path")
570
+ parser.add_argument("--decoder_model_type", default="gpt2", type=str, help="The decoder model architecture to be fine-tuned.")
571
+
572
+ ## Variational auto-encoder
573
+ parser.add_argument("--latent_size", default=32, type=int, help="Latent space dimension.")
574
+ parser.add_argument("--use_deterministic_connect", action='store_true', help="Use deterministic inference to generate latent codes, i.e., standard auto-encoders.")
575
+
576
+ ## Objective functions
577
+ parser.add_argument("--mlm", action='store_true', help="Train with masked-language modeling loss instead of language modeling.")
578
+ parser.add_argument("--mlm_probability", type=float, default=0.15, help="Ratio of tokens to mask for masked language modeling loss")
579
+ parser.add_argument("--cache_dir", default="", type=str, help="Optional directory to store the pre-trained models downloaded from s3 (instread of the default one)")
580
+ parser.add_argument("--block_size", default=21, type=int, help="21 for Yelp and Yahoo on label-conditional text generation")
581
+ parser.add_argument("--do_lower_case", action='store_true', help="Set this flag if you are using an uncased model.")
582
+
583
+ # Training Schedules
584
+ parser.add_argument("--ratio_increase", default=0.25, type=float, help="Learning schedule, the percentage for the annealing stage.")
585
+ parser.add_argument("--ratio_zero", default=0.5, type=float, help="Learning schedule, the percentage for the pure auto-encoding stage.")
586
+ parser.add_argument("--fb_mode", default=1, type=int, help="free bit training mode.")
587
+ parser.add_argument("--dim_target_kl", default=3.0, type=float, help="dim_target_kl free bit training mode.")
588
+ parser.add_argument("--learning_rate", default=5e-6, type=float, help="The initial learning rate for Adam.")
589
+ parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight deay if we apply some.")
590
+ parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
591
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
592
+ parser.add_argument("--max_steps", default=-1, type=int, help="If > 0: set total number of training steps to perform. Override num_train_epochs.")
593
+ parser.add_argument("--use_philly", action='store_true', help="Use Philly for computing.")
594
+ parser.add_argument("--use_pretrained_model", action='store_true',
595
+ help="Use pre-trained auto-encoder models as the initialization")
596
+ parser.add_argument("--use_pretrained_vae", action='store_true',
597
+ help="Use use_pretrained_vae as initialization, where beta value is specified in the folder")
598
+
599
+ parser.add_argument("--beta", type=float, default=1.0, help="The weighting hyper-parameter of the KL term in VAE")
600
+ parser.add_argument("--beta_cls", type=float, default=1.0, help="The weighting hyper-parameter for the classifier on the generated sentences")
601
+
602
+ ## IO: Logging and Saving
603
+ parser.add_argument("--no_cuda", action='store_true', help="Avoid using CUDA when available")
604
+ parser.add_argument('--overwrite_output_dir', type=int, default=1, help="Overwrite the content of the output directory")
605
+ parser.add_argument('--overwrite_cache', action='store_true', help="Overwrite the cached training and evaluation sets")
606
+ parser.add_argument('--seed', type=int, default=42, help="random seed for initialization")
607
+
608
+ # Precision & Distributed Training
609
+ parser.add_argument('--fp16', action='store_true', help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit")
610
+ parser.add_argument('--fp16_opt_level', type=str, default='O1', help="")
611
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
612
+ parser.add_argument('--server_ip', type=str, default='', help="For distant debugging.")
613
+ parser.add_argument('--server_port', type=str, default='', help="For distant debugging.")
614
+
615
+ # New parameters
616
+ parser.add_argument('--label_size', type=int, default=2, help="This depends on which dataset is used.")
617
+ args = parser.parse_args()
618
+ if args.decoder_model_type in ["bert", "roberta"] and not args.mlm:
619
+ raise ValueError("BERT and RoBERTa do not have LM heads but masked LM heads. They must be run using the --mlm flag (masked language modeling).")
620
+ if args.eval_data_file is None and args.do_eval:
621
+ raise ValueError("Cannot do evaluation without an evaluation data file. Either supply a file to --eval_data_file or remove the --do_eval argument.")
622
+ if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir:
623
+ raise ValueError("Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir))
624
+ # Setup distant debugging if needed
625
+ if args.server_ip and args.server_port:
626
+ # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
627
+ import ptvsd
628
+ logger.info("Waiting for debugger attach")
629
+ ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
630
+ ptvsd.wait_for_attach()
631
+ # Setup CUDA, GPU & distributed training
632
+ if args.local_rank == -1 or args.no_cuda:
633
+ device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
634
+ args.n_gpu = torch.cuda.device_count()
635
+ else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
636
+ torch.cuda.set_device(args.local_rank)
637
+ device = torch.device("cuda", args.local_rank)
638
+ torch.distributed.init_process_group(backend='nccl')
639
+ args.n_gpu = 1
640
+ args.device = device
641
+ # pdb.set_trace()
642
+ # Setup logging
643
+ logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', datefmt = '%m/%d/%Y %H:%M:%S',
644
+ level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
645
+ logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
646
+ args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16)
647
+
648
+ args.ExpName = 'Vae_' + args.dataset + '_Nz_' + str(args.latent_size) + '_Beta_' + str(args.beta) + '_Dkl_' + str(args.dim_target_kl) + \
649
+ '_Ra_' + str(args.ratio_increase) + '_R0_' + str(args.ratio_zero)
650
+ table_name = 'Vae' + args.dataset + 'Nz' + str(args.latent_size)
651
+ set_seed(args)
652
+
653
+ # Load pretrained model and tokenizer
654
+ if args.local_rank not in [-1, 0]:
655
+ torch.distributed.barrier() # Barrier to make sure only the first process in distributed training download model & vocab
656
+
657
+
658
+
659
+
660
+ if args.use_pretrained_model:
661
+ args.encoder_model_type = args.encoder_model_type.lower()
662
+ args.decoder_model_type = args.decoder_model_type.lower()
663
+
664
+ global_step = args.gloabl_step_eval
665
+
666
+ if args.use_pretrained_vae:
667
+ output_encoder_dir = os.path.join(args.checkpoint_dir, 'checkpoint-encoder-{}-1.0'.format(global_step))
668
+ output_decoder_dir = os.path.join(args.checkpoint_dir, 'checkpoint-decoder-{}-1.0'.format(global_step))
669
+ else:
670
+ output_encoder_dir = os.path.join(args.checkpoint_dir, 'checkpoint-encoder-{}'.format(global_step))
671
+ output_decoder_dir = os.path.join(args.checkpoint_dir, 'checkpoint-decoder-{}'.format(global_step))
672
+
673
+ checkpoints = [ [output_encoder_dir, output_decoder_dir] ]
674
+ logger.info("Evaluate the following checkpoints: %s", checkpoints)
675
+
676
+ # Load a trained Encoder model and vocabulary
677
+ encoder_config_class, encoder_model_class, encoder_tokenizer_class = MODEL_CLASSES[args.encoder_model_type]
678
+ model_encoder = encoder_model_class.from_pretrained(output_encoder_dir, latent_size=args.latent_size)
679
+ tokenizer_encoder = encoder_tokenizer_class.from_pretrained(args.encoder_tokenizer_name if args.encoder_tokenizer_name else args.encoder_model_name_or_path, do_lower_case=args.do_lower_case)
680
+
681
+ model_encoder.to(args.device)
682
+ if args.block_size <= 0:
683
+ args.block_size = tokenizer_encoder.max_len_single_sentence # Our input block size will be the max possible for the model
684
+ args.block_size = min(args.block_size, tokenizer_encoder.max_len_single_sentence)
685
+
686
+ # Load a trained Decoder model and vocabulary
687
+ decoder_config_class, decoder_model_class, decoder_tokenizer_class = MODEL_CLASSES[args.decoder_model_type]
688
+ model_decoder = decoder_model_class.from_pretrained(output_decoder_dir, latent_size=args.latent_size)
689
+ tokenizer_decoder = decoder_tokenizer_class.from_pretrained(args.decoder_tokenizer_name if args.decoder_tokenizer_name else args.decoder_model_name_or_path, do_lower_case=args.do_lower_case)
690
+ model_decoder.to(args.device)
691
+ if args.block_size <= 0:
692
+ args.block_size = tokenizer_decoder.max_len_single_sentence # Our input block size will be the max possible for the model
693
+ args.block_size = min(args.block_size, tokenizer_decoder.max_len_single_sentence)
694
+
695
+ else:
696
+ ## Encoder
697
+ encoder_config_class, encoder_model_class, encoder_tokenizer_class = MODEL_CLASSES[args.encoder_model_type]
698
+ encoder_config = encoder_config_class.from_pretrained(args.encoder_config_name if args.encoder_config_name else args.encoder_model_name_or_path)
699
+ tokenizer_encoder = encoder_tokenizer_class.from_pretrained(args.encoder_tokenizer_name if args.encoder_tokenizer_name else args.encoder_model_name_or_path, do_lower_case=args.do_lower_case)
700
+ if args.block_size <= 0:
701
+ args.block_size = tokenizer_encoder.max_len_single_sentence # Our input block size will be the max possible for the model
702
+ args.block_size = min(args.block_size, tokenizer_encoder.max_len_single_sentence)
703
+ model_encoder = encoder_model_class.from_pretrained(args.encoder_model_name_or_path, from_tf=bool('.ckpt' in args.encoder_model_name_or_path), config=encoder_config, latent_size=args.latent_size)
704
+ # model_encoder = encoder_model_class(config=encoder_config, latent_size=args.latent_size)
705
+
706
+ ## Decoder
707
+ decoder_config_class, decoder_model_class, decoder_tokenizer_class = MODEL_CLASSES[args.decoder_model_type]
708
+ decoder_config = decoder_config_class.from_pretrained(args.decoder_config_name if args.decoder_config_name else args.decoder_model_name_or_path)
709
+ tokenizer_decoder = decoder_tokenizer_class.from_pretrained(args.decoder_tokenizer_name if args.decoder_tokenizer_name else args.decoder_model_name_or_path, do_lower_case=args.do_lower_case)
710
+ if args.block_size <= 0:
711
+ args.block_size = tokenizer_decoder.max_len_single_sentence # Our input block size will be the max possible for the model
712
+ args.block_size = min(args.block_size, tokenizer_decoder.max_len_single_sentence)
713
+ setattr(decoder_config, "latent_size", args.latent_size)
714
+ model_decoder = decoder_model_class.from_pretrained(args.decoder_model_name_or_path, from_tf=bool('.ckpt' in args.decoder_model_name_or_path), config=decoder_config, latent_size=args.latent_size)
715
+ # model_decoder = decoder_model_class(config=decoder_config, latent_size=args.latent_size)
716
+
717
+ # Chunyuan: Add Padding token to GPT2
718
+ special_tokens_dict = {'pad_token': '<PAD>', 'bos_token': '<BOS>', 'eos_token': '<EOS>'}
719
+ num_added_toks = tokenizer_decoder.add_special_tokens(special_tokens_dict)
720
+ logger.info('We have added {} tokens to GPT2'.format(num_added_toks))
721
+ model_decoder.resize_token_embeddings(len(tokenizer_decoder)) # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer.
722
+ assert tokenizer_decoder.pad_token == '<PAD>'
723
+
724
+
725
+ # on_gpu = next(model_vae.parameters()).is_cuda
726
+ if args.local_rank == 0:
727
+ torch.distributed.barrier() # End of barrier to make sure only the first process in distributed training download model & vocab
728
+ logger.info("Training/evaluation parameters %s", args)
729
+
730
+ if not os.path.exists(args.output_dir): os.makedirs(args.output_dir)
731
+ # Training
732
+
733
+ logff = open(os.path.join(args.output_dir, 'log_{}'.format(get_time_str())), 'a')
734
+
735
+ if args.do_train:
736
+ global_step = args.start_global_step
737
+ model_vae = CARA(model_encoder, model_decoder, tokenizer_encoder, tokenizer_decoder, args).to(args.device)
738
+
739
+ # if args.checkpoint:
740
+ # logger.info("Loading checkpoint from {}".format(args.checkpoint))
741
+ # model_vae.load_state_dict(torch.load(args.checkpoint))
742
+
743
+ if args.local_rank not in [-1, 0]:
744
+ torch.distributed.barrier() # Barrier to make sure only the first process in distributed training process the dataset, and the others will use the cache
745
+ if args.local_rank == 0:
746
+ torch.distributed.barrier()
747
+
748
+ train_dataset = load_and_cache_examples(args, [tokenizer_encoder, tokenizer_decoder], evaluate=False)
749
+
750
+ # logger.info("Test evaluate before training.")
751
+ # evaluate(args, model_vae, tokenizer_encoder, tokenizer_decoder, table_name, prefix=0, subset='test')
752
+
753
+ # Train
754
+ global_step, tr_loss = train(args, train_dataset, model_vae, tokenizer_encoder, tokenizer_decoder, table_name, logff=logff)
755
+ logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
756
+
757
+ # Saving best-practices: if you use save_pretrained for the model and tokenizer, you can reload them using from_pretrained()
758
+ if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
759
+ # Create output directory if needed
760
+ # Save model checkpoint
761
+ output_dir = os.path.join(args.output_dir, 'checkpoint-{}'.format(global_step))
762
+ output_encoder_dir = os.path.join(args.output_dir, 'checkpoint-encoder-{}'.format(global_step))
763
+ output_decoder_dir = os.path.join(args.output_dir, 'checkpoint-decoder-{}'.format(global_step))
764
+ if not os.path.exists(output_dir) and args.local_rank in [-1, 0]:
765
+ os.makedirs(output_dir)
766
+ if not os.path.exists(output_encoder_dir) and args.local_rank in [-1, 0]:
767
+ os.makedirs(output_encoder_dir)
768
+ if not os.path.exists(output_decoder_dir) and args.local_rank in [-1, 0]:
769
+ os.makedirs(output_decoder_dir)
770
+
771
+ logger.info("Saving encoder model checkpoint to %s", output_encoder_dir)
772
+ logger.info("Saving decoder model checkpoint to %s", output_decoder_dir)
773
+
774
+ model_encoder_to_save = model_vae.module.encoder if hasattr(model_vae, 'module') else model_vae.encoder # Take care of distributed/parallel training
775
+ model_decoder_to_save = model_vae.module.decoder if hasattr(model_vae, 'module') else model_vae.decoder # Take care of distributed/parallel training
776
+ model_to_save = model_vae.module if hasattr(model_vae, "module") else model_vae
777
+
778
+ # Good practice: save your training arguments together with the trained model
779
+ if args.use_philly:
780
+ save_solid = False
781
+ while not save_solid:
782
+ try:
783
+ torch.save(args, os.path.join(output_dir, 'training_args.bin'))
784
+ torch.save(model_to_save.state_dict(), os.path.join(output_dir, 'pytorch_model.bin'))
785
+ save_solid = True
786
+ except:
787
+ pass
788
+ else:
789
+ torch.save(args, os.path.join(output_dir, 'training_args.bin'))
790
+ torch.save(model_to_save.state_dict(), os.path.join(output_dir, 'pytorch_model.bin'))
791
+ args.checkpoint = os.path.join(output_dir, 'pytorch_model.bin')
792
+
793
+ if args.use_philly:
794
+ save_solid = False
795
+ while not save_solid:
796
+ try:
797
+ model_encoder_to_save.save_pretrained(output_encoder_dir)
798
+ torch.save(args, os.path.join(output_encoder_dir, 'training_encoder_args.bin'))
799
+ save_solid = True
800
+ except:
801
+ pass
802
+ else:
803
+ model_encoder_to_save.save_pretrained(output_encoder_dir)
804
+ torch.save(args, os.path.join(output_encoder_dir, 'training_encoder_args.bin'))
805
+
806
+ if args.use_philly:
807
+ save_solid = False
808
+ while not save_solid:
809
+ try:
810
+ model_decoder_to_save.save_pretrained(output_decoder_dir)
811
+ torch.save(args, os.path.join(output_decoder_dir, 'training_decoder_args.bin'))
812
+ save_solid = True
813
+ except:
814
+ pass
815
+ else:
816
+ model_decoder_to_save.save_pretrained(output_decoder_dir)
817
+ torch.save(args, os.path.join(output_decoder_dir, 'training_decoder_args.bin'))
818
+
819
+ # Load a trained model and vocabulary that you have fine-tuned
820
+ # model_encoder = encoder_model_class.from_pretrained(output_encoder_dir, latent_size=args.latent_size)
821
+ # model_encoder.to(args.device)
822
+ #
823
+ # # Load a trained model and vocabulary that you have fine-tuned
824
+ # model_decoder = decoder_model_class.from_pretrained(output_decoder_dir, latent_size=args.latent_size)
825
+ # model_decoder.to(args.device)
826
+
827
+ # Evaluation
828
+ results = {}
829
+ if args.do_eval and args.local_rank in [-1, 0]:
830
+ # if global_step == 0:
831
+ # global_step = args.gloabl_step_eval
832
+
833
+ # output_encoder_dir = os.path.join(args.output_dir, 'checkpoint-encoder-{}'.format(global_step))
834
+ # output_decoder_dir = os.path.join(args.output_dir, 'checkpoint-decoder-{}'.format(global_step))
835
+ # checkpoints = [ [output_encoder_dir, output_decoder_dir] ]
836
+
837
+ # logger.info("Evaluate the following checkpoints: %s", checkpoints)
838
+ # for checkpoint in checkpoints:
839
+
840
+ # global_step = args.checkpoint_dir.split('/')[-2].split('-')[-1] if args.checkpoint_dir else ""
841
+
842
+ # model_encoder = encoder_model_class.from_pretrained(checkpoint[0], latent_size=args.latent_size)
843
+ # model_encoder.to(args.device)
844
+ # model_decoder = decoder_model_class.from_pretrained(checkpoint[1], latent_size=args.latent_size)
845
+ # model_decoder.to(args.device)
846
+
847
+ model_vae = CARA(model_encoder, model_decoder, tokenizer_encoder, tokenizer_decoder, args).to(args.device)
848
+
849
+ if args.gloabl_step_eval < 1:
850
+ args.gloabl_step_eval = global_step
851
+ args.checkpoint_dir = os.path.join(args.output_dir, 'checkpoint-{}/pytorch_model.bin'.format(args.gloabl_step_eval))
852
+ else:
853
+ global_step = args.gloabl_step_eval
854
+ args.checkpoint_dir = os.path.join(args.checkpoint_dir, 'checkpoint-{}/pytorch_model.bin'.format(args.gloabl_step_eval))
855
+
856
+
857
+ # if args.checkpoint_dir and os.path.exists(args.checkpoint_dir):
858
+ # logger.info("Loading checkpoint from {}".format(args.checkpoint_dir))
859
+ # model_vae.load_state_dict(torch.load(args.checkpoint_dir))
860
+ # else:
861
+ # raise ValueError("Cannot find checkpoint at: {}".format(args.checkpoint))
862
+
863
+ metrics = evaluate(args, model_vae, tokenizer_encoder, tokenizer_decoder, table_name, prefix=global_step, subset='test')
864
+ metrics = dict((k + '_{}'.format(global_step), v) for k, v in metrics.items())
865
+ results.update(metrics)
866
+
867
+ # result = evaluate(args, model_vae, tokenizer_encoder, tokenizer_decoder, table_name, prefix=global_step, subset='train')
868
+ # result = dict((k + '_{}'.format(global_step), v) for k, v in result.items())
869
+ # results.update(result)
870
+
871
+ return results
872
+
873
+
874
+ if __name__ == "__main__":
875
+ main()
Optimus/code/examples/big_ae/run_lm_vae_pretraining.py ADDED
@@ -0,0 +1,669 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Fine-tuning the library models for language modeling on a text file (GPT, GPT-2, BERT, RoBERTa).
18
+ GPT and GPT-2 are fine-tuned using a causal language modeling (CLM) loss while BERT and RoBERTa are fine-tuned
19
+ using a masked language modeling (MLM) loss.
20
+ """
21
+
22
+ from __future__ import absolute_import, division, print_function
23
+
24
+
25
+ import pdb
26
+ import argparse
27
+ import glob
28
+ import logging
29
+
30
+ import os
31
+ import pickle
32
+ import random
33
+ from pathlib import Path
34
+
35
+ import numpy as np
36
+ import torch
37
+ from torch.utils.data import DataLoader, Dataset, SequentialSampler, RandomSampler, TensorDataset
38
+ from torch.utils.data.distributed import DistributedSampler
39
+ from tensorboardX import SummaryWriter
40
+ from tqdm import tqdm, trange
41
+ from collections import defaultdict
42
+
43
+ # from azure.cosmosdb.table.tableservice import TableService
44
+ # from azure.cosmosdb.table.models import Entity
45
+ from datetime import datetime
46
+
47
+
48
+
49
+ from pytorch_transformers import (WEIGHTS_NAME, AdamW, WarmupLinearSchedule,
50
+ BertConfig, BertForLatentConnector, BertTokenizer,
51
+ GPT2Config, GPT2ForLatentConnector, GPT2Tokenizer,
52
+ OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer,
53
+ RobertaConfig, RobertaForMaskedLM, RobertaTokenizer)
54
+
55
+ from utils import (calc_iwnll, calc_mi, calc_au, BucketingDataLoader, BucketingMultipleFiles_DataLoader, frange_cycle_linear, frange_cycle_zero_linear)
56
+
57
+ from modules import VAE
58
+
59
+
60
+ # logging.getLogger("azure").setLevel(logging.WARNING)
61
+ # logging.getLogger("TableService").setLevel(logging.WARNING)
62
+
63
+ logger = logging.getLogger(__name__)
64
+
65
+
66
+ MODEL_CLASSES = {
67
+ 'gpt2': (GPT2Config, GPT2ForLatentConnector, GPT2Tokenizer),
68
+ 'openai-gpt': (OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer),
69
+ 'bert': (BertConfig, BertForLatentConnector, BertTokenizer),
70
+ 'roberta': (RobertaConfig, RobertaForMaskedLM, RobertaTokenizer)
71
+ }
72
+
73
+
74
+ storage_name="textae"
75
+ key=r"6yBCXlblof8DVFJ4BD3eNFTrGQCej6cKfCf5z308cKnevyHaG+yl/m+ITVErB9yt0kvN3ToqxLIh0knJEfFmPA=="
76
+ # ts = TableService(account_name=storage_name, account_key=key)
77
+
78
+
79
+
80
+ def build_dataload_and_cache_examples(args, tokenizer, evaluate=False):
81
+ if isinstance(tokenizer, list):
82
+ args.batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
83
+ file_path=args.train_data_file
84
+ dataloader = BucketingMultipleFiles_DataLoader(file_path, args.batch_size, args.max_seq_length, tokenizer, args, bucket=100, shuffle=True)
85
+ else:
86
+ pass
87
+ return dataloader
88
+
89
+
90
+
91
+
92
+ def set_seed(args):
93
+ random.seed(args.seed)
94
+ np.random.seed(args.seed)
95
+ torch.manual_seed(args.seed)
96
+ if args.n_gpu > 0:
97
+ torch.cuda.manual_seed_all(args.seed)
98
+
99
+
100
+ def mask_tokens(inputs, tokenizer, args):
101
+ """ Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. """
102
+ labels = inputs.clone()
103
+ # We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
104
+
105
+ masked_indices = torch.bernoulli(torch.full(labels.shape, args.mlm_probability)).to(torch.uint8)
106
+ labels[masked_indices==1] = -1 # We only compute loss on masked tokens
107
+
108
+ # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
109
+ indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).to(torch.uint8) & masked_indices
110
+ inputs[indices_replaced] = tokenizer.convert_tokens_to_ids(tokenizer.mask_token)
111
+
112
+ # 10% of the time, we replace masked input tokens with random word
113
+ indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).to(torch.uint8) & masked_indices & ~indices_replaced
114
+ indices_random = indices_random
115
+ random_words = torch.randint(len(tokenizer), labels.shape, dtype=torch.long)
116
+ inputs[indices_random] = random_words[indices_random]
117
+
118
+ # The rest of the time (10% of the time) we keep the masked input tokens unchanged
119
+ return inputs, labels
120
+
121
+
122
+ def train(args, train_dataloader, model_vae, encoder_tokenizer, decoder_tokenizer, table_name):
123
+ """ Train the model """
124
+ if args.local_rank in [-1, 0]:
125
+ tb_writer = SummaryWriter()
126
+
127
+ args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
128
+ # train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
129
+ # train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
130
+
131
+ if args.max_steps > 0:
132
+ t_total = args.max_steps
133
+ args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
134
+ else:
135
+ t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
136
+
137
+ # Prepare optimizer and schedule (linear warmup and decay)
138
+
139
+
140
+ # model_encoder, model_decoder, model_connector = model_vae.encoder, model_vae.decoder, model_vae.linear
141
+ no_decay = ['bias', 'LayerNorm.weight']
142
+ optimizer_grouped_parameters = [
143
+ {'params': [p for n, p in model_vae.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
144
+ {'params': [p for n, p in model_vae.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
145
+ ]
146
+
147
+ optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
148
+ scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total)
149
+
150
+
151
+ if args.fp16:
152
+ try:
153
+ from apex import amp
154
+ except ImportError:
155
+ raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
156
+ model_vae, optimizer = amp.initialize(model_vae, optimizer, opt_level=args.fp16_opt_level)
157
+
158
+ # multi-gpu training (should be after apex fp16 initialization)
159
+ if args.n_gpu > 1:
160
+ model_vae = torch.nn.DataParallel(model_vae, device_ids=range(args.n_gpu)).to(args.device)
161
+
162
+ # Distributed training (should be after apex fp16 initialization)
163
+ if args.local_rank != -1:
164
+ model_vae = torch.nn.parallel.DistributedDataParallel(model_vae, device_ids=[args.local_rank],
165
+ output_device=args.local_rank,
166
+ find_unused_parameters=True)
167
+
168
+
169
+
170
+
171
+ files = Path(args.train_data_file)
172
+ num_files = len(list(files.glob('*seq64*.json')))
173
+
174
+
175
+ # Train!
176
+ logger.info("***** Running training *****")
177
+ logger.info(" Num files = %d", num_files)
178
+ logger.info(" Num examples of first file = %d", train_dataloader.num_examples)
179
+ logger.info(" Num Epochs = %d", args.num_train_epochs)
180
+ logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
181
+ logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d",
182
+ args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
183
+ logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
184
+ logger.info(" Total optimization steps = %d", t_total)
185
+
186
+
187
+ global_step = 0
188
+ tr_loss, logging_loss = 0.0, 0.0
189
+
190
+ model_vae.zero_grad()
191
+ num_train_epochs_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0])
192
+
193
+ n_iter = int(args.num_train_epochs) * len(train_dataloader)
194
+ beta_t_list = frange_cycle_zero_linear(n_iter, start=0.0, stop=args.beta, n_cycle=1, ratio_increase=args.ratio_increase, ratio_zero=args.ratio_zero)
195
+
196
+ tmp_list = []
197
+ dict_token_length = defaultdict(int)
198
+
199
+ set_seed(args) # Added here for reproducibility (even between python 2 and 3)
200
+ for epoch in num_train_epochs_iterator:
201
+ train_dataloader.reset()
202
+ for idx_file in range(num_files-1):
203
+ logger.info(f"Epoch {epoch}, File idx {train_dataloader.file_idx}")
204
+ epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
205
+ for step, batch in enumerate(epoch_iterator):
206
+
207
+ tokenized_text0, tokenized_text1, tokenized_text_lengths = batch
208
+
209
+ dict_token_length[ tokenized_text_lengths[0,0].item() ] += 1
210
+
211
+ # continue
212
+
213
+
214
+ # tokenized_text0 = tokenized_text0.to(args.device)
215
+ # tokenized_text1 = tokenized_text1.to(args.device)
216
+ # prepare input-output data for reconstruction
217
+
218
+
219
+
220
+ inputs, labels = mask_tokens(tokenized_text0, encoder_tokenizer, args) if args.mlm else (tokenized_text0, tokenized_text1)
221
+ labels = tokenized_text1
222
+
223
+ tokenized_text1 = tokenized_text1.to(args.device)
224
+ inputs = inputs.to(args.device)
225
+ labels = labels.to(args.device)
226
+
227
+ model_vae.train()
228
+
229
+ beta_t = 0.0 # beta_t_list[step + epoch*len(epoch_iterator)]
230
+ model_vae.module.args.beta = beta_t
231
+
232
+ if beta_t == 0.0:
233
+ model_vae.module.args.fb_mode = 0
234
+ else:
235
+ model_vae.module.args.fb_mode = 1
236
+
237
+ if args.use_deterministic_connect:
238
+ model_vae.module.args.fb_mode = 2
239
+
240
+ loss_rec, loss_kl, loss = model_vae(inputs, labels)
241
+
242
+ loss_rec = loss_rec.mean() # mean() to average on multi-gpu parallel training
243
+ loss_kl = loss_kl.mean()
244
+ loss = loss.mean()
245
+
246
+ if args.use_philly:
247
+ print("PROGRESS: {}%".format(round(100 * (step + epoch*len(epoch_iterator) ) /(int(args.num_train_epochs) * len(epoch_iterator)) , 4)))
248
+ print("EVALERR: {}%".format(loss_rec))
249
+
250
+ epoch_iterator.set_description(
251
+ (
252
+ f'iter: {step + epoch*len(epoch_iterator) }; file:{idx_file}; loss: {loss.item():.3f}; '
253
+ f'loss_rec: {loss_rec.item():.3f}; loss_kl: {loss_kl.item():.3f}; '
254
+ f'beta: {model_vae.module.args.beta:.3f}'
255
+ )
256
+ )
257
+
258
+ # if global_step % 5 == 0:
259
+ # row = {
260
+ # 'PartitionKey': 'MILU_Rule_Rule_Template',
261
+ # 'RowKey': str(datetime.now()),
262
+ # 'ExpName' : args.ExpName,
263
+ # 'iter': str( step + epoch*len(epoch_iterator) ),
264
+ # 'loss': str( loss.item()),
265
+ # 'loss_rec': str(loss_rec.item()),
266
+ # 'loss_kl': str(loss_kl.item()),
267
+ # 'beta': str(model_vae.args.beta)
268
+ # }
269
+ # # pdb.set_trace()
270
+ # ts.insert_entity(table_name, row)
271
+
272
+ # pdb.set_trace()
273
+
274
+ if args.gradient_accumulation_steps > 1:
275
+ loss = loss / args.gradient_accumulation_steps
276
+
277
+ if args.fp16:
278
+ with amp.scale_loss(loss, optimizer) as scaled_loss:
279
+ scaled_loss.backward()
280
+ else:
281
+ loss.backward()
282
+
283
+ tr_loss += loss.item()
284
+ if (step + 1) % args.gradient_accumulation_steps == 0:
285
+ if args.fp16:
286
+ torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
287
+ else:
288
+ torch.nn.utils.clip_grad_norm_(model_vae.parameters(), args.max_grad_norm)
289
+
290
+ optimizer.step()
291
+
292
+ scheduler.step() # Update learning rate schedule
293
+
294
+ model_vae.zero_grad()
295
+
296
+ global_step += 1
297
+
298
+
299
+ if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
300
+ # Log metrics
301
+ if args.local_rank == -1 and args.evaluate_during_training: # Only evaluate when single GPU otherwise metrics may not average well
302
+ results = evaluate(args, model_vae, encoder_tokenizer, decoder_tokenizer)
303
+ for key, value in results.items():
304
+ tb_writer.add_scalar('eval_{}'.format(key), value, global_step)
305
+ tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step)
306
+ tb_writer.add_scalar('loss', (tr_loss - logging_loss)/args.logging_steps, global_step)
307
+ logging_loss = tr_loss
308
+
309
+ if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
310
+
311
+ # Save encoder model checkpoint
312
+ output_encoder_dir = os.path.join(args.output_dir, 'checkpoint-encoder-{}'.format(global_step))
313
+
314
+ if not os.path.exists(output_encoder_dir):
315
+ os.makedirs(output_encoder_dir)
316
+
317
+ model_encoder_to_save = model_vae.module.encoder if hasattr(model_vae, 'module') else model_vae.encoder # Take care of distributed/parallel training
318
+ if args.use_philly:
319
+ save_solid = False
320
+ while not save_solid:
321
+ try:
322
+ model_encoder_to_save.save_pretrained(output_encoder_dir)
323
+ torch.save(args, os.path.join(output_encoder_dir, 'training_args.bin'))
324
+ logger.info("Saving model checkpoint to %s", output_encoder_dir)
325
+ save_solid = True
326
+ except:
327
+ pass
328
+ else:
329
+ model_encoder_to_save.save_pretrained(output_encoder_dir)
330
+ torch.save(args, os.path.join(output_encoder_dir, 'training_args.bin'))
331
+ logger.info("Saving model checkpoint to %s", output_encoder_dir)
332
+
333
+ # Save decoder model checkpoint
334
+ output_decoder_dir = os.path.join(args.output_dir, 'checkpoint-decoder-{}'.format(global_step))
335
+
336
+ if not os.path.exists(output_decoder_dir):
337
+ os.makedirs(output_decoder_dir)
338
+
339
+ model_decoder_to_save = model_vae.module.decoder if hasattr(model_vae, 'module') else model_vae.decoder # Take care of distributed/parallel training
340
+ if args.use_philly:
341
+ save_solid = False
342
+ while not save_solid:
343
+ try:
344
+ model_decoder_to_save.save_pretrained(output_decoder_dir)
345
+ torch.save(args, os.path.join(output_decoder_dir, 'training_args.bin'))
346
+ logger.info("Saving model checkpoint to %s", output_decoder_dir)
347
+ save_solid = True
348
+ except:
349
+ pass
350
+ else:
351
+ model_decoder_to_save.save_pretrained(output_decoder_dir)
352
+ torch.save(args, os.path.join(output_decoder_dir, 'training_args.bin'))
353
+ logger.info("Saving model checkpoint to %s", output_decoder_dir)
354
+
355
+
356
+ if args.max_steps > 0 and global_step > args.max_steps:
357
+ epoch_iterator.close()
358
+ break
359
+
360
+ if args.max_steps > 0 and global_step > args.max_steps:
361
+ train_iterator.close()
362
+ break
363
+
364
+
365
+ # print(dict_token_length)
366
+ # with open('wikipedia_stats.json', 'w') as fp:
367
+ # json.dump(dict_token_length, fp)
368
+
369
+ if args.local_rank in [-1, 0]:
370
+ tb_writer.close()
371
+
372
+ return global_step, tr_loss / global_step
373
+
374
+
375
+ def main():
376
+ parser = argparse.ArgumentParser()
377
+
378
+ ## Required parameters
379
+ parser.add_argument("--train_data_file", default=None, type=str, required=True,
380
+ help="The input training data file (a text file).")
381
+ parser.add_argument("--output_dir", default=None, type=str, required=True,
382
+ help="The output directory where the model predictions and checkpoints will be written.")
383
+ parser.add_argument("--dataset", default=None, type=str, help="The dataset.")
384
+
385
+ ## Other parameters
386
+ parser.add_argument("--eval_data_file", default=None, type=str,
387
+ help="An optional input evaluation data file to evaluate the perplexity on (a text file).")
388
+ parser.add_argument("--ExpName", default="", type=str,
389
+ help="The experiment name used in Azure Table.")
390
+
391
+ ## Encoder options
392
+ parser.add_argument("--encoder_model_type", default="bert", type=str,
393
+ help="The encoder model architecture to be fine-tuned.")
394
+ parser.add_argument("--encoder_model_name_or_path", default="bert-base-cased", type=str,
395
+ help="The encoder model checkpoint for weights initialization.")
396
+ parser.add_argument("--encoder_config_name", default="", type=str,
397
+ help="Optional pretrained config name or path if not the same as model_name_or_path")
398
+ parser.add_argument("--encoder_tokenizer_name", default="", type=str,
399
+ help="Optional pretrained tokenizer name or path if not the same as model_name_or_path")
400
+
401
+ ## Decoder options
402
+ parser.add_argument("--decoder_model_type", default="gpt2", type=str,
403
+ help="The decoder model architecture to be fine-tuned.")
404
+ parser.add_argument("--decoder_model_name_or_path", default="bert-base-cased", type=str,
405
+ help="The decoder model checkpoint for weights initialization.")
406
+ parser.add_argument("--decoder_config_name", default="", type=str,
407
+ help="Optional pretrained config name or path if not the same as model_name_or_path")
408
+ parser.add_argument("--decoder_tokenizer_name", default="", type=str,
409
+ help="Optional pretrained tokenizer name or path if not the same as model_name_or_path")
410
+
411
+ ## Variational auto-encoder
412
+ parser.add_argument("--latent_size", default=32, type=int, help="Latent space dimension.")
413
+ parser.add_argument("--use_deterministic_connect", action='store_true',
414
+ help="Use deterministic inference to generate latent codes, i.e., standard auto-encoders.")
415
+
416
+ ## Objective functions
417
+ parser.add_argument("--mlm", action='store_true',
418
+ help="Train with masked-language modeling loss instead of language modeling.")
419
+ parser.add_argument("--mlm_probability", type=float, default=0.15,
420
+ help="Ratio of tokens to mask for masked language modeling loss")
421
+ parser.add_argument("--beta", type=float, default=1.0,
422
+ help="The weighting hyper-parameter of the KL term in VAE")
423
+
424
+
425
+ parser.add_argument("--cache_dir", default="", type=str,
426
+ help="Optional directory to store the pre-trained models downloaded from s3 (instread of the default one)")
427
+ parser.add_argument("--max_seq_length", default=512, type=int,
428
+ help="Optional input sequence length before tokenization. The sequence will be dropped if it is longer the max_seq_length")
429
+ parser.add_argument("--block_size", default=-1, type=int,
430
+ help="Optional input sequence length after tokenization."
431
+ "The training dataset will be truncated in block of this size for training."
432
+ "Default to the model max input length for single sentence inputs (take into account special tokens).")
433
+ parser.add_argument("--do_train", action='store_true',
434
+ help="Whether to run training.")
435
+ parser.add_argument("--do_eval", action='store_true',
436
+ help="Whether to run eval on the dev set.")
437
+ parser.add_argument("--evaluate_during_training", action='store_true',
438
+ help="Run evaluation during training at each logging step.")
439
+ parser.add_argument("--do_lower_case", action='store_true',
440
+ help="Set this flag if you are using an uncased model.")
441
+
442
+
443
+ # Training Schedules
444
+ parser.add_argument("--ratio_increase", default=0.25, type=float,
445
+ help="Learning schedule, the percentage for the annealing stage.")
446
+ parser.add_argument("--ratio_zero", default=0.25, type=float,
447
+ help="Learning schedule, the percentage for the pure auto-encoding stage.")
448
+ parser.add_argument("--fb_mode", default=0, type=int,
449
+ help="free bit training mode.")
450
+ parser.add_argument("--dim_target_kl", default=3.0, type=float,
451
+ help="dim_target_kl free bit training mode.")
452
+ parser.add_argument("--per_gpu_train_batch_size", default=4, type=int,
453
+ help="Batch size per GPU/CPU for training.")
454
+ parser.add_argument("--per_gpu_eval_batch_size", default=1, type=int,
455
+ help="Batch size per GPU/CPU for evaluation.")
456
+ parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
457
+ help="Number of updates steps to accumulate before performing a backward/update pass.")
458
+ parser.add_argument("--learning_rate", default=5e-5, type=float,
459
+ help="The initial learning rate for Adam.")
460
+ parser.add_argument("--weight_decay", default=0.0, type=float,
461
+ help="Weight deay if we apply some.")
462
+ parser.add_argument("--adam_epsilon", default=1e-8, type=float,
463
+ help="Epsilon for Adam optimizer.")
464
+ parser.add_argument("--max_grad_norm", default=1.0, type=float,
465
+ help="Max gradient norm.")
466
+ parser.add_argument("--num_train_epochs", default=1.0, type=float,
467
+ help="Total number of training epochs to perform.")
468
+ parser.add_argument("--max_steps", default=-1, type=int,
469
+ help="If > 0: set total number of training steps to perform. Override num_train_epochs.")
470
+ parser.add_argument("--warmup_steps", default=0, type=int,
471
+ help="Linear warmup over warmup_steps.")
472
+ parser.add_argument("--use_philly", action='store_true',
473
+ help="Use Philly for computing.")
474
+
475
+ ## IO: Logging and Saving
476
+ parser.add_argument('--logging_steps', type=int, default=50,
477
+ help="Log every X updates steps.")
478
+ parser.add_argument('--save_steps', type=int, default=50,
479
+ help="Save checkpoint every X updates steps.")
480
+ parser.add_argument("--eval_all_checkpoints", action='store_true',
481
+ help="Evaluate all checkpoints starting with the same prefix as model_name_or_path ending and ending with step number")
482
+ parser.add_argument("--no_cuda", action='store_true',
483
+ help="Avoid using CUDA when available")
484
+ parser.add_argument('--overwrite_output_dir', action='store_true',
485
+ help="Overwrite the content of the output directory")
486
+ parser.add_argument('--overwrite_cache', action='store_true',
487
+ help="Overwrite the cached training and evaluation sets")
488
+ parser.add_argument('--seed', type=int, default=42,
489
+ help="random seed for initialization")
490
+ parser.add_argument('--gloabl_step_eval', type=int, default=661,
491
+ help="Evaluate the results at the given global step")
492
+
493
+ # Precision & Distributed Training
494
+ parser.add_argument('--fp16', action='store_true',
495
+ help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit")
496
+ parser.add_argument('--fp16_opt_level', type=str, default='O1',
497
+ help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
498
+ "See details at https://nvidia.github.io/apex/amp.html")
499
+ parser.add_argument("--local_rank", type=int, default=-1,
500
+ help="For distributed training: local_rank")
501
+ parser.add_argument('--server_ip', type=str, default='', help="For distant debugging.")
502
+ parser.add_argument('--server_port', type=str, default='', help="For distant debugging.")
503
+ args = parser.parse_args()
504
+
505
+ if args.decoder_model_type in ["bert", "roberta"] and not args.mlm:
506
+ raise ValueError("BERT and RoBERTa do not have LM heads but masked LM heads. They must be run using the --mlm "
507
+ "flag (masked language modeling).")
508
+ if args.eval_data_file is None and args.do_eval:
509
+ raise ValueError("Cannot do evaluation without an evaluation data file. Either supply a file to --eval_data_file "
510
+ "or remove the --do_eval argument.")
511
+
512
+ if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir:
513
+ raise ValueError("Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir))
514
+
515
+ # Setup distant debugging if needed
516
+ if args.server_ip and args.server_port:
517
+ # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
518
+ import ptvsd
519
+ print("Waiting for debugger attach")
520
+ ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
521
+ ptvsd.wait_for_attach()
522
+
523
+ # Setup CUDA, GPU & distributed training
524
+ if args.local_rank == -1 or args.no_cuda:
525
+ device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
526
+ args.n_gpu = torch.cuda.device_count()
527
+ else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
528
+ torch.cuda.set_device(args.local_rank)
529
+ device = torch.device("cuda", args.local_rank)
530
+ torch.distributed.init_process_group(backend='nccl')
531
+ args.n_gpu = 1
532
+ args.device = device
533
+
534
+ # Setup logging
535
+ logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
536
+ datefmt = '%m/%d/%Y %H:%M:%S',
537
+ level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
538
+ logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
539
+ args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16)
540
+
541
+ args.ExpName = 'Vae_' + args.dataset + '_Nz_' + str(args.latent_size) + '_Beta_' + str(args.beta) + '_Dkl_' + str(args.dim_target_kl) + '_Ra_' + str(args.ratio_increase) + '_R0_' + str(args.ratio_zero)
542
+ table_name = 'Vae' + args.dataset + 'Nz' + str(args.latent_size)
543
+ try:
544
+ ts.create_table(table_name)
545
+ except:
546
+ pass
547
+
548
+
549
+ # Set seed
550
+ set_seed(args)
551
+
552
+ # Load pretrained model and tokenizer
553
+ if args.local_rank not in [-1, 0]:
554
+ torch.distributed.barrier() # Barrier to make sure only the first process in distributed training download model & vocab
555
+
556
+ ## Encoder
557
+ encoder_config_class, encoder_model_class, encoder_tokenizer_class = MODEL_CLASSES[args.encoder_model_type]
558
+ encoder_config = encoder_config_class.from_pretrained(args.encoder_config_name if args.encoder_config_name else args.encoder_model_name_or_path)
559
+ tokenizer_encoder = encoder_tokenizer_class.from_pretrained(args.encoder_tokenizer_name if args.encoder_tokenizer_name else args.encoder_model_name_or_path, do_lower_case=args.do_lower_case)
560
+ if args.block_size <= 0:
561
+ args.block_size = tokenizer_encoder.max_len_single_sentence # Our input block size will be the max possible for the model
562
+ args.block_size = min(args.block_size, tokenizer_encoder.max_len_single_sentence)
563
+ model_encoder = encoder_model_class.from_pretrained(args.encoder_model_name_or_path, from_tf=bool('.ckpt' in args.encoder_model_name_or_path), config=encoder_config, latent_size=args.latent_size)
564
+ # model_encoder.to(args.device)
565
+
566
+ ## Decoder
567
+ decoder_config_class, decoder_model_class, decoder_tokenizer_class = MODEL_CLASSES[args.decoder_model_type]
568
+ decoder_config = decoder_config_class.from_pretrained(args.decoder_config_name if args.decoder_config_name else args.decoder_model_name_or_path)
569
+ tokenizer_decoder = decoder_tokenizer_class.from_pretrained(args.decoder_tokenizer_name if args.decoder_tokenizer_name else args.decoder_model_name_or_path, do_lower_case=args.do_lower_case)
570
+ if args.block_size <= 0:
571
+ args.block_size = tokenizer_decoder.max_len_single_sentence # Our input block size will be the max possible for the model
572
+ args.block_size = min(args.block_size, tokenizer_decoder.max_len_single_sentence)
573
+ model_decoder = decoder_model_class.from_pretrained(args.decoder_model_name_or_path, from_tf=bool('.ckpt' in args.decoder_model_name_or_path), config=decoder_config, latent_size=args.latent_size)
574
+
575
+ # Chunyuan: Add Padding token to GPT2
576
+ special_tokens_dict = {'pad_token': '<PAD>', 'bos_token': '<BOS>', 'eos_token': '<EOS>'}
577
+ num_added_toks = tokenizer_decoder.add_special_tokens(special_tokens_dict)
578
+ print('We have added', num_added_toks, 'tokens to GPT2')
579
+ model_decoder.resize_token_embeddings(len(tokenizer_decoder)) # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer.
580
+ assert tokenizer_decoder.pad_token == '<PAD>'
581
+
582
+ # model_decoder.to(args.device)
583
+
584
+ model_vae = VAE(model_encoder, model_decoder, tokenizer_encoder, tokenizer_decoder, args).to(args.device) #
585
+
586
+ # on_gpu = next(model_vae.parameters()).is_cuda
587
+
588
+
589
+
590
+ if args.local_rank == 0:
591
+ torch.distributed.barrier() # End of barrier to make sure only the first process in distributed training download model & vocab
592
+
593
+ logger.info("Training/evaluation parameters %s", args)
594
+
595
+ global_step= 0
596
+ # Training
597
+ if args.do_train:
598
+ if args.local_rank not in [-1, 0]:
599
+ torch.distributed.barrier() # Barrier to make sure only the first process in distributed training process the dataset, and the others will use the cache
600
+
601
+ train_dataloader = build_dataload_and_cache_examples(args, [tokenizer_encoder, tokenizer_decoder], evaluate=False)
602
+
603
+ if args.local_rank == 0:
604
+ torch.distributed.barrier()
605
+
606
+ global_step, tr_loss = train(args, train_dataloader, model_vae, tokenizer_encoder, tokenizer_decoder, table_name)
607
+ logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
608
+
609
+
610
+ # Saving best-practices: if you use save_pretrained for the model and tokenizer, you can reload them using from_pretrained()
611
+ if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
612
+ # Create output directory if needed
613
+ # Save model checkpoint
614
+ output_encoder_dir = os.path.join(args.output_dir, 'checkpoint-encoder-{}'.format(global_step))
615
+ output_decoder_dir = os.path.join(args.output_dir, 'checkpoint-decoder-{}'.format(global_step))
616
+ if not os.path.exists(output_encoder_dir) and args.local_rank in [-1, 0]:
617
+ os.makedirs(output_encoder_dir)
618
+ if not os.path.exists(output_decoder_dir) and args.local_rank in [-1, 0]:
619
+ os.makedirs(output_decoder_dir)
620
+
621
+ logger.info("Saving encoder model checkpoint to %s", output_encoder_dir)
622
+ logger.info("Saving decoder model checkpoint to %s", output_decoder_dir)
623
+ # Save a trained model, configuration and tokenizer using `save_pretrained()`.
624
+ # They can then be reloaded using `from_pretrained()`
625
+
626
+ model_encoder_to_save = model_vae.module.encoder if hasattr(model_vae, 'module') else model_vae.encoder # Take care of distributed/parallel training
627
+ model_decoder_to_save = model_vae.module.decoder if hasattr(model_vae, 'module') else model_vae.decoder # Take care of distributed/parallel training
628
+
629
+ # Good practice: save your training arguments together with the trained model
630
+ if args.use_philly:
631
+ save_solid = False
632
+ while not save_solid:
633
+ try:
634
+ model_encoder_to_save.save_pretrained(output_encoder_dir)
635
+ torch.save(args, os.path.join(output_encoder_dir, 'training_encoder_args.bin'))
636
+ save_solid = True
637
+ except:
638
+ pass
639
+ else:
640
+ model_encoder_to_save.save_pretrained(output_encoder_dir)
641
+ torch.save(args, os.path.join(output_encoder_dir, 'training_encoder_args.bin'))
642
+
643
+
644
+ if args.use_philly:
645
+ save_solid = False
646
+ while not save_solid:
647
+ try:
648
+ model_decoder_to_save.save_pretrained(output_decoder_dir)
649
+ torch.save(args, os.path.join(output_decoder_dir, 'training_decoder_args.bin'))
650
+ save_solid = True
651
+ except:
652
+ pass
653
+ else:
654
+ model_decoder_to_save.save_pretrained(output_decoder_dir)
655
+ torch.save(args, os.path.join(output_decoder_dir, 'training_encoder_args.bin'))
656
+
657
+
658
+ # Load a trained model and vocabulary that you have fine-tuned
659
+ model_encoder = encoder_model_class.from_pretrained(output_encoder_dir, latent_size=args.latent_size)
660
+ model_encoder.to(args.device)
661
+
662
+ # Load a trained model and vocabulary that you have fine-tuned
663
+ model_decoder = decoder_model_class.from_pretrained(output_decoder_dir, latent_size=args.latent_size)
664
+ model_decoder.to(args.device)
665
+
666
+
667
+
668
+ if __name__ == "__main__":
669
+ main()