![TensorFlow Requirement: 1.x](https://img.shields.io/badge/TensorFlow%20Requirement-1.x-brightgreen) ![TensorFlow 2 Not Supported](https://img.shields.io/badge/TensorFlow%202%20Not%20Supported-%E2%9C%95-red.svg) # Adversarial Text Classification Code for [*Adversarial Training Methods for Semi-Supervised Text Classification*](https://arxiv.org/abs/1605.07725) and [*Semi-Supervised Sequence Learning*](https://arxiv.org/abs/1511.01432). ## Requirements * TensorFlow >= v1.3 ## End-to-end IMDB Sentiment Classification ### Fetch data ```bash $ wget http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz \ -O /tmp/imdb.tar.gz $ tar -xf /tmp/imdb.tar.gz -C /tmp ``` The directory `/tmp/aclImdb` contains the raw IMDB data. ### Generate vocabulary ```bash $ IMDB_DATA_DIR=/tmp/imdb $ python gen_vocab.py \ --output_dir=$IMDB_DATA_DIR \ --dataset=imdb \ --imdb_input_dir=/tmp/aclImdb \ --lowercase=False ``` Vocabulary and frequency files will be generated in `$IMDB_DATA_DIR`. ###  Generate training, validation, and test data ```bash $ python gen_data.py \ --output_dir=$IMDB_DATA_DIR \ --dataset=imdb \ --imdb_input_dir=/tmp/aclImdb \ --lowercase=False \ --label_gain=False ``` `$IMDB_DATA_DIR` contains TFRecords files. ### Pretrain IMDB Language Model ```bash $ PRETRAIN_DIR=/tmp/models/imdb_pretrain $ python pretrain.py \ --train_dir=$PRETRAIN_DIR \ --data_dir=$IMDB_DATA_DIR \ --vocab_size=86934 \ --embedding_dims=256 \ --rnn_cell_size=1024 \ --num_candidate_samples=1024 \ --batch_size=256 \ --learning_rate=0.001 \ --learning_rate_decay_factor=0.9999 \ --max_steps=100000 \ --max_grad_norm=1.0 \ --num_timesteps=400 \ --keep_prob_emb=0.5 \ --normalize_embeddings ``` `$PRETRAIN_DIR` contains checkpoints of the pretrained language model. ### Train classifier Most flags stay the same, save for the removal of candidate sampling and the addition of `pretrained_model_dir`, from which the classifier will load the pretrained embedding and LSTM variables, and flags related to adversarial training and classification. ```bash $ TRAIN_DIR=/tmp/models/imdb_classify $ python train_classifier.py \ --train_dir=$TRAIN_DIR \ --pretrained_model_dir=$PRETRAIN_DIR \ --data_dir=$IMDB_DATA_DIR \ --vocab_size=86934 \ --embedding_dims=256 \ --rnn_cell_size=1024 \ --cl_num_layers=1 \ --cl_hidden_size=30 \ --batch_size=64 \ --learning_rate=0.0005 \ --learning_rate_decay_factor=0.9998 \ --max_steps=15000 \ --max_grad_norm=1.0 \ --num_timesteps=400 \ --keep_prob_emb=0.5 \ --normalize_embeddings \ --adv_training_method=vat \ --perturb_norm_length=5.0 ``` ### Evaluate on test data ```bash $ EVAL_DIR=/tmp/models/imdb_eval $ python evaluate.py \ --eval_dir=$EVAL_DIR \ --checkpoint_dir=$TRAIN_DIR \ --eval_data=test \ --run_once \ --num_examples=25000 \ --data_dir=$IMDB_DATA_DIR \ --vocab_size=86934 \ --embedding_dims=256 \ --rnn_cell_size=1024 \ --batch_size=256 \ --num_timesteps=400 \ --normalize_embeddings ``` ## Code Overview The main entry points are the binaries listed below. Each training binary builds a `VatxtModel`, defined in `graphs.py`, which in turn uses graph building blocks defined in `inputs.py` (defines input data reading and parsing), `layers.py` (defines core model components), and `adversarial_losses.py` (defines adversarial training losses). The training loop itself is defined in `train_utils.py`. ### Binaries * Pretraining: `pretrain.py` * Classifier Training: `train_classifier.py` * Evaluation: `evaluate.py` ### Command-Line Flags Flags related to distributed training and the training loop itself are defined in [`train_utils.py`](https://github.com/tensorflow/models/tree/master/research/adversarial_text/train_utils.py). Flags related to model hyperparameters are defined in [`graphs.py`](https://github.com/tensorflow/models/tree/master/research/adversarial_text/graphs.py). Flags related to adversarial training are defined in [`adversarial_losses.py`](https://github.com/tensorflow/models/tree/master/research/adversarial_text/adversarial_losses.py). Flags particular to each job are defined in the main binary files. ### Data Generation * Vocabulary generation: [`gen_vocab.py`](https://github.com/tensorflow/models/tree/master/research/adversarial_text/gen_vocab.py) * Data generation: [`gen_data.py`](https://github.com/tensorflow/models/tree/master/research/adversarial_text/gen_data.py) Command-line flags defined in [`document_generators.py`](https://github.com/tensorflow/models/tree/master/research/adversarial_text/data/document_generators.py) control which dataset is processed and how. ## Contact for Issues * Ryan Sepassi, @rsepassi * Andrew M. Dai, @a-dai * Takeru Miyato, @takerum (Original implementation)