Spaces:
Running
Running
File size: 4,950 Bytes
0b8359d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
![TensorFlow Requirement: 1.x](https://img.shields.io/badge/TensorFlow%20Requirement-1.x-brightgreen)
![TensorFlow 2 Not Supported](https://img.shields.io/badge/TensorFlow%202%20Not%20Supported-%E2%9C%95-red.svg)
# Adversarial Text Classification
Code for [*Adversarial Training Methods for Semi-Supervised Text Classification*](https://arxiv.org/abs/1605.07725) and [*Semi-Supervised Sequence Learning*](https://arxiv.org/abs/1511.01432).
## Requirements
* TensorFlow >= v1.3
## End-to-end IMDB Sentiment Classification
### Fetch data
```bash
$ wget http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz \
-O /tmp/imdb.tar.gz
$ tar -xf /tmp/imdb.tar.gz -C /tmp
```
The directory `/tmp/aclImdb` contains the raw IMDB data.
### Generate vocabulary
```bash
$ IMDB_DATA_DIR=/tmp/imdb
$ python gen_vocab.py \
--output_dir=$IMDB_DATA_DIR \
--dataset=imdb \
--imdb_input_dir=/tmp/aclImdb \
--lowercase=False
```
Vocabulary and frequency files will be generated in `$IMDB_DATA_DIR`.
### Generate training, validation, and test data
```bash
$ python gen_data.py \
--output_dir=$IMDB_DATA_DIR \
--dataset=imdb \
--imdb_input_dir=/tmp/aclImdb \
--lowercase=False \
--label_gain=False
```
`$IMDB_DATA_DIR` contains TFRecords files.
### Pretrain IMDB Language Model
```bash
$ PRETRAIN_DIR=/tmp/models/imdb_pretrain
$ python pretrain.py \
--train_dir=$PRETRAIN_DIR \
--data_dir=$IMDB_DATA_DIR \
--vocab_size=86934 \
--embedding_dims=256 \
--rnn_cell_size=1024 \
--num_candidate_samples=1024 \
--batch_size=256 \
--learning_rate=0.001 \
--learning_rate_decay_factor=0.9999 \
--max_steps=100000 \
--max_grad_norm=1.0 \
--num_timesteps=400 \
--keep_prob_emb=0.5 \
--normalize_embeddings
```
`$PRETRAIN_DIR` contains checkpoints of the pretrained language model.
### Train classifier
Most flags stay the same, save for the removal of candidate sampling and the
addition of `pretrained_model_dir`, from which the classifier will load the
pretrained embedding and LSTM variables, and flags related to adversarial
training and classification.
```bash
$ TRAIN_DIR=/tmp/models/imdb_classify
$ python train_classifier.py \
--train_dir=$TRAIN_DIR \
--pretrained_model_dir=$PRETRAIN_DIR \
--data_dir=$IMDB_DATA_DIR \
--vocab_size=86934 \
--embedding_dims=256 \
--rnn_cell_size=1024 \
--cl_num_layers=1 \
--cl_hidden_size=30 \
--batch_size=64 \
--learning_rate=0.0005 \
--learning_rate_decay_factor=0.9998 \
--max_steps=15000 \
--max_grad_norm=1.0 \
--num_timesteps=400 \
--keep_prob_emb=0.5 \
--normalize_embeddings \
--adv_training_method=vat \
--perturb_norm_length=5.0
```
### Evaluate on test data
```bash
$ EVAL_DIR=/tmp/models/imdb_eval
$ python evaluate.py \
--eval_dir=$EVAL_DIR \
--checkpoint_dir=$TRAIN_DIR \
--eval_data=test \
--run_once \
--num_examples=25000 \
--data_dir=$IMDB_DATA_DIR \
--vocab_size=86934 \
--embedding_dims=256 \
--rnn_cell_size=1024 \
--batch_size=256 \
--num_timesteps=400 \
--normalize_embeddings
```
## Code Overview
The main entry points are the binaries listed below. Each training binary builds
a `VatxtModel`, defined in `graphs.py`, which in turn uses graph building blocks
defined in `inputs.py` (defines input data reading and parsing), `layers.py`
(defines core model components), and `adversarial_losses.py` (defines
adversarial training losses). The training loop itself is defined in
`train_utils.py`.
### Binaries
* Pretraining: `pretrain.py`
* Classifier Training: `train_classifier.py`
* Evaluation: `evaluate.py`
### Command-Line Flags
Flags related to distributed training and the training loop itself are defined
in [`train_utils.py`](https://github.com/tensorflow/models/tree/master/research/adversarial_text/train_utils.py).
Flags related to model hyperparameters are defined in [`graphs.py`](https://github.com/tensorflow/models/tree/master/research/adversarial_text/graphs.py).
Flags related to adversarial training are defined in [`adversarial_losses.py`](https://github.com/tensorflow/models/tree/master/research/adversarial_text/adversarial_losses.py).
Flags particular to each job are defined in the main binary files.
### Data Generation
* Vocabulary generation: [`gen_vocab.py`](https://github.com/tensorflow/models/tree/master/research/adversarial_text/gen_vocab.py)
* Data generation: [`gen_data.py`](https://github.com/tensorflow/models/tree/master/research/adversarial_text/gen_data.py)
Command-line flags defined in [`document_generators.py`](https://github.com/tensorflow/models/tree/master/research/adversarial_text/data/document_generators.py)
control which dataset is processed and how.
## Contact for Issues
* Ryan Sepassi, @rsepassi
* Andrew M. Dai, @a-dai <adai@google.com>
* Takeru Miyato, @takerum (Original implementation)
|