Dusan Svilarkovic commited on
Commit
fc5ecba
1 Parent(s): 7872a22

Adding Fudge

Browse files
Files changed (46) hide show
  1. naacl-2021-fudge-controlled-generation/LICENSE +21 -0
  2. naacl-2021-fudge-controlled-generation/README.md +155 -0
  3. naacl-2021-fudge-controlled-generation/clickbait_classifier.py +128 -0
  4. naacl-2021-fudge-controlled-generation/constants.py +32 -0
  5. naacl-2021-fudge-controlled-generation/data.py +415 -0
  6. naacl-2021-fudge-controlled-generation/eval_formality_metrics.py +73 -0
  7. naacl-2021-fudge-controlled-generation/eval_poetry_metrics.py +135 -0
  8. naacl-2021-fudge-controlled-generation/eval_topic_metrics.py +134 -0
  9. naacl-2021-fudge-controlled-generation/evaluate_clickbait.py +200 -0
  10. naacl-2021-fudge-controlled-generation/evaluate_formality.py +104 -0
  11. naacl-2021-fudge-controlled-generation/evaluate_poetry.py +115 -0
  12. naacl-2021-fudge-controlled-generation/evaluate_topic.py +143 -0
  13. naacl-2021-fudge-controlled-generation/formality_data/README.md +2 -0
  14. naacl-2021-fudge-controlled-generation/formality_data/fisher_test_oracle.es +0 -0
  15. naacl-2021-fudge-controlled-generation/formality_data/test.noid.cleaned_0 +0 -0
  16. naacl-2021-fudge-controlled-generation/formality_data/test.noid.cleaned_1 +0 -0
  17. naacl-2021-fudge-controlled-generation/main.py +192 -0
  18. naacl-2021-fudge-controlled-generation/model.py +182 -0
  19. naacl-2021-fudge-controlled-generation/poetry_data/README.md +1 -0
  20. naacl-2021-fudge-controlled-generation/poetry_data/couplet_ends.txt +154 -0
  21. naacl-2021-fudge-controlled-generation/poetry_data/couplet_prefixes.txt +154 -0
  22. naacl-2021-fudge-controlled-generation/poetry_util.py +83 -0
  23. naacl-2021-fudge-controlled-generation/predict_clickbait.py +199 -0
  24. naacl-2021-fudge-controlled-generation/predict_formality.py +404 -0
  25. naacl-2021-fudge-controlled-generation/predict_poetry.py +219 -0
  26. naacl-2021-fudge-controlled-generation/predict_topic.py +126 -0
  27. naacl-2021-fudge-controlled-generation/requirements.txt +7 -0
  28. naacl-2021-fudge-controlled-generation/topic_data/README.md +3 -0
  29. naacl-2021-fudge-controlled-generation/topic_data/test_wordlists/computers.txt +163 -0
  30. naacl-2021-fudge-controlled-generation/topic_data/test_wordlists/legal.txt +108 -0
  31. naacl-2021-fudge-controlled-generation/topic_data/test_wordlists/military.txt +136 -0
  32. naacl-2021-fudge-controlled-generation/topic_data/test_wordlists/politics.txt +40 -0
  33. naacl-2021-fudge-controlled-generation/topic_data/test_wordlists/religion.txt +207 -0
  34. naacl-2021-fudge-controlled-generation/topic_data/test_wordlists/science.txt +47 -0
  35. naacl-2021-fudge-controlled-generation/topic_data/test_wordlists/space.txt +16 -0
  36. naacl-2021-fudge-controlled-generation/topic_data/topic_prefixes.txt +20 -0
  37. naacl-2021-fudge-controlled-generation/topic_data/val_wordlists/fantasy.txt +26 -0
  38. naacl-2021-fudge-controlled-generation/topic_data/wordlists/computers.txt +176 -0
  39. naacl-2021-fudge-controlled-generation/topic_data/wordlists/legal.txt +131 -0
  40. naacl-2021-fudge-controlled-generation/topic_data/wordlists/military.txt +149 -0
  41. naacl-2021-fudge-controlled-generation/topic_data/wordlists/politics.txt +47 -0
  42. naacl-2021-fudge-controlled-generation/topic_data/wordlists/religion.txt +232 -0
  43. naacl-2021-fudge-controlled-generation/topic_data/wordlists/science.txt +48 -0
  44. naacl-2021-fudge-controlled-generation/topic_data/wordlists/space.txt +18 -0
  45. naacl-2021-fudge-controlled-generation/transcript.txt +415 -0
  46. naacl-2021-fudge-controlled-generation/util.py +110 -0
naacl-2021-fudge-controlled-generation/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2022 Kevin Yang
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
naacl-2021-fudge-controlled-generation/README.md ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # FUDGE: Controlled Text Generation With Future Discriminators
2
+
3
+ This repo contains code corresponding to the paper FUDGE: Controlled Text Generation With Future Discriminators (https://arxiv.org/abs/2104.05218) by Kevin Yang and Dan Klein, published at NAACL 2021.
4
+
5
+ You can also find a video presentation at http://somup.com/crhlVPFKN7 and the corresponding slides in `slides.pptx`.
6
+
7
+ ## Setup/Installation
8
+
9
+ We tested on Python 3.8.5 but earlier versions of Python 3 are almost certainly fine. To get the required packages (other versions likely to work too):
10
+
11
+ ```
12
+ pip install -r requirements.txt
13
+ ```
14
+
15
+ Additionally, to get our pre-trained predictor checkpoints and training data, run:
16
+
17
+ ```
18
+ wget https://naacl2021-fudge-files.s3.amazonaws.com/large_files.zip
19
+ ```
20
+
21
+ and extract the zip to the top-level `lm-prediction/` folder. (There should be three folders, `ckpt/`, `train_data/`, and `topic_human_evals/`. The zip is 7GB.) Note: the zip seems to not work for some people actually, if this is the case you can get the files directly from https://drive.google.com/drive/folders/1GZfOGqpQxDmIfD2RvuhUQla9eX2OHUXU?usp=sharing (13GB).
22
+
23
+ `ckpt/` contains predictor checkpoints for each task if you are just interested in running inference. (Note that for the paper results, we used predictors trained with an older version of the code, but the new checkpoints get similar results, so you are OK to use the new predictors provided here if e.g. you just want to use FUDGE as a baseline. You can just run the evaluation commands provided below; it should take maybe 5-60 minutes depending on the task and your compute, assuming you have a GPU.)
24
+
25
+ `train_data/` contains our GPT2-generated training data for the poetry and topic tasks' predictors. See https://github.com/raosudha89/GYAFC-corpus for instructions on gaining access to the GYAFC data used for the machine translation formality task; replace our dummy folders with the corresponding folders/files if you want to train our formality predictor.
26
+
27
+ ## Clickbait
28
+ To generate outputs, run:
29
+
30
+ ```
31
+ python -u evaluate_clickbait.py --ckpt ckpt/topic/future_word_predictor/model.pth.tar --dataset_info ckpt/topic/future_word_predictor/dataset_info --in_file topic_data/topic_prefixes.txt --condition_lambda 4.0 --verbose --precondition_topk 200 --length_cutoff 80 --device cpu
32
+
33
+ python -u evaluate_clickbait.py --ckpt ckpt/formality/predictor_gyafc_entertainment_music/model.pth.tar --dataset_info ckpt/formality/predictor_gyafc_entertainment_music/dataset_info --in_file formality_data/fisher_test_oracle.es
34
+
35
+ python -u evaluate_clickbait.py --ckpt ckpt/topic/future_word_predictor/model.pth.tar --dataset_info ckpt/topic/future_word_predictor/dataset_info --in_file topic_data/topic_prefixes.txt --condition_lambda 4.0 --verbose --precondition_topk 200 --sample_size 3 --max_sample_batch 1 --length_cutoff 80 --log_file clickbait_preds.log
36
+ ```
37
+
38
+ Then evaluate metrics using:
39
+
40
+ ```
41
+ python eval_topic_metrics.py --log_file topic_preds.log --tw_dir topic_data/test_wordlists
42
+ ```
43
+
44
+
45
+ ## Poetry Couplet Completion
46
+
47
+ ### Evaluation
48
+
49
+ To generate outputs, run:
50
+
51
+ ```
52
+ python -u evaluate_poetry.py --iambic_ckpt ckpt/poetry/iambic_predictor/model.pth.tar --rhyme_ckpt ckpt/poetry/rhyme_predictor/model.pth.tar --newline_ckpt ckpt/poetry/newline_predictor/model.pth.tar --dataset_info ckpt/poetry/rhyme_predictor/dataset_info --rhyme_info ckpt/poetry/rhyme_predictor/rhyme_info --prefix_file poetry_data/couplet_prefixes.txt --precondition_topk 200 > poetry_preds.log
53
+ ```
54
+
55
+ Then evaluate metrics using:
56
+
57
+ ```
58
+ python eval_poetry_metrics.py --pred_file poetry_preds.log --prefix_file poetry_data/couplet_prefixes.txt
59
+ ```
60
+
61
+ ### Training your own predictors
62
+
63
+ Example commands for all three predictors used in the poetry task below. (You actually probably don't need so many epochs for iambic and rhyme; in any case the commands will save intermediate ckpts so you can just stop them early if needed by inspecting the log.)
64
+
65
+ Iambic predictor:
66
+
67
+ ```
68
+ python -u main.py --task iambic --data_dir train_data/gpt2_generations --save_dir ckpt/poetry/iambic_retrain_predictor --num_workers 20 --batch_size 128 --epoch_max_len 100000 --validation_freq 10 --lr 2e-4 --epochs 1500 > iambic_retrain_predictor.log
69
+ ```
70
+
71
+ Rhyme predictor:
72
+
73
+ ```
74
+ python -u main.py --task rhyme --data_dir train_data/gpt2_generations --save_dir ckpt/poetry/rhyme_retrain_predictor --num_workers 20 --batch_size 128 --epoch_max_len 100000 --validation_freq 10 --lr 2e-4 --epochs 1500 > rhyme_retrain_predictor.log
75
+ ```
76
+
77
+ End of sentence predictor (referred to as "newline" in the code; 50 epochs is more than enough for this one):
78
+
79
+ ```
80
+ python -u main.py --task newline --data_dir train_data/gpt2_generations --save_dir ckpt/poetry/newline_retrain_predictor --num_workers 20 --batch_size 128 --epoch_max_len 100000 --validation_freq 10 --lr 2e-4 --epochs 50 > newline_retrain_predictor.log
81
+ ```
82
+
83
+ The same evaluation commands as before will work; just modify the paths in the command to point to `model_best.pth.tar`, `dataset_info`, and `rhyme_info` from your newly trained ckpt folders.
84
+
85
+ ## Topic Control
86
+
87
+ ### Evaluation
88
+
89
+ To generate outputs, run:
90
+
91
+ ```
92
+ python -u evaluate_topic.py --ckpt ckpt/topic/future_word_predictor/model.pth.tar --dataset_info ckpt/topic/future_word_predictor/dataset_info --prefix_file topic_data/topic_prefixes.txt --wordlist_dir topic_data/wordlists --condition_lambda 4.0 --verbose --precondition_topk 200 --topk 10 --sample_size 3 --max_sample_batch 1 --length_cutoff 80 --log_file topic_preds.log
93
+ ```
94
+
95
+ Then evaluate metrics using:
96
+
97
+ ```
98
+ python eval_topic_metrics.py --log_file topic_preds.log --tw_dir topic_data/test_wordlists
99
+ ```
100
+
101
+ You can also find our original generations and baselines in `topic_human_evals/`.
102
+
103
+ ### Training your own predictors
104
+
105
+ Example command below.
106
+
107
+ ```
108
+ python -u main.py --task topic --data_dir train_data/gpt2_generations --save_dir ckpt/topic/future_word_retrain_predictor --num_workers 20 --batch_size 128 --epoch_max_len 100000 --validation_freq 10 --lr 2e-4 --epochs 500 --glove_file train_data/glove.840B.300d.txt > future_word_retrain_predictor.log
109
+ ```
110
+
111
+ The same evaluation commands as before will work; just modify the paths in the command to point to `model_best.pth.tar`, `dataset_info`, and `rhyme_info` from your newly trained ckpt folders.
112
+
113
+ ## Machine Translation Formality
114
+
115
+ ### Evaluation
116
+
117
+ To generate outputs, run:
118
+
119
+ ```
120
+ python -u evaluate_formality.py --ckpt ckpt/formality/predictor_gyafc_entertainment_music/model.pth.tar --dataset_info ckpt/formality/predictor_gyafc_entertainment_music/dataset_info --in_file formality_data/fisher_test_oracle.es --model_path ckpt/formality/marian_finetune_fisher > formality_preds.log
121
+ ```
122
+
123
+ The above command generates predictions using the Marian model finetuned on the Fisher dataset; remove the `--model_path` argument to get predictions with the un-finetuned Marian model from HuggingFace (referred to as 0-shot in the paper)
124
+
125
+ Then evaluate metrics using:
126
+
127
+ ```
128
+ python eval_formality_metrics.py --pred formality_preds.log --ref formality_data/test.noid.cleaned_0 formality_data/test.noid.cleaned_1 --ckpt ckpt/formality/test_evaluator_gyafc_family_relationships/model.pth.tar --dataset_info ckpt/formality/test_evaluator_gyafc_family_relationships/dataset_info
129
+ ```
130
+
131
+ ### Training your own predictors
132
+
133
+ Example command below. (Reminder: you need to go get the GYAFC dataset following the instructions in https://github.com/raosudha89/GYAFC-corpus.)
134
+
135
+ ```
136
+ python -u main.py --task formality --data_dir train_data/GYAFC_Corpus/Entertainment_Music --save_dir ckpt/formality/formality_retrain_predictor --num_workers 20 --batch_size 32 --epoch_max_len 1000000 --validation_freq 1 --lr 2e-5 --epochs 20 > formality_retrain_predictor.log
137
+ ```
138
+
139
+ (The test-time formality evaluator is trained in the same way, just using the Family/Relationships half of the GYAFC dataset.)
140
+
141
+ The same evaluation commands as before will work; just modify the paths in the command to point to `model_best.pth.tar`, `dataset_info`, and `rhyme_info` from your newly trained ckpt folders.
142
+
143
+ ## Running FUDGE on your own data
144
+
145
+ The code has been refactored so that the iambic (poetry), rhyme (poetry), newline (poetry), future word (topic), and formality (machine translation) are controlled by the `--task` flag to `main.py`. You should add your task as another option here, then modify the data processing in `data.py` and the model in `model.py` as needed for your task. (In `data.py` you probably won't need all the entries of the tuple that is expected of the loader; you can just put dummy entries in the ones you don't need.) You might also need to modify the loss computation in the `train` and `validate` functions in `main.py`. You'll probably want to write new evaluation scripts, though the existing poetry/topic/formality ones are hopefully helpful as references.
146
+
147
+ Alternatively, the general FUDGE framework is pretty simple, so you could always try reimplementing things yourself. A few additional details based on questions I've received:
148
+
149
+ (1) The formality task setup is likely closest to what you want if you're just trying to run the simplest form of FUDGE (take a language model, and use a classifier to optimize toward a single attribute) although you may need to swap out the Marian translation model/tokenizer we use.
150
+
151
+ (2) When you construct your training data, if you have an example in your data e.g. "This movie is great!" for positive sentiment, you want to learn on all the pairs (This, +), (This movie, +), (This movie is, +), etc., as that's one of the main points of our approach.
152
+
153
+ (3) For computational efficiency, we first filter the base model's next token probabilities down to the top 200 (Sec. 3.1 in the paper), before adding the classifier logits. This way you only need to evaluate your classifier on 200 continuations. Then afterward, you filter down again to whatever top-k/greedy/nucleus sampling you're using for evaluation (we use top-k with k=10 for poetry and topic, greedy for formality).
154
+
155
+ (4) You can use a pretrained LM backbone instead of a simple LSTM backbone for the predictor as well. This should work better when your dataset is smaller.
naacl-2021-fudge-controlled-generation/clickbait_classifier.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import BertModel, BertConfig, PretrainedConfig, PreTrainedModel, AutoModel, AutoConfig
3
+ from typing import List, Optional, Tuple, Union
4
+ from transformers.modeling_outputs import TokenClassifierOutput,SequenceClassifierOutput
5
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss, BCELoss
6
+ import torch.nn as nn
7
+ # from modeling_mpnet import MPNetModel, MPnetConfig
8
+
9
+ class ClickbaitConfig(PretrainedConfig):
10
+ def __init__(
11
+ self,
12
+ model_type: str = "bert",
13
+ pretrained_model: str = "bert-base-uncased",
14
+ num_labels: int = 1,
15
+ dropout: float = 0.1,
16
+ inner_dim1: int = 256,
17
+ inner_dim2: int = 32,
18
+ max_length: int = 512,
19
+ load_pretrained: bool = True,
20
+ freeze_bert: bool = True,
21
+ **kwargs
22
+ ):
23
+ super(ClickbaitConfig, self).__init__(num_labels=num_labels, **kwargs)
24
+ self.model_type = model_type
25
+ self.pretrained_model = pretrained_model
26
+ self.dropout = dropout
27
+ self.inner_dim1 = inner_dim1
28
+ self.inner_dim2 = inner_dim2
29
+ self.max_length = max_length
30
+ self.load_pretrained = load_pretrained
31
+ self.freeze_bert = freeze_bert
32
+
33
+
34
+ class BertClickbaitClassifier(PreTrainedModel):
35
+ """
36
+ Taken and extended from BertforSequenceClassification : https://github.com/huggingface/transformers/blob/v4.19.2/src/transformers/models/bert/modeling_bert.py#L1508
37
+ """
38
+ config_class = ClickbaitConfig
39
+ def __init__(self, config: ClickbaitConfig):
40
+ super(BertClickbaitClassifier, self).__init__(config)
41
+ self.num_labels = config.num_labels
42
+ self.config = config
43
+ # self.bert_config = BertConfig.from_pretrained(config.pretrained_model)
44
+ self.bert_config = AutoConfig.from_pretrained(config.pretrained_model)
45
+
46
+ # self.bert = BertModel(self.bert_config)
47
+ self.bert = AutoModel.from_pretrained(config.pretrained_model, config=self.bert_config)
48
+ # self.bert = SentenceTransformer(config.pretrained_model, config=self.bert_config)
49
+ # self.bert = MPNetModel(config.pretrained_model, config=self.bert_config)
50
+ if config.load_pretrained:
51
+ print("Load pretrained weights from {}".format(config.pretrained_model))
52
+ self.bert = self.bert.from_pretrained(config.pretrained_model)
53
+ if config.freeze_bert:
54
+ print("Freeze weights in the BERT model. Just the classifier will be trained")
55
+ for param in self.bert.parameters():
56
+ param.requires_grad = False
57
+
58
+ self.linear_1 = nn.Linear(self.bert.config.hidden_size, config.inner_dim1)
59
+ self.dropout_1 = nn.Dropout(config.dropout)
60
+ self.relu_1 = nn.ReLU()
61
+ self.dropout_2 = nn.Dropout(config.dropout)
62
+ self.linear_2 = nn.Linear(config.inner_dim1, config.inner_dim2)
63
+ self.relu_2 = nn.ReLU()
64
+ self.dropout_3 = nn.Dropout(config.dropout)
65
+ self.classifier = nn.Linear(config.inner_dim2, config.num_labels)
66
+ self.sigmoid = nn.Sigmoid()
67
+
68
+
69
+ def forward(
70
+ self,
71
+ input_ids: Optional[torch.Tensor] = None,
72
+ attention_mask: Optional[torch.Tensor] = None,
73
+ token_type_ids: Optional[torch.Tensor] = None,
74
+ position_ids: Optional[torch.Tensor] = None,
75
+ head_mask: Optional[torch.Tensor] = None,
76
+ inputs_embeds: Optional[torch.Tensor] = None,
77
+ labels: Optional[torch.Tensor] = None,
78
+ output_attentions: Optional[bool] = None,
79
+ output_hidden_states: Optional[bool] = None,
80
+ return_dict: Optional[bool] = None,
81
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
82
+ r"""
83
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
84
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
85
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
86
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
87
+ """
88
+
89
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
90
+
91
+ outputs = self.bert(
92
+ input_ids,
93
+ attention_mask=attention_mask,
94
+ token_type_ids=token_type_ids,
95
+ position_ids=position_ids,
96
+ head_mask=head_mask,
97
+ inputs_embeds=inputs_embeds,
98
+ output_attentions=output_attentions,
99
+ output_hidden_states=output_hidden_states,
100
+ return_dict=return_dict,
101
+ )
102
+
103
+ output = outputs[0][:,0,:]
104
+
105
+ x = self.dropout_1(output)
106
+ x = self.linear_1(x)
107
+ x = self.relu_1(x)
108
+ x = self.dropout_2(x)
109
+ x = self.linear_2(x)
110
+ x = self.relu_2(x)
111
+ x = self.dropout_3(x)
112
+
113
+ logits = self.classifier(x)
114
+ logits = self.sigmoid(logits)
115
+
116
+ loss = None
117
+ if labels is not None:
118
+ loss_fct = BCELoss(weight=WEIGHT)
119
+ labels = 1.0*labels
120
+ loss = loss_fct(logits.view(-1), labels.view(-1))
121
+ if not return_dict:
122
+ output = (logits,) + outputs[2:]
123
+ return ((loss,) + output) if loss is not None else output
124
+
125
+ return SequenceClassifierOutput(
126
+ loss=loss,
127
+ logits=logits
128
+ )
naacl-2021-fudge-controlled-generation/constants.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ PAD_TOKEN = '[PAD]'
2
+ EOT_TOKEN = '<|endoftext|>'
3
+ SEP = 50256 # just use the weird eot token
4
+
5
+ TOPIC_MODEL_STRING = 'gpt2-medium'
6
+ FORMALITY_MODEL_STRING = 'Helsinki-NLP/opus-mt-es-en'
7
+
8
+ DIR_END_SPLIT_POSITIONS = 32
9
+
10
+ TOPIC_VAL_SIZE = 100000
11
+ FORMALITY_VAL_SIZE = 2000
12
+ VOCAB_SIZE = 50000
13
+
14
+ FORMALITY_MAX_LEN = 200
15
+
16
+ GLOVE_PRINT_PROGRESS_FREQ = 1000000
17
+ GLOVE_DIM = 300
18
+ HIDDEN_DIM = 300
19
+ RNN_DIM = 150
20
+
21
+ MIN_SENTENCE_LENGTH = 3
22
+
23
+ POETRY_LINE_SYLLABLES = 10
24
+ MAX_SYLLABLES_PER_WORD = 10 # no way anything is more
25
+ MAX_COUNT_SYLLABLE_DIST = 10
26
+ MAX_COUNT_SYLLABLE_INPUT_LENGTH = 25 # for just a couplet, shouldn't need more
27
+ COUNT_SYLLABLE_DIM = 100
28
+ UNKNOWN_RHYME_GROUP = 'UNKNOWN_RHYME_GROUP'
29
+ PHRASE_ENDS = '.?!'
30
+
31
+ POETRY_BANNED_TOKENS = [198, 50256, 628, 220] # newlines and eos and such
32
+
naacl-2021-fudge-controlled-generation/data.py ADDED
@@ -0,0 +1,415 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import math
3
+ import os
4
+ import pickle
5
+ from collections import defaultdict, namedtuple
6
+ import string
7
+
8
+ os.environ['TOKENIZERS_PARALLELISM'] = 'false' # turn off since we're using multiple threads for loading anyway
9
+
10
+ from transformers import AutoTokenizer, AutoModelWithLMHead, pipeline, set_seed, GPT2Tokenizer, GPT2Model
11
+ import numpy as np
12
+ from tqdm import tqdm
13
+ import torch
14
+
15
+ from util import suppress_stdout
16
+ from poetry_util import is_iambic, count_syllables, get_rhymes, get_rhyme_group
17
+ from constants import *
18
+
19
+ DatasetInfo = namedtuple('DatasetInfo',
20
+ ['index2word', 'word2index', 'total_words', 'vocab', 'glove_embeddings'])
21
+ RhymeInfo = namedtuple('RhymeInfo',
22
+ ['word2rhyme_group', 'rhyme_group_counts', 'rhyme_groups', 'index2rhyme_group', 'rhyme_group2index', 'total_rhyme_groups'])
23
+
24
+ def collate(batch):
25
+ pad_id = batch[0][4]
26
+ inputs = [b[0] for b in batch]
27
+ lengths = torch.LongTensor([b[1] for b in batch])
28
+ max_length = lengths.max()
29
+ for i in range(len(inputs)):
30
+ if len(inputs[i]) < max_length:
31
+ inputs[i] = torch.cat([inputs[i], torch.zeros(max_length - len(inputs[i])).long()], dim=0) # actually 0 is fine as pad since it's masked out
32
+ inputs = torch.stack(inputs, dim=0)
33
+ future_words = torch.LongTensor([b[2] for b in batch]).unsqueeze(0).expand(len(batch), -1).clone() # batch x N=batch
34
+ labels = torch.zeros_like(future_words).long()
35
+ labels = labels.scatter(1, torch.arange(len(batch)).unsqueeze(1), torch.ones(len(batch)).long().unsqueeze(1)).clone()
36
+ log_probs = torch.Tensor([b[3] for b in batch])
37
+ classification_labels = [b[5] for b in batch] # batch
38
+ if type(classification_labels[0]) == list:
39
+ for i in range(len(classification_labels)):
40
+ assert len(classification_labels[i]) == lengths[i]
41
+ if len(classification_labels[i]) < max_length:
42
+ classification_labels[i] = torch.cat([torch.LongTensor(classification_labels[i]), -1 + torch.zeros(max_length - len(classification_labels[i])).long()], dim=0)
43
+ else:
44
+ classification_labels[i] = torch.LongTensor(classification_labels[i])
45
+ classification_labels = torch.stack(classification_labels, dim=0) # batch x seq
46
+ else:
47
+ assert type(classification_labels[0]) == int
48
+ classification_labels = torch.LongTensor(classification_labels) # they're just int labels
49
+ syllables_to_go = torch.LongTensor([b[6] for b in batch])
50
+ future_word_num_syllables = torch.LongTensor([b[7] for b in batch])
51
+ rhyme_group_index = torch.LongTensor([b[8] for b in batch])
52
+ return (inputs, lengths, future_words, log_probs, labels, classification_labels, syllables_to_go, future_word_num_syllables, rhyme_group_index)
53
+
54
+
55
+ def load_rhyme_info(index2word, vocab):
56
+ word2rhyme_group = defaultdict(lambda: UNKNOWN_RHYME_GROUP)
57
+ rhyme_group_counts = defaultdict(lambda: 0)
58
+ rhyme_groups = set()
59
+ for word in index2word:
60
+ try:
61
+ rhyme_group = get_rhyme_group(word)
62
+ word2rhyme_group[word] = rhyme_group
63
+ rhyme_group_counts[rhyme_group] += (vocab[word] if word in vocab else 1) # for rare words not in vocab, just use 1
64
+ rhyme_groups.add(rhyme_group)
65
+ except:
66
+ rhyme_group_counts[UNKNOWN_RHYME_GROUP] += (vocab[word] if word in vocab else 1)
67
+ index2rhyme_group = [UNKNOWN_RHYME_GROUP] + sorted(list(rhyme_groups))
68
+ rhyme_group2index = {s: i for i, s in enumerate(index2rhyme_group)}
69
+ total_rhyme_groups = sum(rhyme_group_counts.values())
70
+
71
+ return RhymeInfo(word2rhyme_group=dict(word2rhyme_group),
72
+ rhyme_group_counts=dict(rhyme_group_counts),
73
+ rhyme_groups=rhyme_groups,
74
+ index2rhyme_group=index2rhyme_group,
75
+ rhyme_group2index=rhyme_group2index,
76
+ total_rhyme_groups=total_rhyme_groups)
77
+
78
+
79
+ class Dataset:
80
+ def __init__(self, args):
81
+ print('loading data')
82
+ random.seed(args.seed)
83
+ self.batch_size = args.batch_size
84
+ self.data_dir = args.data_dir
85
+ self.topic = args.task == 'topic'
86
+ self.formality = args.task == 'formality'
87
+ self.iambic = args.task == 'iambic'
88
+ self.rhyme = args.task == 'rhyme'
89
+ self.newline = args.task == 'newline'
90
+
91
+ self.tokenizer = AutoTokenizer.from_pretrained(FORMALITY_MODEL_STRING if self.formality else TOPIC_MODEL_STRING)
92
+ self.tokenizer.add_special_tokens({'pad_token': PAD_TOKEN})
93
+ self.gpt_pad_id = self.tokenizer.encode(PAD_TOKEN)[0] # actually just the vocab size
94
+ sentences = []
95
+ self.vocab = defaultdict(lambda: 0)
96
+ if self.formality:
97
+ self.vocab['placeholder'] = 1 # anything so we don't crash
98
+ train, val, test = [], [], []
99
+ for category, label in [('formal', 1), ('informal', 0)]:
100
+ with open(os.path.join(args.data_dir, 'train', category), 'r') as rf:
101
+ for i, line in enumerate(rf):
102
+ if len(line) > FORMALITY_MAX_LEN:
103
+ line = ' '.join(line.strip()[:FORMALITY_MAX_LEN].split()[:-1]) # cutoff words until below max len; chosen so only ~20 examples affected in dataset
104
+ if i < FORMALITY_VAL_SIZE // 2:
105
+ val.append((line.strip(), label))
106
+ else:
107
+ train.append((line.strip(), label))
108
+ with open(os.path.join(args.data_dir, 'test', category), 'r') as rf:
109
+ for line in rf:
110
+ if len(line) > FORMALITY_MAX_LEN:
111
+ line = ' '.join(line.strip()[:FORMALITY_MAX_LEN].split()[:-1]) # cutoff words until below max len
112
+ test.append((line.strip(), label))
113
+ self.splits = {}
114
+ self.splits['train'], self.splits['val'], self.splits['test'] = train, val, test
115
+ else: # topic / poetry
116
+ for root, _, filenames in os.walk(args.data_dir):
117
+ for fname in filenames:
118
+ with open(os.path.join(root, fname), 'r') as rf:
119
+ for line in rf:
120
+ sentences.append(line.strip())
121
+ for word in line.strip().split(' '):
122
+ self.vocab[word] += 1
123
+ random.shuffle(sentences)
124
+ self.splits = {}
125
+ if args.debug:
126
+ self.splits['val'] = sentences
127
+ self.splits['test'] = sentences
128
+ self.splits['train'] = sentences
129
+ else:
130
+ self.splits['val'] = sentences[:TOPIC_VAL_SIZE]
131
+ self.splits['test'] = sentences[TOPIC_VAL_SIZE:2*TOPIC_VAL_SIZE]
132
+ self.splits['train'] = sentences[2*TOPIC_VAL_SIZE:]
133
+
134
+ if args.dataset_info is not None:
135
+ print('loading dataset info from file')
136
+ with open(args.dataset_info, 'rb') as rf:
137
+ dataset_info = pickle.load(rf)
138
+ self.vocab, self.total_words, self.index2word, self.word2index, self.glove_embeddings = \
139
+ dataset_info.vocab, dataset_info.total_words, dataset_info.index2word, dataset_info.word2index, dataset_info.glove_embeddings
140
+ self.dataset_info = dataset_info
141
+ else:
142
+ print('generating dataset info from scratch')
143
+ words_values = list(self.vocab.items())
144
+ words_values = sorted(words_values, key=lambda x: x[1], reverse=True)
145
+ if args.glove_file is None:
146
+ print('no glove embeddings given')
147
+ for word, _ in words_values[VOCAB_SIZE:]: # only use somewhat common tokens
148
+ del self.vocab[word]
149
+ glove_embeddings = None
150
+ else:
151
+ print('loading glove embeddings')
152
+ glove_embeddings = {}
153
+ with open(args.glove_file, 'r') as rf:
154
+ for i, line in enumerate(rf):
155
+ if i % GLOVE_PRINT_PROGRESS_FREQ == 0:
156
+ print(i)
157
+ line = line.strip().split()
158
+ if len(line) != GLOVE_DIM + 1:
159
+ continue # skip multi-word embeddings which are rare anyway
160
+ glove_embeddings[line[0]] = [float(x) for x in line[1:]]
161
+ for word, _ in words_values:
162
+ if word not in glove_embeddings:
163
+ del self.vocab[word]
164
+ self.total_words = sum(self.vocab.values())
165
+ self.index2word = [PAD_TOKEN] + sorted(list(self.vocab.keys()))
166
+ self.word2index = {s: i for i, s in enumerate(self.index2word)}
167
+ self.vocab = dict(self.vocab) # so we can pickle later
168
+ if glove_embeddings is None:
169
+ self.glove_embeddings = None
170
+ else:
171
+ self.glove_embeddings = torch.stack([torch.zeros(GLOVE_DIM)] + [torch.Tensor(glove_embeddings[word]) for word in self.index2word[1:]], dim=0)
172
+
173
+ self.dataset_info = DatasetInfo(index2word=self.index2word,
174
+ word2index=self.word2index,
175
+ total_words=self.total_words,
176
+ vocab=self.vocab,
177
+ glove_embeddings=self.glove_embeddings)
178
+
179
+ if self.rhyme:
180
+ if args.rhyme_info is not None:
181
+ print('loading rhyme info from file')
182
+ with open(args.rhyme_info, 'rb') as rf:
183
+ self.rhyme_info = pickle.load(rf)
184
+ else:
185
+ self.rhyme_info = load_rhyme_info(self.index2word, self.vocab)
186
+ self.word2rhyme_group, self.rhyme_group_counts, self.rhyme_groups, self.index2rhyme_group, self.rhyme_group2index, self.total_rhyme_groups = \
187
+ defaultdict(lambda: UNKNOWN_RHYME_GROUP, self.rhyme_info.word2rhyme_group), self.rhyme_info.rhyme_group_counts, self.rhyme_info.rhyme_groups, self.rhyme_info.index2rhyme_group, self.rhyme_info.rhyme_group2index, self.rhyme_info.total_rhyme_groups
188
+
189
+ print('done loading data')
190
+ print('split sizes:')
191
+ for key in ['train', 'val', 'test']:
192
+ print(key, len(self.splits[key]))
193
+ if not self.formality:
194
+ print('total words', self.total_words)
195
+ print('vocab size', len(self.index2word))
196
+
197
+
198
+ def shuffle(self, split, seed=None):
199
+ assert split in ['train', 'val', 'test']
200
+ if seed is not None:
201
+ random.seed(seed)
202
+ random.shuffle(self.splits[split])
203
+
204
+
205
+ def loader(self, split, num_workers=20, indices=None):
206
+ assert split in ['train', 'val', 'test']
207
+ data = self.splits[split] if indices is None else [self.splits[split][i] for i in indices]
208
+ return torch.utils.data.DataLoader(SplitLoader(data, self), batch_size=self.batch_size, pin_memory=True, collate_fn=collate, num_workers=num_workers)
209
+
210
+
211
+ class SplitLoader(torch.utils.data.IterableDataset):
212
+ def __init__(self, data, parent):
213
+ super(SplitLoader).__init__()
214
+ self.data = data
215
+ self.pos = 0
216
+ self.parent = parent
217
+
218
+
219
+ def __len__(self):
220
+ return len(self.data)
221
+
222
+
223
+ def __iter__(self):
224
+ return self
225
+
226
+
227
+ def __next__(self):
228
+ increment = 1
229
+ worker_info = torch.utils.data.get_worker_info()
230
+ if worker_info is not None: # # in a worker process
231
+ increment = worker_info.num_workers
232
+ worker_id = worker_info.id
233
+ if self.pos == 0:
234
+ self.pos = worker_id
235
+ valid = False
236
+ while not valid:
237
+ if self.pos >= len(self):
238
+ raise StopIteration
239
+ if self.parent.topic:
240
+ failed = False
241
+ future_word_num_syllables, rhyme_group_index, syllables_to_go = -1, -1, -1
242
+ raw_sentence, classification_label = self.data[self.pos], -1
243
+ original_sentence = raw_sentence.split()
244
+ sentence = self.parent.tokenizer.encode(raw_sentence, return_tensors='pt')[0]
245
+ length = len(sentence)
246
+ min_sentence_length = MIN_SENTENCE_LENGTH
247
+ if len(sentence) > min_sentence_length: # set to 3. well, everything in data is > 3 for the bag of words task
248
+ pos_to_split = random.randint(1, length - 1) # for lm, learn all positions at once
249
+ inp = sentence[:pos_to_split]
250
+ length = len(inp)
251
+ num_words_in_input = len(self.parent.tokenizer.decode(inp).split())
252
+ if not failed and num_words_in_input < len(original_sentence):
253
+ future_word_position_max = len(original_sentence) - 1
254
+ future_word_position = random.randint(num_words_in_input-1, future_word_position_max) # allow the last possibly partial word though
255
+ future_word = original_sentence[future_word_position]
256
+ unstripped_future_word = future_word
257
+ future_word = future_word.strip().strip(string.punctuation) # NOTE: we didn't strip punctuation for the topic bag of words paper experiments for our method. it doesn't make much difference, though.
258
+ if not failed and future_word in self.parent.word2index.keys():
259
+ word_log_prob = math.log(self.parent.vocab[future_word] / self.parent.total_words) # roughly baseline prob of word under noise model
260
+ future_word = self.parent.word2index[future_word]
261
+ pad_id = self.parent.gpt_pad_id
262
+ example = (inp, length, future_word, word_log_prob, pad_id, classification_label, syllables_to_go, future_word_num_syllables, rhyme_group_index)
263
+ valid = not failed
264
+ elif self.parent.formality:
265
+ future_word_num_syllables, rhyme_group_index, syllables_to_go = -1, -1, -1
266
+ raw_sentence, classification_label = self.data[self.pos]
267
+ original_sentence = raw_sentence.split()
268
+ sentence = self.parent.tokenizer.encode(raw_sentence, return_tensors='pt')[0]
269
+ length = len(sentence)
270
+ min_sentence_length = MIN_SENTENCE_LENGTH
271
+ if len(sentence) > min_sentence_length: # set to 3. well, everything in data is > 3 for the bag of words task
272
+ pos_to_split = length # no need to split; we're going to train on all possible prefixes simultaneously for efficiency
273
+ inp = sentence[:pos_to_split]
274
+ length = len(inp)
275
+ num_words_in_input = len(self.parent.tokenizer.decode(inp).split())
276
+ # only look up to 10 words ahead if we're doing count syllables, since we'll filter out anything more than 10 syllables ahead anyway
277
+ future_word_position_max = len(original_sentence) - 1
278
+ future_word_position = 0
279
+ future_word = 'placeholder'
280
+ unstripped_future_word = future_word
281
+ future_word = future_word.strip().strip(string.punctuation) # NOTE: we didn't strip punctuation for the topic bag of words paper experiments for our method. it doesn't make much difference, though.
282
+ word_log_prob, future_word = 0, 0
283
+ pad_id = self.parent.gpt_pad_id
284
+ example = (inp, length, future_word, word_log_prob, pad_id, classification_label, syllables_to_go, future_word_num_syllables, rhyme_group_index)
285
+ valid = True
286
+ elif self.parent.iambic:
287
+ failed = False
288
+ future_word_num_syllables, rhyme_group_index, syllables_to_go = -1, -1, -1
289
+ raw_sentence, classification_label = self.data[self.pos], -1
290
+ original_sentence = raw_sentence.split()
291
+ sentence = self.parent.tokenizer.encode(raw_sentence, return_tensors='pt')[0]
292
+ length = len(sentence)
293
+ min_sentence_length = MIN_SENTENCE_LENGTH
294
+ if len(sentence) > min_sentence_length: # set to 3. well, everything in data is > 3 for the bag of words task
295
+ pos_to_split = random.randint(0, length - 1)
296
+ # try to get a subseq of exactly 10 syllables
297
+ inp = sentence[pos_to_split:]
298
+ num_syllables = 0
299
+ checked = False
300
+ for i in range(1, len(inp)):
301
+ decoded = self.parent.tokenizer.decode(inp[:i])
302
+ num_syllables = count_syllables(decoded)
303
+ if num_syllables > POETRY_LINE_SYLLABLES:
304
+ inp = inp[:i-1] # might get a few data points where the split is in the middle of a word, but it should be ok for learning.
305
+ last_line_length = i-1
306
+ decoded = self.parent.tokenizer.decode(inp)
307
+ num_syllables = count_syllables(decoded)
308
+ checked = True
309
+ break
310
+ if not checked or num_syllables != POETRY_LINE_SYLLABLES:
311
+ failed = True
312
+ length = len(inp)
313
+ num_words_in_input = len(self.parent.tokenizer.decode(inp).split())
314
+ classification_label = [is_iambic(self.parent.tokenizer.decode(inp)) for _ in range(length)] # predict for whole seq including future
315
+ # only look up to 10 words ahead if we're doing count syllables, since we'll filter out anything more than 10 syllables ahead anyway
316
+ future_word_position_max = len(original_sentence) - 1
317
+ future_word_position = 0
318
+ future_word = 'placeholder'
319
+ unstripped_future_word = future_word
320
+ future_word = future_word.strip().strip(string.punctuation) # NOTE: we didn't strip punctuation for the topic bag of words paper experiments for our method. it doesn't make much difference, though.
321
+ if not failed:
322
+ word_log_prob, future_word = 0, 0
323
+ pad_id = self.parent.gpt_pad_id
324
+ example = (inp, length, future_word, word_log_prob, pad_id, classification_label, syllables_to_go, future_word_num_syllables, rhyme_group_index)
325
+ valid = not failed
326
+ elif self.parent.rhyme:
327
+ failed = False
328
+ future_word_num_syllables, rhyme_group_index = -1, -1
329
+ raw_sentence, classification_label = self.data[self.pos], -1
330
+ original_sentence = raw_sentence.split()
331
+ sentence = self.parent.tokenizer.encode(raw_sentence, return_tensors='pt')[0]
332
+ length = len(sentence)
333
+ min_sentence_length = MIN_SENTENCE_LENGTH
334
+ if len(sentence) > min_sentence_length: # set to 3. well, everything in data is > 3 for the bag of words task
335
+ pos_to_split = random.randint(1, length - 1) # for lm, learn all positions at once
336
+ inp = sentence[:pos_to_split]
337
+ length = len(inp)
338
+ num_words_in_input = len(self.parent.tokenizer.decode(inp).split())
339
+ if not failed and num_words_in_input < len(original_sentence):
340
+ # only look up to 10 words ahead if we're doing count syllables, since we'll filter out anything more than 10 syllables ahead anyway
341
+ future_word_position_max = min(len(original_sentence) - 1, num_words_in_input + MAX_COUNT_SYLLABLE_DIST)
342
+ future_word_position = random.randint(num_words_in_input-1, future_word_position_max) # allow the last possibly partial word though
343
+ future_word = original_sentence[future_word_position]
344
+ unstripped_future_word = future_word
345
+ future_word = future_word.strip().strip(string.punctuation) # NOTE: we didn't strip punctuation for the topic bag of words paper experiments for our method. it doesn't make much difference, though.
346
+
347
+ words_in_between = original_sentence[num_words_in_input-1:future_word_position+1]
348
+ syllables_to_go = count_syllables(' '.join(words_in_between))
349
+ if syllables_to_go > MAX_COUNT_SYLLABLE_DIST:
350
+ failed = True
351
+ future_word_num_syllables = count_syllables(future_word)
352
+ rhyme_group = self.parent.word2rhyme_group[future_word]
353
+ rhyme_group_index = self.parent.rhyme_group2index[rhyme_group]
354
+ # truncate context a bit since we're just doing couplets. random length from 1 to max desired length for this purpose.
355
+ desired_length = random.randint(1, MAX_COUNT_SYLLABLE_INPUT_LENGTH)
356
+ inp = inp[-desired_length:]
357
+ length = len(inp)
358
+
359
+ if not failed and future_word in self.parent.word2index.keys():
360
+ word_log_prob = math.log(self.parent.rhyme_group_counts[rhyme_group] / self.parent.total_rhyme_groups)
361
+ future_word = rhyme_group_index # future conditioning is just the rhyme group in this case
362
+ pad_id = self.parent.gpt_pad_id
363
+ example = (inp, length, future_word, word_log_prob, pad_id, classification_label, syllables_to_go, future_word_num_syllables, rhyme_group_index)
364
+ valid = not failed
365
+ elif self.parent.newline:
366
+ failed = False
367
+ future_word_num_syllables, rhyme_group_index = -1, -1
368
+ raw_sentence, classification_label = self.data[self.pos], -1
369
+ original_sentence = raw_sentence.split()
370
+ sentence = self.parent.tokenizer.encode(raw_sentence, return_tensors='pt')[0]
371
+ length = len(sentence)
372
+ min_sentence_length = MIN_SENTENCE_LENGTH
373
+ if len(sentence) > min_sentence_length: # set to 3. well, everything in data is > 3 for the bag of words task
374
+ pos_to_split = random.randint(1, length - 1) # for lm, learn all positions at once
375
+ inp = sentence[:pos_to_split]
376
+ while pos_to_split < len(sentence):
377
+ if len(self.parent.tokenizer.decode(inp).split()) == len(self.parent.tokenizer.decode(sentence[:pos_to_split + 1]).split()):
378
+ pos_to_split += 1
379
+ inp = sentence[:pos_to_split]
380
+ else:
381
+ break
382
+ length = len(inp)
383
+ num_words_in_input = len(self.parent.tokenizer.decode(inp).split())
384
+ if not failed and num_words_in_input < len(original_sentence):
385
+ # only look up to 10 words ahead if we're doing count syllables, since we'll filter out anything more than 10 syllables ahead anyway
386
+ future_word_position_max = len(original_sentence) - 1
387
+ future_word_position = random.randint(num_words_in_input-1, future_word_position_max) # allow the last possibly partial word though
388
+ future_word = original_sentence[future_word_position]
389
+ unstripped_future_word = future_word
390
+ future_word = future_word.strip().strip(string.punctuation) # NOTE: we didn't strip punctuation for the topic bag of words paper experiments for our method. it doesn't make much difference, though.
391
+
392
+ # future_word = original_sentence[-1] # useful for debugging
393
+ words_in_between = original_sentence[num_words_in_input-1:future_word_position+1]
394
+ syllables_to_go = count_syllables(' '.join(words_in_between))
395
+ if syllables_to_go > MAX_COUNT_SYLLABLE_DIST:
396
+ failed = True
397
+ # truncate context a bit since we're just doing couplets. random length from 1 to max desired length for this purpose.
398
+ desired_length = random.randint(1, MAX_COUNT_SYLLABLE_INPUT_LENGTH)
399
+ # desired_length = 10 # useful for debugging
400
+ inp = inp[-desired_length:]
401
+ length = len(inp)
402
+ true_label = 1 if unstripped_future_word.strip()[-1] in PHRASE_ENDS else 0 # common ways to end a phrase
403
+ classification_label = [-1 for _ in range(length)]
404
+ classification_label[-1] = true_label # only learn at the last position
405
+ if not failed and future_word in self.parent.word2index.keys():
406
+ word_log_prob = math.log(self.parent.vocab[future_word] / self.parent.total_words) # roughly baseline prob of word under noise model
407
+ future_word = self.parent.word2index[future_word]
408
+ pad_id = self.parent.gpt_pad_id
409
+ example = (inp, length, future_word, word_log_prob, pad_id, classification_label, syllables_to_go, future_word_num_syllables, rhyme_group_index)
410
+ valid = not failed
411
+ else:
412
+ raise NotImplementedError
413
+
414
+ self.pos += increment
415
+ return example
naacl-2021-fudge-controlled-generation/eval_formality_metrics.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from argparse import ArgumentParser
2
+ import pickle
3
+ import os
4
+ import math
5
+
6
+ import sacrebleu
7
+ import numpy as np
8
+ import torch
9
+ from transformers import AutoTokenizer, AutoModelWithLMHead, pipeline, set_seed, GPT2Tokenizer, GPT2Model, MarianTokenizer, MarianMTModel
10
+
11
+ from constants import *
12
+ from model import Model
13
+ from util import save_checkpoint, ProgressMeter, AverageMeter, num_params
14
+
15
+ def avg_formality(preds, model, tokenizer, device='cuda'):
16
+ probs = []
17
+ for sent in preds:
18
+ encoded_input = tokenizer.encode(sent, return_tensors='pt').to(device)
19
+ lengths = torch.LongTensor([encoded_input.shape[1]]).to(device)
20
+ scores = model(encoded_input, lengths=lengths) # batch x seq
21
+ score = scores.flatten()[-1].item()
22
+ probs.append(math.exp(score) / (1 + math.exp(score))) # sigmoided score = prob
23
+ return np.mean(probs)
24
+
25
+ if __name__=='__main__':
26
+ parser = ArgumentParser()
27
+ parser.add_argument('--pred', type=str)
28
+ parser.add_argument('--ref', type=str, nargs='*', help='bleu refs')
29
+ parser.add_argument('--ckpt', type=str, help='formality classifier')
30
+ parser.add_argument('--dataset_info', type=str)
31
+ parser.add_argument('--device', type=str, default='cuda', choices=['cpu', 'cuda'])
32
+ parser.add_argument('--model_string', type=str, default='Helsinki-NLP/opus-mt-es-en')
33
+
34
+ args = parser.parse_args()
35
+
36
+ # refs = [['The dog bit the man.', 'It was not unexpected.', 'The man bit him first.'],
37
+ # ['The dog had bit the man.', 'No one was surprised.', 'The man had bitten the dog.']]
38
+ # sys = ['The dog bit the man.', "It wasn't surprising.", 'The man had just bitten him.']
39
+ print('num ref files', len(args.ref))
40
+ pred = []
41
+ with open(args.pred, 'r') as rf:
42
+ for line in rf:
43
+ pred.append(line.strip())
44
+ refs = []
45
+ for ref_file in args.ref:
46
+ ref = []
47
+ with open(ref_file, 'r') as rf:
48
+ for line in rf:
49
+ ref.append(line.strip())
50
+ assert len(ref) == len(pred)
51
+ refs.append(ref)
52
+ bleu = sacrebleu.corpus_bleu(pred, refs)
53
+ print('BLEU score:', bleu.score)
54
+
55
+ with open(args.dataset_info, 'rb') as rf:
56
+ dataset_info = pickle.load(rf)
57
+
58
+ tokenizer = MarianTokenizer.from_pretrained(args.model_string)
59
+ tokenizer.add_special_tokens({'pad_token': PAD_TOKEN})
60
+ pad_id = tokenizer.encode(PAD_TOKEN)[0]
61
+
62
+ checkpoint = torch.load(args.ckpt, map_location=args.device)
63
+ model_args = checkpoint['args']
64
+ conditioning_model = Model(model_args, pad_id, len(dataset_info.index2word)) # no need to get the glove embeddings when reloading since they're saved in model ckpt anyway
65
+ conditioning_model.load_state_dict(checkpoint['state_dict'])
66
+ conditioning_model = conditioning_model.to(args.device)
67
+ conditioning_model.eval()
68
+ print("=> loaded checkpoint '{}' (epoch {})"
69
+ .format(args.ckpt, checkpoint['epoch']))
70
+ print('num params', num_params(conditioning_model))
71
+
72
+ print('avg formality prob according to model', avg_formality(pred, conditioning_model, tokenizer, device=args.device))
73
+
naacl-2021-fudge-controlled-generation/eval_poetry_metrics.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from argparse import ArgumentParser
2
+ import math
3
+ import string
4
+
5
+ from tqdm import tqdm
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from transformers import AutoTokenizer, AutoModelWithLMHead, AutoModelForSequenceClassification
10
+
11
+ from poetry_util import is_iambic, perfect_rhyme_end, count_syllables
12
+ from constants import *
13
+
14
+
15
+ def conditional_perplexity(prefix, pred, tokenizer, model, device='cuda', sep_losses=False):
16
+ # calculate perplexity on pred only, conditioned on prefix
17
+ sentence = prefix + pred
18
+ sos_token = tokenizer.decode([0])
19
+ prefix_tensor_input = tokenizer.encode(sos_token + prefix.replace(EOT_TOKEN, ' ').strip(), return_tensors='pt').to(device)
20
+ full_tensor_input = tokenizer.encode(sos_token + sentence.replace(EOT_TOKEN, ' ').strip(), return_tensors='pt').to(device)
21
+ if sep_losses:
22
+ prefix_loss = model(prefix_tensor_input, labels=prefix_tensor_input)[0].sum()
23
+ full_loss = model(full_tensor_input, labels=full_tensor_input)[0].sum()
24
+ else:
25
+ prefix_loss = model(prefix_tensor_input, labels=prefix_tensor_input)[0] * (prefix_tensor_input.shape[1]-1) # neg log prob of prefix
26
+ full_loss = model(full_tensor_input, labels=full_tensor_input)[0] * (full_tensor_input.shape[1]-1) # neg log prob of full seq
27
+ pred_loss = full_loss - prefix_loss # neg log prob of preds given prefix
28
+ avg_pred_loss = pred_loss / (full_tensor_input.shape[1] - prefix_tensor_input.shape[1])
29
+ return math.exp(avg_pred_loss.item())
30
+
31
+
32
+ def grammaticality(sentences, tokenizer, model, device='cuda'):
33
+ with torch.no_grad():
34
+ total_good = 0
35
+ for sent in tqdm(sentences, total=len(sentences)):
36
+ good_prob = F.softmax(model(tokenizer.encode(sent, return_tensors='pt').to(device))[0].flatten(), dim=0)[1]
37
+ total_good += good_prob
38
+ return total_good / len(sentences) # avg probability of grammaticality according to model
39
+
40
+
41
+ def distinctness(sentences):
42
+ d1 = set()
43
+ d2 = set()
44
+ d3 = set()
45
+ total_words = 0
46
+ for sentence in sentences:
47
+ o = sentence.split(' ')
48
+ total_words += len(o)
49
+ d1.update(o)
50
+ for i in range(len(o) - 1):
51
+ d2.add(o[i] + '_' + o[i+1])
52
+ for i in range(len(o) - 2):
53
+ d3.add(o[i] + '_' + o[i+1] + '_' + o[i+2])
54
+ return len(d1) / total_words, len(d2) / total_words, len(d3) / total_words
55
+
56
+
57
+ if __name__=='__main__':
58
+ parser = ArgumentParser()
59
+ parser.add_argument('--pred_file', type=str)
60
+ parser.add_argument('--prefix_file', type=str)
61
+ parser.add_argument('--device', type=str, default='cuda', choices=['cpu', 'cuda'])
62
+ args = parser.parse_args()
63
+
64
+ preds = []
65
+ with open(args.pred_file, 'r') as rf:
66
+ for line in rf:
67
+ preds.append(line[:-1]) # drop \n but not beginning spaces if any
68
+ prefixes = []
69
+ with open(args.prefix_file, 'r') as rf:
70
+ for line in rf:
71
+ prefixes.append(line.strip())
72
+ assert len(prefixes) == len(preds)
73
+ rhymes = 0
74
+ iambic = 0
75
+ ten_syllables = 0
76
+ end = 0
77
+ diff_rhymes = 0
78
+ all_success = 0
79
+ total = len(prefixes)
80
+ for prefix, pred in zip(prefixes, preds):
81
+ if is_iambic(pred):
82
+ iambic += 1
83
+ if perfect_rhyme_end(prefix, pred):
84
+ rhymes += 1
85
+ if prefix.split()[-1].strip(string.punctuation) != pred.split()[-1].strip(string.punctuation):
86
+ diff_rhymes += 1
87
+ if count_syllables(pred) == 10:
88
+ ten_syllables += 1
89
+ if pred.strip()[-1] in PHRASE_ENDS:
90
+ end += 1
91
+ if is_iambic(pred) and perfect_rhyme_end(prefix, pred) and count_syllables(pred) == 10 and pred.strip()[-1] in PHRASE_ENDS:
92
+ all_success += 1
93
+ print('iambic', iambic, 'out of', total, ', frac', iambic / total)
94
+ print('rhymes', rhymes, 'out of', total, ', frac', rhymes / total)
95
+ print('end sentence', end, 'out of', total, ', frac', end / total)
96
+ print('10 syllables', ten_syllables, 'out of', total, ', frac', ten_syllables / total)
97
+ print('all success', all_success, 'out of', total, ', frac', all_success / total)
98
+ print('rhymes with diff word', diff_rhymes, 'out of', total, ', frac', diff_rhymes / total)
99
+
100
+ print('distinctness', distinctness(preds))
101
+
102
+ grammar_tokenizer = AutoTokenizer.from_pretrained('textattack/roberta-base-CoLA')
103
+ grammar_model = AutoModelForSequenceClassification.from_pretrained('textattack/roberta-base-CoLA').to(args.device)
104
+ grammar_model.eval()
105
+ print('grammaticality', grammaticality(preds, grammar_tokenizer, grammar_model, device=args.device))
106
+
107
+ perplexities = []
108
+ eval_tokenizer = AutoTokenizer.from_pretrained('transfo-xl-wt103')
109
+ eval_model = AutoModelWithLMHead.from_pretrained('transfo-xl-wt103').to(args.device)
110
+ eval_model.eval()
111
+ for prefix, pred in zip(prefixes, preds):
112
+ perplexities.append(conditional_perplexity(prefix, pred, eval_tokenizer, eval_model, device=args.device, sep_losses=True))
113
+ print('transformer xl perplexity', np.mean(perplexities), '+/-', np.std(perplexities))
114
+
115
+ perplexities = []
116
+ eval_tokenizer = AutoTokenizer.from_pretrained('openai-gpt')
117
+ eval_model = AutoModelWithLMHead.from_pretrained('openai-gpt').to(args.device)
118
+ eval_model.eval()
119
+ for prefix, pred in zip(prefixes, preds):
120
+ perplexities.append(conditional_perplexity(prefix, pred, eval_tokenizer, eval_model, device=args.device))
121
+ print('gpt perplexity', np.mean(perplexities), '+/-', np.std(perplexities))
122
+
123
+ # NOTE: uncomment this section with the path to the Shakespeare-finetuned GPT to evaluate this metric. it's in ckpt/poetry/gpt_finetune_shakespeare.pth.tar.
124
+ # eval_tokenizer = AutoTokenizer.from_pretrained('openai-gpt')
125
+ # eval_model = AutoModelWithLMHead.from_pretrained('openai-gpt').to(args.device)
126
+ # checkpoint = torch.load('***PATH_TO_SHAKESPEARE_FINETUNED_GPT***', map_location=args.device)
127
+ # mod_dict = {}
128
+ # for key in checkpoint['state_dict']:
129
+ # mod_dict[key.replace('classifier.', '')] = checkpoint['state_dict'][key]
130
+ # eval_model.load_state_dict(mod_dict)
131
+ # eval_model.eval()
132
+ # perplexities = []
133
+ # for prefix, pred in zip(prefixes, preds):
134
+ # perplexities.append(conditional_perplexity(prefix, pred, eval_tokenizer, eval_model, device=args.device))
135
+ # print('shakespeare finetuned perplexity', np.mean(perplexities), '+/-', np.std(perplexities))
naacl-2021-fudge-controlled-generation/eval_topic_metrics.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import time
4
+ import pickle
5
+ import math
6
+ from argparse import ArgumentParser
7
+ from collections import defaultdict
8
+ import string
9
+ import csv
10
+
11
+ from tqdm import tqdm
12
+ import numpy as np
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+ from transformers import AutoTokenizer, AutoModelWithLMHead, AutoModelForSequenceClassification
17
+
18
+ from data import Dataset
19
+ from model import Model
20
+ from util import save_checkpoint, ProgressMeter, AverageMeter, num_params, pad_mask
21
+ from predict import predict
22
+ from constants import *
23
+
24
+ def tw_topic_eval(sentences, category, tw_dir, cap=None):
25
+ # num matches of distinct words
26
+ words = []
27
+ with open(os.path.join(tw_dir, category + '.txt'), 'r') as rf:
28
+ for line in rf:
29
+ words.append(line.strip().lower())
30
+ num_match = 0
31
+ for sent in sentences:
32
+ sent_match = 0
33
+ sent = sent.strip().lower().split()
34
+ sent = [tok.strip(string.punctuation) for tok in sent]
35
+ for word in words:
36
+ if word in sent:
37
+ sent_match += 1
38
+ if cap is None:
39
+ num_match += sent_match
40
+ else:
41
+ num_match += min(cap, sent_match)
42
+ return num_match
43
+
44
+
45
+ def perplexity(sentences, tokenizer, model, device='cuda'):
46
+ # calculate perplexity
47
+ with torch.no_grad():
48
+ ppl = []
49
+ sos_token = tokenizer.decode([0])
50
+ for sentence in tqdm(sentences, total=len(sentences)):
51
+ full_tensor_input = tokenizer.encode(sos_token + sentence.replace(EOT_TOKEN, ' ').strip(), return_tensors='pt').to(device)
52
+ full_loss = model(full_tensor_input, labels=full_tensor_input)[0].mean()
53
+ ppl.append(torch.exp(full_loss).flatten().cpu().item())
54
+ return np.mean(ppl), np.std(ppl)
55
+
56
+
57
+ def grammaticality(sentences, tokenizer, model, device='cuda'):
58
+ with torch.no_grad():
59
+ total_good = 0
60
+ for sent in tqdm(sentences, total=len(sentences)):
61
+ good_prob = F.softmax(model(tokenizer.encode(sent, return_tensors='pt').to(device))[0].flatten(), dim=0)[1]
62
+ total_good += good_prob
63
+ return total_good / len(sentences) # avg probability of grammaticality according to model
64
+
65
+
66
+ def distinctness(results):
67
+ d1, d2, d3 = defaultdict(lambda: set()), defaultdict(lambda: set()), defaultdict(lambda: set())
68
+ total_words = defaultdict(lambda: 0)
69
+ for cw, outputs in results.items():
70
+ for o in outputs:
71
+ o = o.replace(EOT_TOKEN, ' ').strip().split(' ')
72
+ o = [str(x) for x in o]
73
+ total_words[cw] += len(o)
74
+ d1[cw].update(o)
75
+ for i in range(len(o) - 1):
76
+ d2[cw].add(o[i] + ' ' + o[i+1])
77
+ for i in range(len(o) - 2):
78
+ d3[cw].add(o[i] + ' ' + o[i+1] + ' ' + o[i+2])
79
+ return_info = []
80
+ avg_d1, avg_d2, avg_d3 = 0, 0, 0
81
+ for cw in total_words.keys():
82
+ return_info.append((cw, 'DISTINCTNESS', len(d1[cw]) / total_words[cw], len(d2[cw]) / total_words[cw], len(d3[cw]) / total_words[cw]))
83
+ avg_d1 += len(d1[cw]) / total_words[cw]
84
+ avg_d2 += len(d2[cw]) / total_words[cw]
85
+ avg_d3 += len(d3[cw]) / total_words[cw]
86
+ avg_d1, avg_d2, avg_d3 = avg_d1 / len(total_words.keys()), avg_d2 / len(total_words.keys()), avg_d3 / len(total_words.keys())
87
+ return return_info, (avg_d1, avg_d2, avg_d3)
88
+
89
+
90
+ if __name__=='__main__':
91
+ parser = ArgumentParser()
92
+ parser.add_argument('--log_file', type=str, required=True, help='where to load results from')
93
+ parser.add_argument('--tw_dir', type=str, default='test_wordlists', help='test wordlists')
94
+ parser.add_argument('--batch_size', type=int, default=8, help='max samples at a time')
95
+ parser.add_argument('--cap_per_example', type=int, default=None, help='max matches to count per sentence')
96
+ parser.add_argument('--device', type=str, default='cuda', choices=['cpu', 'cuda'])
97
+ args = parser.parse_args()
98
+
99
+ tw_topic_match_c_total = 0
100
+ category_totals_c = defaultdict(lambda:0)
101
+ results = defaultdict(lambda: [])
102
+ with open(args.log_file, 'r') as rf:
103
+ data = list(csv.DictReader(rf))
104
+ for line in data:
105
+ results[line['category']].append(line['generation'])
106
+
107
+ all_c_sents = []
108
+ for category, condition_results in results.items():
109
+ tw_topic_match_c = tw_topic_eval(condition_results, category, args.tw_dir, cap=args.cap_per_example)
110
+ tw_topic_match_c_total += tw_topic_match_c
111
+ category_totals_c[category] += tw_topic_match_c
112
+ all_c_sents += condition_results
113
+
114
+ print('Test wordlist matches (divide by num outputs to get the Success metric):', tw_topic_match_c_total)
115
+ print('per category:', category_totals_c)
116
+
117
+ dist_info_by_category, dist_overall = distinctness(results)
118
+ print('Overall avg distinctness:', dist_overall)
119
+ print('per category:', dist_info_by_category)
120
+
121
+ grammar_tokenizer = AutoTokenizer.from_pretrained('textattack/roberta-base-CoLA')
122
+ grammar_model = AutoModelForSequenceClassification.from_pretrained('textattack/roberta-base-CoLA').to(args.device)
123
+ grammar_model.eval()
124
+ print('grammaticality:', grammaticality(all_c_sents, grammar_tokenizer, grammar_model, device=args.device))
125
+
126
+ eval_tokenizer = AutoTokenizer.from_pretrained('openai-gpt')
127
+ eval_model = AutoModelWithLMHead.from_pretrained('openai-gpt').to(args.device)
128
+ eval_model.eval()
129
+ print('GPT perplexity:', perplexity(all_c_sents, eval_tokenizer, eval_model))
130
+
131
+ eval_tokenizer = AutoTokenizer.from_pretrained('transfo-xl-wt103')
132
+ eval_model = AutoModelWithLMHead.from_pretrained('transfo-xl-wt103').to(args.device)
133
+ eval_model.eval()
134
+ print('TFXL perplexity:', perplexity(all_c_sents, eval_tokenizer, eval_model))
naacl-2021-fudge-controlled-generation/evaluate_clickbait.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import time
4
+ import pickle
5
+ import math
6
+ from argparse import ArgumentParser
7
+
8
+ from typing import Iterable, List, Optional, Tuple
9
+
10
+ from tqdm import tqdm
11
+ import numpy as np
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+ from transformers import AutoTokenizer, AutoModelWithLMHead
16
+ from torch import Tensor
17
+
18
+ from data import Dataset
19
+ from model import Model
20
+ from util import num_params
21
+ from constants import *
22
+
23
+
24
+
25
+ tokenizer = AutoTokenizer.from_pretrained('google/pegasus-xsum')
26
+ classifier_tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-mpnet-base-v2')
27
+
28
+
29
+ def main(args):
30
+ with open(args.dataset_info, 'rb') as rf:
31
+ dataset_info = pickle.load(rf)
32
+
33
+ article_content = """Australian actor Guy Pearce will return for the iconic soap Neighbours finale on August 1 to reprise his role as Mike Young.
34
+ Guy, 54, played the troubled Mike from 1986 to 1989, and is now set to make a comeback on the show after 33 years, Metro.co.uk reports.
35
+ The star's character arcs explored the implications of domestic abuse, student-teacher relationships and dealing with loss of loved ones.
36
+ Speaking to Metro.co.uk, Guy said: 'It is very exciting and surreal at the same time being back on set again, however it feels like coming home.
37
+ 'It's where it all started for me professionally. I've been asked to come back on occasions over the years and wondered if it was the right thing
38
+ to do, but once I knew the show was finishing, I knew I had to do it.'He added that there is 'nothing like being here all together again'
39
+ , even though he's had a chance to catch-up with other cast members."""
40
+
41
+ tokenizer.add_special_tokens({'pad_token': PAD_TOKEN})
42
+ pad_id = tokenizer.encode(PAD_TOKEN)[0]
43
+
44
+ #For loading Clickbait summarizer
45
+ model = AutoModelWithLMHead.from_pretrained(args.model_string, return_dict=True).to(args.device)
46
+
47
+ model.eval()
48
+
49
+ checkpoint = torch.load(args.ckpt, map_location=args.device)
50
+ model_args = checkpoint['args']
51
+ conditioning_model = Model(model_args, pad_id, len(dataset_info.index2word)) # no need to get the glove embeddings when reloading since they're saved in model ckpt anyway
52
+ conditioning_model.load_state_dict(checkpoint['state_dict'])
53
+ conditioning_model = conditioning_model.to(args.device)
54
+ conditioning_model.eval()
55
+ print("=> loaded checkpoint '{}' (epoch {})"
56
+ .format(args.ckpt, checkpoint['epoch']))
57
+ print('num params', num_params(conditioning_model))
58
+
59
+ while True:
60
+ results = generate_clickbait(model,
61
+ tokenizer,
62
+ conditioning_model,
63
+ [args.input_text],
64
+ dataset_info,
65
+ precondition_topk=args.precondition_topk,
66
+ do_sample=args.do_sample,
67
+ length_cutoff=args.length_cutoff,
68
+ condition_lambda=args.condition_lambda,
69
+ article_content=article_content,
70
+ device=args.device)
71
+ # print(results)
72
+ import pdb; pdb.set_trace()
73
+
74
+
75
+ def generate_clickbait(model,
76
+ tokenizer,
77
+ conditioning_model,
78
+ input_text,
79
+ dataset_info,
80
+ precondition_topk,
81
+ length_cutoff,
82
+ condition_lambda=1.0,
83
+ article_content=None,
84
+ device='cuda'):
85
+ with torch.no_grad():
86
+ batch_size = len(input_text)
87
+ # encoded_input_article = [tokenizer.encode(article_content, return_tensors='pt',add_special_tokens=False).to(device)] # batch x seq
88
+ encoded_input_article = tokenizer(article_content, return_tensors='pt',add_special_tokens=False, max_length=512).to(device) # batch x seq
89
+ # encoded_input_article = torch.cat(encoded_input_article, dim=0)
90
+ # attention_mask = encoded_input_article.new_ones(encoded_input_article.shape).to(device)
91
+
92
+ # CHANGE=ko
93
+ encoded_input = tokenizer('<pad>', return_tensors='pt',add_special_tokens=False).to(device) # batch x seq
94
+ # encoded_input = tokenizer('<pad>'+ input_text[0], return_tensors='pt',add_special_tokens=False).to(device) # batch x seq
95
+ # encoded_input = torch.cat(encoded_input, dim=0)
96
+ encoded_input = encoded_input['input_ids']
97
+
98
+
99
+ lengths = torch.LongTensor([encoded_input.shape[1]]).to(device)
100
+ # lengths = 1
101
+
102
+ past = None
103
+ use_cache = True
104
+
105
+ # CHANGE
106
+ # model_kwargs = {'encoder_outputs': model.get_encoder()(encoded_input_article, attention_mask=attention_mask)}
107
+ # print(encoded_input_article)
108
+ # print(encoded_input_article['input_ids'].shape, encoded_input_article['attention_mask'].shape)
109
+ model_kwargs = {'encoder_outputs': model.get_encoder()(input_ids=encoded_input_article['input_ids'],
110
+ attention_mask=encoded_input_article['attention_mask'],
111
+ return_dict=True,
112
+ output_attentions=False,
113
+ output_hidden_states=False),
114
+ }
115
+
116
+ while lengths.max() < length_cutoff:
117
+ model_inputs = model.prepare_inputs_for_generation(
118
+ input_ids = encoded_input_article['input_ids'],
119
+ decoder_input_ids=encoded_input,
120
+ # past=past,
121
+ attention_mask=encoded_input_article['attention_mask'],
122
+ use_cache=use_cache,
123
+ **model_kwargs
124
+ )
125
+
126
+ outputs = model(**model_inputs, return_dict=True)
127
+ logits = outputs.logits[:, -1, :]
128
+
129
+ if "past_key_values" in outputs:
130
+ model_kwargs["past"] = outputs.past_key_values
131
+
132
+ # logits = model(encoded_input)[0][:, -1, :] # batch x vocab
133
+ top_logits, top_indices = logits.topk(precondition_topk, dim=1) # batch x topk
134
+ new_input_candidates = torch.cat([encoded_input.unsqueeze(1).expand(-1, precondition_topk, -1), top_indices.unsqueeze(2)], dim=2) # batch x topk x seq+1
135
+ expanded_lengths = (lengths + 1).unsqueeze(1).expand(batch_size, precondition_topk) # batch x topk
136
+
137
+ if condition_lambda == 0:
138
+ condition_logits = torch.zeros_like(top_logits).float()
139
+ condition_logits = condition_logits.view(batch_size, precondition_topk, -1) # batch x topk x N
140
+ else:
141
+ decoded_outputs = tokenizer.batch_decode(new_input_candidates.view(-1, new_input_candidates.size(-1)), clean_up_tokenization_spaces=False)
142
+ resulting_tokenization = classifier_tokenizer(decoded_outputs, add_special_tokens=False, padding='longest')
143
+ encoded_with_classifier = resulting_tokenization['input_ids']
144
+ attention_mask = torch.tensor(resulting_tokenization['attention_mask']).to(model.device)
145
+ tplus1_candidates_classifier = torch.tensor(encoded_with_classifier).view(batch_size, precondition_topk, -1).to(model.device)
146
+
147
+ condition_logits = conditioning_model(tplus1_candidates_classifier.flatten(0, 1), # batch*topk x seq+1
148
+ expanded_lengths.flatten(0, 1), # batch*topk
149
+ None,
150
+ None,
151
+ None,
152
+ attention_mask=attention_mask
153
+ )
154
+ condition_logits = condition_logits.view(batch_size, precondition_topk, -1) # batch x topk x N
155
+ condition_logits = condition_logits - torch.log(1 + torch.exp(condition_logits)) # get correct log probs
156
+
157
+ condition_logits = torch.mean(condition_logits, dim=2)
158
+ full_logits = top_logits + condition_logits * condition_lambda # batch x topk
159
+ post_logits, post_indices = full_logits.topk(precondition_topk, dim=1)
160
+ post_probs = F.softmax(post_logits, dim=1)
161
+ # index_into_top_indices = post_indices[torch.arange(batch_size).to(post_indices.device), torch.multinomial(post_probs, 1).flatten()] # batch
162
+ index_into_top_indices = post_indices[:, torch.multinomial(post_probs, 1).flatten()] # batch
163
+
164
+ # next_indices = top_indices[torch.arange(batch_size).to(top_indices.device), index_into_top_indices] # batch
165
+ next_indices = top_indices[:, index_into_top_indices] # batch
166
+
167
+ # encoded_input = torch.cat([encoded_input, next_indices.unsqueeze(1)], dim=1) # batch x seq+1
168
+ encoded_input = torch.cat([encoded_input, next_indices.squeeze(1)], dim=1)
169
+ lengths = lengths + 1 # batch
170
+
171
+ # print(tokenizer.decode(encoded_input[0], add_special_tokens=False))
172
+ return [tokenizer.decode(s) for s in encoded_input]
173
+
174
+
175
+ if __name__=='__main__':
176
+ parser = ArgumentParser()
177
+
178
+ # DATA
179
+ parser.add_argument('--ckpt', type=str, required=True)
180
+ parser.add_argument('--dataset_info', type=str, required=True, help='saved dataset info')
181
+ parser.add_argument('--model_string', type=str, default='Helsinki-NLP/opus-mt-es-en')
182
+
183
+ parser.add_argument('--in_file', type=str, default=None, required=True, help='text to run pred on')
184
+
185
+ parser.add_argument('--precondition_topk', type=int, default=200, help='consider top k outputs from text generation at each step before conditioning and re-pruning')
186
+ parser.add_argument('--do_sample', action='store_true', default=False, help='sample instead of greedy')
187
+ parser.add_argument('--condition_lambda', type=float, default=1.0, help='lambda weight on conditioning model')
188
+ parser.add_argument('--length_cutoff', type=int, default=512, help='max length')
189
+
190
+ parser.add_argument('--seed', type=int, default=1, help='random seed')
191
+ parser.add_argument('--device', type=str, default='cuda', choices=['cpu', 'cuda'])
192
+ parser.add_argument('--debug', action='store_true', default=False)
193
+
194
+ args = parser.parse_args()
195
+
196
+ random.seed(args.seed)
197
+ np.random.seed(args.seed)
198
+ torch.manual_seed(args.seed)
199
+
200
+ main(args)
naacl-2021-fudge-controlled-generation/evaluate_formality.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import time
4
+ import pickle
5
+ import math
6
+ from argparse import ArgumentParser
7
+ from collections import namedtuple
8
+
9
+ from tqdm import tqdm
10
+ import numpy as np
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ from transformers import AutoTokenizer, AutoModelWithLMHead, pipeline, set_seed, GPT2Tokenizer, GPT2Model, MarianTokenizer, MarianMTModel
15
+
16
+ from data import Dataset
17
+ from model import Model
18
+ from util import save_checkpoint, ProgressMeter, AverageMeter, num_params
19
+ from constants import *
20
+ from predict_formality import predict_formality
21
+
22
+ def main(args):
23
+ with open(args.dataset_info, 'rb') as rf:
24
+ dataset_info = pickle.load(rf)
25
+ tokenizer = MarianTokenizer.from_pretrained(args.model_string)
26
+ tokenizer.add_special_tokens({'pad_token': PAD_TOKEN})
27
+ pad_id = tokenizer.encode(PAD_TOKEN)[0]
28
+ model = MarianMTModel.from_pretrained(args.model_string, return_dict=True).to(args.device)
29
+ if args.model_path is not None:
30
+ if os.path.isdir(args.model_path):
31
+ for _, _, files in os.walk(args.model_path):
32
+ for fname in files:
33
+ if fname.endswith('.ckpt'):
34
+ args.model_path = os.path.join(args.model_path, fname)
35
+ break
36
+ ckpt = torch.load(args.model_path, map_location=torch.device(args.device))
37
+ try:
38
+ model.load_state_dict(ckpt['state_dict'], strict=False)
39
+ except:
40
+ state_dict = {}
41
+ for key in ckpt['state_dict'].keys():
42
+ assert key.startswith('model.')
43
+ state_dict[key[6:]] = ckpt['state_dict'][key]
44
+ model.load_state_dict(state_dict)
45
+ model.eval()
46
+
47
+ checkpoint = torch.load(args.ckpt, map_location=args.device)
48
+ model_args = checkpoint['args']
49
+ conditioning_model = Model(model_args, pad_id, len(dataset_info.index2word)) # no need to get the glove embeddings when reloading since they're saved in model ckpt anyway
50
+ conditioning_model.load_state_dict(checkpoint['state_dict'])
51
+ conditioning_model = conditioning_model.to(args.device)
52
+ conditioning_model.eval()
53
+ if args.verbose:
54
+ print("=> loaded checkpoint '{}' (epoch {})"
55
+ .format(args.ckpt, checkpoint['epoch']))
56
+ print('num params', num_params(conditioning_model))
57
+
58
+ inputs = []
59
+ with open(args.in_file, 'r') as rf:
60
+ for line in rf:
61
+ inputs.append(line.strip())
62
+
63
+ for inp in tqdm(inputs, total=len(inputs)):
64
+ results = predict_formality(model,
65
+ tokenizer,
66
+ conditioning_model,
67
+ [inp],
68
+ dataset_info,
69
+ precondition_topk=args.precondition_topk,
70
+ do_sample=args.do_sample,
71
+ length_cutoff=args.length_cutoff,
72
+ condition_lambda=args.condition_lambda,
73
+ device=args.device)
74
+ print(results[0])
75
+
76
+
77
+ if __name__=='__main__':
78
+ parser = ArgumentParser()
79
+
80
+ # DATA
81
+ parser.add_argument('--ckpt', type=str, required=True)
82
+ parser.add_argument('--dataset_info', type=str, required=True, help='saved dataset info')
83
+ parser.add_argument('--model_string', type=str, default='Helsinki-NLP/opus-mt-es-en')
84
+ parser.add_argument('--model_path', type=str, default=None)
85
+
86
+ parser.add_argument('--in_file', type=str, default=None, required=True, help='file containing text to run pred on')
87
+
88
+ parser.add_argument('--precondition_topk', type=int, default=200, help='consider top k outputs from gpt at each step before conditioning and re-pruning')
89
+ parser.add_argument('--do_sample', action='store_true', default=False, help='sample or greedy; only greedy implemented')
90
+ parser.add_argument('--condition_lambda', type=float, default=1.0, help='lambda weight on conditioning model')
91
+ parser.add_argument('--length_cutoff', type=int, default=512, help='max length')
92
+
93
+ parser.add_argument('--seed', type=int, default=1, help='random seed')
94
+ parser.add_argument('--device', type=str, default='cuda', choices=['cpu', 'cuda'])
95
+ parser.add_argument('--debug', action='store_true', default=False)
96
+ parser.add_argument('--verbose', action='store_true', default=False)
97
+
98
+ args = parser.parse_args()
99
+
100
+ random.seed(args.seed)
101
+ np.random.seed(args.seed)
102
+ torch.manual_seed(args.seed)
103
+
104
+ main(args)
naacl-2021-fudge-controlled-generation/evaluate_poetry.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import time
4
+ import pickle
5
+ import math
6
+ from argparse import ArgumentParser
7
+ import string
8
+ from collections import defaultdict
9
+
10
+ from tqdm import tqdm
11
+ import numpy as np
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+ from transformers import AutoTokenizer, AutoModelWithLMHead, pipeline, set_seed, GPT2Tokenizer, GPT2Model
16
+
17
+ from data import Dataset, load_rhyme_info
18
+ from model import Model
19
+ from util import save_checkpoint, ProgressMeter, AverageMeter, num_params
20
+ from constants import *
21
+ from poetry_util import get_rhymes, count_syllables
22
+ from predict_poetry import predict_couplet
23
+
24
+ def main(args):
25
+ with open(args.dataset_info, 'rb') as rf:
26
+ dataset_info = pickle.load(rf)
27
+ gpt_tokenizer = AutoTokenizer.from_pretrained(args.model_string)
28
+ gpt_tokenizer.add_special_tokens({'pad_token': PAD_TOKEN})
29
+ gpt_pad_id = gpt_tokenizer.encode(PAD_TOKEN)[0]
30
+ gpt_model = AutoModelWithLMHead.from_pretrained(args.model_string).to(args.device)
31
+ gpt_model.eval()
32
+
33
+ checkpoint = torch.load(args.iambic_ckpt, map_location=args.device)
34
+ model_args = checkpoint['args']
35
+ iambic_model = Model(model_args, gpt_pad_id, len(dataset_info.index2word)) # no need to get the glove embeddings when reloading since they're saved in model ckpt anyway
36
+ iambic_model.load_state_dict(checkpoint['state_dict'])
37
+ iambic_model = iambic_model.to(args.device)
38
+ iambic_model.eval()
39
+ if args.verbose:
40
+ print("=> loaded checkpoint '{}' (epoch {})"
41
+ .format(args.iambic_ckpt, checkpoint['epoch']))
42
+ print('iambic model num params', num_params(iambic_model))
43
+
44
+ with open(args.rhyme_info, 'rb') as rf:
45
+ rhyme_info = pickle.load(rf)
46
+ checkpoint = torch.load(args.rhyme_ckpt, map_location=args.device)
47
+ model_args = checkpoint['args']
48
+ rhyme_model = Model(model_args, gpt_pad_id, len(dataset_info.index2word), rhyme_group_size=len(rhyme_info.index2rhyme_group), verbose=args.verbose) # no need to get the glove embeddings when reloading since they're saved in model ckpt anyway
49
+ rhyme_model.load_state_dict(checkpoint['state_dict'])
50
+ rhyme_model = rhyme_model.to(args.device)
51
+ rhyme_model.eval()
52
+ if args.verbose:
53
+ print("=> loaded checkpoint '{}' (epoch {})"
54
+ .format(args.rhyme_ckpt, checkpoint['epoch']))
55
+ print('rhyme model num params', num_params(rhyme_model))
56
+
57
+ checkpoint = torch.load(args.newline_ckpt, map_location=args.device)
58
+ model_args = checkpoint['args']
59
+ newline_model = Model(model_args, gpt_pad_id, len(dataset_info.index2word)) # no need to get the glove embeddings when reloading since they're saved in model ckpt anyway
60
+ newline_model.load_state_dict(checkpoint['state_dict'])
61
+ newline_model = newline_model.to(args.device)
62
+ newline_model.eval()
63
+ if args.verbose:
64
+ print("=> loaded checkpoint '{}' (epoch {})"
65
+ .format(args.newline_ckpt, checkpoint['epoch']))
66
+ print('iambic model num params', num_params(newline_model))
67
+
68
+ with open(args.prefix_file, 'r') as rf:
69
+ lines = rf.readlines()
70
+ for line in tqdm(lines, total=len(lines)):
71
+ couplet = predict_couplet(gpt_model,
72
+ gpt_tokenizer,
73
+ iambic_model,
74
+ rhyme_model,
75
+ newline_model,
76
+ [line],
77
+ dataset_info,
78
+ rhyme_info,
79
+ args.precondition_topk,
80
+ args.topk,
81
+ condition_lambda=args.condition_lambda,
82
+ device=args.device)
83
+ assert len(couplet) == 2
84
+ print(couplet[1].strip().replace('\n', ''))
85
+
86
+
87
+ if __name__=='__main__':
88
+ parser = ArgumentParser()
89
+
90
+ # DATA
91
+ parser.add_argument('--iambic_ckpt', type=str, required=True)
92
+ parser.add_argument('--rhyme_ckpt', type=str, required=True)
93
+ parser.add_argument('--newline_ckpt', type=str, required=True)
94
+ parser.add_argument('--dataset_info', type=str, required=True, help='saved dataset info')
95
+ parser.add_argument('--rhyme_info', type=str, required=True, help='saved rhyme info')
96
+ parser.add_argument('--model_string', type=str, default='gpt2-medium')
97
+
98
+ parser.add_argument('--prefix_file', type=str, default=None, required=True, help='file of prefix lines for couplets')
99
+
100
+ parser.add_argument('--precondition_topk', type=int, default=200, help='consider top k outputs from gpt at each step before conditioning and re-pruning')
101
+ parser.add_argument('--topk', type=int, default=10, help='consider top k outputs from gpt at each step')
102
+ parser.add_argument('--condition_lambda', type=float, default=1.0, help='lambda weight on conditioning model')
103
+
104
+ parser.add_argument('--seed', type=int, default=1, help='random seed')
105
+ parser.add_argument('--device', type=str, default='cuda', choices=['cpu', 'cuda'])
106
+ parser.add_argument('--debug', action='store_true', default=False)
107
+ parser.add_argument('--verbose', action='store_true', default=False)
108
+
109
+ args = parser.parse_args()
110
+
111
+ random.seed(args.seed)
112
+ np.random.seed(args.seed)
113
+ torch.manual_seed(args.seed)
114
+
115
+ main(args)
naacl-2021-fudge-controlled-generation/evaluate_topic.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import time
4
+ import pickle
5
+ import math
6
+ from argparse import ArgumentParser
7
+ from collections import defaultdict
8
+ import string
9
+ import csv
10
+
11
+ from tqdm import tqdm
12
+ import numpy as np
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+ from transformers import AutoTokenizer, AutoModelWithLMHead, pipeline, set_seed, GPT2Tokenizer, GPT2Model
17
+
18
+ from data import Dataset
19
+ from model import Model
20
+ from util import save_checkpoint, ProgressMeter, AverageMeter, num_params, pad_mask
21
+ from predict_topic import predict
22
+ from constants import *
23
+
24
+
25
+ def main(args):
26
+ with open(args.dataset_info, 'rb') as rf:
27
+ dataset_info = pickle.load(rf)
28
+ gpt_tokenizer = AutoTokenizer.from_pretrained(args.model_string)
29
+ gpt_tokenizer.add_special_tokens({'pad_token': PAD_TOKEN})
30
+ gpt_pad_id = gpt_tokenizer.encode(PAD_TOKEN)[0]
31
+ gpt_model = AutoModelWithLMHead.from_pretrained(args.model_string).to(args.device)
32
+ gpt_model.eval()
33
+
34
+ checkpoint = torch.load(args.ckpt, map_location=args.device)
35
+ model_args = checkpoint['args']
36
+ conditioning_model = Model(model_args, gpt_pad_id, len(dataset_info.index2word)) # no need to get the glove embeddings when reloading since they're saved in model ckpt anyway
37
+ conditioning_model.load_state_dict(checkpoint['state_dict'])
38
+ conditioning_model = conditioning_model.to(args.device)
39
+ conditioning_model.eval()
40
+ if args.verbose:
41
+ print("=> loaded checkpoint '{}' (epoch {})"
42
+ .format(args.ckpt, checkpoint['epoch']))
43
+ print('num params', num_params(conditioning_model))
44
+
45
+ input_texts, conditions, categories = [], [], []
46
+
47
+ if args.condition_file is not None:
48
+ with open(args.condition_file, 'r') as rf:
49
+ for line in rf:
50
+ input_texts.append(line.strip().split('\t')[0])
51
+ conditions.append(line.strip().split('\t')[1])
52
+ categories.append(None)
53
+ for cw in conditions[-1].split():
54
+ assert cw in dataset_info.word2index
55
+ else:
56
+ prefixes = []
57
+ with open(args.prefix_file, 'r') as rf:
58
+ for line in rf:
59
+ prefixes.append(line.strip())
60
+ condition_wordlists = []
61
+ for root, _, files in os.walk(args.wordlist_dir):
62
+ for fname in files:
63
+ words = []
64
+ with open(os.path.join(root, fname), 'r') as rf:
65
+ for line in rf:
66
+ word = line.strip()
67
+ if word in dataset_info.word2index:
68
+ words.append(word)
69
+ else:
70
+ if args.verbose:
71
+ print('word not found:', word)
72
+ condition_wordlists.append((' '.join(words), fname.split('.')[0]))
73
+ for p in prefixes:
74
+ for c, category in condition_wordlists:
75
+ input_texts.append(p)
76
+ conditions.append(c)
77
+ categories.append(category)
78
+
79
+ all_cr = []
80
+ pair_num = 0
81
+ for input_text, condition_words, category in tqdm(zip(input_texts, conditions, categories), total=len(conditions)):
82
+ predict_function = predict
83
+ condition_results = []
84
+ for i in range(0, args.sample_size, args.max_sample_batch):
85
+ num_samples = min(args.max_sample_batch, args.sample_size - i)
86
+ condition_results += predict_function(gpt_model,
87
+ gpt_tokenizer,
88
+ conditioning_model,
89
+ [input_text for _ in range(num_samples)],
90
+ condition_words,
91
+ dataset_info,
92
+ args.precondition_topk,
93
+ args.topk,
94
+ args.length_cutoff,
95
+ condition_lambda=args.condition_lambda,
96
+ device=args.device)
97
+ all_cr.append((input_text, category, condition_results))
98
+ pair_num += 1
99
+ if args.max_pairs > 0 and pair_num >= args.max_pairs:
100
+ break
101
+ with open(args.log_file, 'w') as wf:
102
+ writer = csv.DictWriter(wf, fieldnames=['category', 'input_text', 'generation'])
103
+ writer.writeheader()
104
+ for cr_group in all_cr:
105
+ for cr in cr_group[2]:
106
+ writer.writerow({'category': cr_group[1], 'input_text': cr_group[0], 'generation': cr})
107
+
108
+
109
+ if __name__=='__main__':
110
+ parser = ArgumentParser()
111
+
112
+ # DATA
113
+ parser.add_argument('--ckpt', type=str, required=True)
114
+ parser.add_argument('--log_file', type=str, required=True, help='file to write outputs to (csv format)')
115
+ parser.add_argument('--dataset_info', type=str, required=True, help='saved dataset info')
116
+ parser.add_argument('--model_string', type=str, default='gpt2-medium')
117
+
118
+ parser.add_argument('--condition_file', type=str, default=None, help='file of inputs and conditions')
119
+ parser.add_argument('--prefix_file', type=str, default=None, help='prefix set')
120
+ parser.add_argument('--wordlist_dir', type=str, default=None, help='dir of bow wordlists for categories')
121
+ parser.add_argument('--sample_size', type=int, default=3, help='samples per input text-condition pair')
122
+ parser.add_argument('--max_sample_batch', type=int, default=3, help='max samples at a time')
123
+ parser.add_argument('--max_pairs', type=int, default=-1, help='max input-condition pairs, for debugging quickly')
124
+
125
+ parser.add_argument('--precondition_topk', type=int, default=200, help='consider top k outputs from gpt at each step before conditioning and re-pruning')
126
+ parser.add_argument('--topk', type=int, default=10, help='consider top k outputs from gpt at each step')
127
+ parser.add_argument('--condition_lambda', type=float, default=1.0, help='lambda weight on conditioning model')
128
+ parser.add_argument('--length_cutoff', type=int, default=80, help='max length')
129
+
130
+ parser.add_argument('--seed', type=int, default=1, help='random seed')
131
+ parser.add_argument('--device', type=str, default='cuda', choices=['cpu', 'cuda'])
132
+ parser.add_argument('--debug', action='store_true', default=False)
133
+ parser.add_argument('--verbose', action='store_true', default=False)
134
+
135
+ args = parser.parse_args()
136
+
137
+ assert (args.condition_file is not None) != (args.prefix_file is not None and args.wordlist_dir is not None) # one of two interfaces for specifying
138
+
139
+ random.seed(args.seed)
140
+ np.random.seed(args.seed)
141
+ torch.manual_seed(args.seed)
142
+
143
+ main(args)
naacl-2021-fudge-controlled-generation/formality_data/README.md ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ `fisher_test_oracle.es` is the source-side Spanish test set.
2
+ `test_noid.cleaned_0` and `test_noid.cleaned_1` are Salesky 2019's fluent English test-time references.
naacl-2021-fudge-controlled-generation/formality_data/fisher_test_oracle.es ADDED
The diff for this file is too large to render. See raw diff
 
naacl-2021-fudge-controlled-generation/formality_data/test.noid.cleaned_0 ADDED
The diff for this file is too large to render. See raw diff
 
naacl-2021-fudge-controlled-generation/formality_data/test.noid.cleaned_1 ADDED
The diff for this file is too large to render. See raw diff
 
naacl-2021-fudge-controlled-generation/main.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import time
4
+ import pickle
5
+ import math
6
+ from argparse import ArgumentParser
7
+
8
+ from tqdm import tqdm
9
+ import numpy as np
10
+ import torch
11
+ import torch.nn as nn
12
+
13
+ from data import Dataset
14
+ from model import Model
15
+ from util import save_checkpoint, ProgressMeter, AverageMeter, num_params, pad_mask
16
+ from constants import *
17
+
18
+
19
+ def train(model, dataset, optimizer, criterion, epoch, args, data_start_index):
20
+ model.train()
21
+ if data_start_index == 0:
22
+ dataset.shuffle('train', seed=epoch + args.seed)
23
+ if args.epoch_max_len is not None:
24
+ data_end_index = min(data_start_index + args.epoch_max_len, len(dataset.splits['train']))
25
+ loader = dataset.loader('train', num_workers=args.num_workers, indices=list(range(data_start_index, data_end_index)))
26
+ data_start_index = data_end_index if data_end_index < len(dataset.splits['train']) else 0
27
+ else:
28
+ loader = dataset.loader('train', num_workers=args.num_workers)
29
+ loss_meter = AverageMeter('loss', ':6.4f')
30
+ total_length = len(loader)
31
+ progress = ProgressMeter(total_length, [loss_meter], prefix='Training: ')
32
+ for batch_num, batch in enumerate(tqdm(loader, total=len(loader))):
33
+ batch = [tensor.to(args.device) for tensor in batch]
34
+ inputs, lengths, future_words, log_probs, labels, classification_targets, syllables_to_go, future_word_num_syllables, rhyme_group_index = batch
35
+ if args.task not in ['formality', 'iambic']:
36
+ if not args.debug and len(inputs) != args.batch_size: # it'll screw up the bias...?
37
+ continue
38
+ scores = model(inputs, lengths, future_words, log_probs, syllables_to_go, future_word_num_syllables, rhyme_group_index, run_classifier=True)
39
+ if args.task == 'formality': # we're learning for all positions at once. scores are batch x seq
40
+ expanded_labels = classification_targets.unsqueeze(1).expand(-1, scores.shape[1]) # batch x seq
41
+ length_mask = pad_mask(lengths).permute(1, 0) # batch x seq
42
+ loss = criterion(scores.flatten()[length_mask.flatten()==1], expanded_labels.flatten().float()[length_mask.flatten()==1])
43
+ elif args.task in ['iambic', 'newline']:
44
+ use_indices = classification_targets.flatten() != -1
45
+ loss = criterion(scores.flatten()[use_indices], classification_targets.flatten().float()[use_indices])
46
+ else: # topic, rhyme
47
+ loss = criterion(scores.flatten(), labels.flatten().float())
48
+ optimizer.zero_grad()
49
+ loss.backward()
50
+ optimizer.step()
51
+ loss_meter.update(loss.detach(), len(labels))
52
+ if batch_num % args.train_print_freq == 0:
53
+ progress.display(batch_num)
54
+ progress.display(total_length)
55
+ return data_start_index
56
+
57
+
58
+ def validate(model, dataset, criterion, epoch, args):
59
+ model.eval()
60
+ random.seed(0)
61
+ loader = dataset.loader('val', num_workers=args.num_workers)
62
+ loss_meter = AverageMeter('loss', ':6.4f')
63
+ total_length = len(loader)
64
+ progress = ProgressMeter(total_length, [loss_meter], prefix='Validation: ')
65
+ with torch.no_grad():
66
+ for batch_num, batch in enumerate(tqdm(loader, total=len(loader))):
67
+ batch = [tensor.to(args.device) for tensor in batch]
68
+ inputs, lengths, future_words, log_probs, labels, classification_targets, syllables_to_go, future_word_num_syllables, rhyme_group_index = batch
69
+ if args.task not in ['formality', 'iambic']: # topic predictor
70
+ if not args.debug and len(inputs) != args.batch_size:
71
+ continue
72
+ scores = model(inputs, lengths, future_words, log_probs, syllables_to_go, future_word_num_syllables, rhyme_group_index, run_classifier=True)
73
+ if args.task == 'formality': # we're learning for all positions at once. scores are batch x seq
74
+ expanded_labels = classification_targets.unsqueeze(1).expand(-1, scores.shape[1]) # batch x seq
75
+ length_mask = pad_mask(lengths).permute(1, 0) # batch x seq
76
+ loss = criterion(scores.flatten()[length_mask.flatten()==1], expanded_labels.flatten().float()[length_mask.flatten()==1])
77
+ elif args.task in ['iambic', 'newline']:
78
+ use_indices = classification_targets.flatten() != -1
79
+ loss = criterion(scores.flatten()[use_indices], classification_targets.flatten().float()[use_indices])
80
+ else: # topic, rhyme
81
+ loss = criterion(scores.flatten(), labels.flatten().float())
82
+ loss_meter.update(loss.detach(), len(labels))
83
+ if batch_num % args.train_print_freq == 0:
84
+ progress.display(batch_num)
85
+ progress.display(total_length)
86
+ return loss_meter.avg
87
+
88
+
89
+ def main(args):
90
+ dataset = Dataset(args)
91
+ os.makedirs(args.save_dir, exist_ok=True)
92
+ with open(os.path.join(args.save_dir, 'dataset_info'), 'wb') as wf:
93
+ pickle.dump(dataset.dataset_info, wf)
94
+ if args.task == 'rhyme':
95
+ with open(os.path.join(args.save_dir, 'rhyme_info'), 'wb') as wf:
96
+ pickle.dump(dataset.rhyme_info, wf)
97
+ if args.ckpt:
98
+ checkpoint = torch.load(args.ckpt, map_location=args.device)
99
+ start_epoch = checkpoint['epoch'] + 1
100
+ best_val_metric = checkpoint['best_metric']
101
+ model_args = checkpoint['args']
102
+ model = Model(model_args, dataset.gpt_pad_id, len(dataset.index2word), rhyme_group_size=len(dataset.index2rhyme_group) if args.task == 'rhyme' else None) # no need to get the glove embeddings when reloading since they're saved in model ckpt anyway
103
+ model.load_state_dict(checkpoint['state_dict'])
104
+ model = model.to(args.device)
105
+ optimizer = torch.optim.Adam(model.parameters(), lr=model_args.lr)
106
+ optimizer.load_state_dict(checkpoint['optimizer'])
107
+ data_start_index = checkpoint['data_start_index']
108
+ print("=> loaded checkpoint '{}' (epoch {})"
109
+ .format(args.ckpt, checkpoint['epoch']))
110
+ # NOTE: just import pdb after loading the model here if you want to play with it, it's easy
111
+ # model.eval()
112
+ # import pdb; pdb.set_trace()
113
+ else:
114
+ model = Model(args, dataset.gpt_pad_id, len(dataset.index2word), rhyme_group_size=len(dataset.index2rhyme_group) if args.task == 'rhyme' else None, glove_embeddings=dataset.glove_embeddings)
115
+ model = model.to(args.device)
116
+ optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
117
+ best_val_metric = 1e8 # lower is better for BCE
118
+ data_start_index = 0
119
+ print('num params', num_params(model))
120
+ criterion = nn.BCEWithLogitsLoss().to(args.device)
121
+
122
+ if args.evaluate:
123
+ epoch = 0
124
+ validate(model, dataset, criterion, epoch, args)
125
+ return
126
+ for epoch in range(args.epochs):
127
+ print("TRAINING: Epoch {} at {}".format(epoch, time.ctime()))
128
+ data_start_index = train(model, dataset, optimizer, criterion, epoch, args, data_start_index)
129
+ if epoch % args.validation_freq == 0:
130
+ print("VALIDATION: Epoch {} at {}".format(epoch, time.ctime()))
131
+ metric = validate(model, dataset, criterion, epoch, args)
132
+
133
+ if not args.debug:
134
+ if metric < best_val_metric:
135
+ print('new best val metric', metric)
136
+ best_val_metric = metric
137
+ save_checkpoint({
138
+ 'epoch': epoch,
139
+ 'state_dict': model.state_dict(),
140
+ 'best_metric': best_val_metric,
141
+ 'optimizer': optimizer.state_dict(),
142
+ 'data_start_index': data_start_index,
143
+ 'args': args
144
+ }, os.path.join(args.save_dir, 'model_best.pth.tar'))
145
+ save_checkpoint({
146
+ 'epoch': epoch,
147
+ 'state_dict': model.state_dict(),
148
+ 'best_metric': metric,
149
+ 'optimizer': optimizer.state_dict(),
150
+ 'data_start_index': data_start_index,
151
+ 'args': args
152
+ }, os.path.join(args.save_dir, 'model_epoch' + str(epoch) + '.pth.tar'))
153
+
154
+
155
+ if __name__=='__main__':
156
+ parser = ArgumentParser()
157
+
158
+ # DATA
159
+ parser.add_argument('--task', type=str, required=True, choices=['iambic', 'rhyme', 'newline', 'topic', 'formality', 'clickbait'])
160
+ parser.add_argument('--data_dir', type=str, required=True)
161
+ parser.add_argument('--glove_file', type=str, help='glove embedding init, for topic task')
162
+
163
+ # SAVE/LOAD
164
+ parser.add_argument('--save_dir', type=str, required=True, help='where to save ckpts')
165
+ parser.add_argument('--ckpt', type=str, default=None, help='load ckpt from file if given')
166
+ parser.add_argument('--dataset_info', type=str, help='saved dataset info')
167
+ parser.add_argument('--rhyme_info', type=str, help='saved dataset rhyme info, for a ckpt with task==rhyme')
168
+
169
+ # TRAINING
170
+ parser.add_argument('--batch_size', type=int, default=128)
171
+ parser.add_argument('--epochs', type=int, default=100)
172
+ parser.add_argument('--epoch_max_len', type=int, default=None, help='max batches per epoch if set, for more frequent validation')
173
+ parser.add_argument('--validation_freq', type=int, default=1, help='validate every X epochs')
174
+ parser.add_argument('--lr', type=float, default=1e-3, help='Adam learning rate')
175
+ parser.add_argument('--seed', type=int, default=1, help='random seed')
176
+ parser.add_argument('--device', type=str, default='cuda', choices=['cpu', 'cuda'])
177
+ parser.add_argument('--num_workers', type=int, default=20, help='num workers for data loader')
178
+ parser.add_argument('--evaluate', action='store_true', default=False)
179
+ parser.add_argument('--debug', action='store_true', default=False)
180
+
181
+ # PRINTING
182
+ parser.add_argument('--train_print_freq', type=int, default=100, help='how often to print metrics (every X batches)')
183
+
184
+ args = parser.parse_args()
185
+
186
+ random.seed(args.seed)
187
+ np.random.seed(args.seed)
188
+ torch.manual_seed(args.seed)
189
+ if args.evaluate:
190
+ assert args.ckpt is not None
191
+
192
+ main(args)
naacl-2021-fudge-controlled-generation/model.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from torch.nn.utils.rnn import pad_sequence, pad_packed_sequence, pack_padded_sequence
7
+ from transformers import AutoTokenizer, AutoModelWithLMHead, pipeline, set_seed, GPT2Tokenizer, GPT2Model, GPT2LMHeadModel, GPT2Config, GPT2ForSequenceClassification, GPT2LMHeadModel, MarianTokenizer
8
+
9
+ from constants import *
10
+ from util import pad_mask
11
+ from clickbait_classifier import BertClickbaitClassifier, ClickbaitConfig
12
+
13
+ class Model(nn.Module):
14
+ def __init__(self, args, gpt_pad_id, vocab_size, rhyme_group_size=None, glove_embeddings=None, verbose=True):
15
+ super(Model, self).__init__()
16
+
17
+ # self.topic = args.task == 'topic'
18
+ self.formality = args.task == 'formality'
19
+ self.iambic = args.task == 'iambic'
20
+ self.rhyme = args.task == 'rhyme'
21
+ self.newline = args.task == 'newline'
22
+ self.clickbait = args.task == 'clickbait'
23
+ # if self.topic:
24
+ # self.gpt_embed = nn.Embedding(gpt_pad_id + 1, HIDDEN_DIM, padding_idx=gpt_pad_id) # these are subwords, not words
25
+ # if glove_embeddings is None:
26
+ # if verbose:
27
+ # print('initializing word embeddings from scratch')
28
+ # self.word_embed = nn.Embedding(vocab_size, GLOVE_DIM, padding_idx=0)
29
+ # else:
30
+ # if verbose:
31
+ # print('initializing word embeddings from glove')
32
+ # self.word_embed = nn.Embedding.from_pretrained(glove_embeddings, padding_idx=0)
33
+ # self.rnn = nn.LSTM(HIDDEN_DIM, RNN_DIM, num_layers=3, bidirectional=True)
34
+ # self.attention_linear = nn.Linear(HIDDEN_DIM, HIDDEN_DIM)
35
+ # large_hidden_dim = HIDDEN_DIM
36
+ # self.embed_key_linear = nn.Linear(large_hidden_dim, HIDDEN_DIM)
37
+ # self.attention_value_linear = nn.Linear(HIDDEN_DIM, HIDDEN_DIM)
38
+ # self.out_embed_linear = nn.Linear(HIDDEN_DIM, HIDDEN_DIM)
39
+ # self.out_linear = nn.Linear(HIDDEN_DIM, HIDDEN_DIM)
40
+ # self.out_linear2 = nn.Linear(HIDDEN_DIM + large_hidden_dim, HIDDEN_DIM)
41
+ # self.out_linear3 = nn.Linear(HIDDEN_DIM, 1)
42
+ # self.nonlinear = nn.ReLU()
43
+ # elif self.formality:
44
+ if self.formality:
45
+ self.marian_embed = nn.Embedding(gpt_pad_id + 1, HIDDEN_DIM, padding_idx=0) # 0 in marian is ''
46
+ self.rnn = nn.LSTM(HIDDEN_DIM, HIDDEN_DIM, num_layers=3, bidirectional=False, dropout=0.5) # want it to be causal so we can learn all positions
47
+ self.out_linear = nn.Linear(HIDDEN_DIM, 1)
48
+ elif self.iambic:
49
+ self.gpt_embed = nn.Embedding(gpt_pad_id + 1, HIDDEN_DIM, padding_idx=gpt_pad_id)
50
+ self.rnn = nn.LSTM(HIDDEN_DIM, HIDDEN_DIM, num_layers=3, bidirectional=False, dropout=0) # want it to be causal so we can learn all positions
51
+ self.out_linear = nn.Linear(HIDDEN_DIM, 1)
52
+ elif self.rhyme:
53
+ self.gpt_embed = nn.Embedding(gpt_pad_id + 1, HIDDEN_DIM, padding_idx=gpt_pad_id) # these are subwords, not words
54
+ self.word_embed = nn.Embedding(rhyme_group_size+1, GLOVE_DIM, padding_idx=0) # this embedding for future words will actually embed the rhyme group idx
55
+ self.rnn = nn.LSTM(HIDDEN_DIM, RNN_DIM, num_layers=3, bidirectional=True)
56
+ self.attention_linear = nn.Linear(HIDDEN_DIM, HIDDEN_DIM)
57
+ large_hidden_dim = HIDDEN_DIM + COUNT_SYLLABLE_DIM
58
+ self.embed_key_linear = nn.Linear(large_hidden_dim, HIDDEN_DIM)
59
+ self.attention_value_linear = nn.Linear(HIDDEN_DIM, HIDDEN_DIM)
60
+ self.out_embed_linear = nn.Linear(HIDDEN_DIM, HIDDEN_DIM)
61
+ self.out_linear = nn.Linear(HIDDEN_DIM, HIDDEN_DIM)
62
+ self.out_linear2 = nn.Linear(HIDDEN_DIM + large_hidden_dim, HIDDEN_DIM)
63
+ self.out_linear3 = nn.Linear(HIDDEN_DIM, 1)
64
+ self.count_syllable_embed = nn.Embedding(MAX_COUNT_SYLLABLE_DIST+1, COUNT_SYLLABLE_DIM)
65
+ self.nonlinear = nn.ReLU()
66
+ elif self.newline:
67
+ self.gpt_embed = nn.Embedding(gpt_pad_id + 1, HIDDEN_DIM, padding_idx=gpt_pad_id) # these are subwords, not words
68
+ self.rnn = nn.LSTM(HIDDEN_DIM, HIDDEN_DIM, num_layers=3, bidirectional=False)
69
+ self.count_syllable_embed = nn.Embedding(MAX_COUNT_SYLLABLE_DIST+1, COUNT_SYLLABLE_DIM)
70
+ self.out_linear = nn.Linear(HIDDEN_DIM + COUNT_SYLLABLE_DIM, HIDDEN_DIM)
71
+ self.out_linear2 = nn.Linear(HIDDEN_DIM, HIDDEN_DIM)
72
+ self.out_linear3 = nn.Linear(HIDDEN_DIM, 1)
73
+ self.nonlinear = nn.ReLU()
74
+ elif self.clickbait:
75
+ # mpnet_config = ClickbaitConfig(
76
+ # model_type="mpnet",
77
+ # pretrained_model="sentence-transformers/all-mpnet-base-v2",
78
+ # num_labels=1,
79
+ # dropout=0.2,
80
+ # inner_dim1=256,
81
+ # inner_dim2=32,
82
+ # max_length=25,
83
+ # load_pretrained=True,
84
+ # freeze_bert=False,
85
+ # )
86
+ #TODO add a checkpoint to Classifier
87
+ # print('add a checkpoint to Classifier')
88
+ checkpoint = args.checkpoint #'ckpt/clickbait_classifier/checkpoint-1464'
89
+ # self.classifier = BertClickbaitClassifier(config=mpnet_config).to(torch.device(args.device))
90
+ self.classifier = BertClickbaitClassifier.from_pretrained(checkpoint).to(torch.device(args.device))
91
+ else:
92
+ raise NotImplementedError # TODO honestly this can/should be refactored into different models
93
+
94
+
95
+ def forward(self, inputs, lengths=None, future_words=None, log_probs=None, syllables_to_go=None, future_word_num_syllables=None, rhyme_group_index=None, run_classifier=False, attention_mask=None):
96
+ """
97
+ inputs: token ids, batch x seq, right-padded with 0s
98
+ lengths: lengths of inputs; batch
99
+ future_words: batch x N words to check if not predict next token, else batch
100
+ log_probs: N
101
+ syllables_to_go: batch
102
+ """
103
+ # if self.topic:
104
+ # inputs = self.gpt_embed(inputs) # batch x seq x 300
105
+ # inputs = pack_padded_sequence(inputs.permute(1, 0, 2), lengths.cpu(), enforce_sorted=False)
106
+ # rnn_output, _ = self.rnn(inputs)
107
+ # rnn_output, _ = pad_packed_sequence(rnn_output)
108
+ # rnn_output = rnn_output.permute(1, 0, 2) # batch x seq x 300
109
+ # hidden = rnn_output
110
+ # attention_mask = pad_mask(lengths).permute(1, 0) # batch x seq
111
+ # embed = self.word_embed(future_words) # batch x N x 300
112
+ # embed_query = self.embed_key_linear(embed)
113
+ # attention_tensor = self.attention_linear(hidden).unsqueeze(2) * embed_query.unsqueeze(1) # batch x seq x N x 300
114
+ # attention_weights = F.softmax(attention_tensor.sum(dim=3), dim=1) # batch x seq x N
115
+ # attention_weights = attention_weights * attention_mask.unsqueeze(2)
116
+ # hidden = self.attention_value_linear(hidden)
117
+ # weighted_hidden = (hidden.unsqueeze(2) * attention_weights.unsqueeze(3)).sum(dim=1) # batch x seq x N x 768 -> batch x N x 768
118
+ # unnormalized_scores = (self.out_linear(weighted_hidden) * self.out_embed_linear(embed)) # batch x N x 300
119
+ # unnormalized_scores = torch.cat([unnormalized_scores, embed], dim=2)
120
+ # unnormalized_scores = self.nonlinear(self.out_linear2(self.nonlinear(unnormalized_scores)))
121
+ # unnormalized_scores = self.out_linear3(unnormalized_scores)
122
+ # scores = unnormalized_scores.squeeze(2) - log_probs.unsqueeze(0)
123
+ # return scores # batch x N of normalized scores or batch x
124
+ # elif self.formality:
125
+ if self.formality:
126
+ inputs = self.marian_embed(inputs)
127
+ inputs = pack_padded_sequence(inputs.permute(1, 0, 2), lengths.cpu(), enforce_sorted=False)
128
+ rnn_output, _ = self.rnn(inputs)
129
+ rnn_output, _ = pad_packed_sequence(rnn_output)
130
+ rnn_output = rnn_output.permute(1, 0, 2) # batch x seq x 300
131
+ return self.out_linear(rnn_output).squeeze(2)
132
+ elif self.iambic:
133
+ inputs = self.gpt_embed(inputs)
134
+ inputs = pack_padded_sequence(inputs.permute(1, 0, 2), lengths.cpu(), enforce_sorted=False)
135
+ rnn_output, _ = self.rnn(inputs)
136
+ rnn_output, _ = pad_packed_sequence(rnn_output)
137
+ rnn_output = rnn_output.permute(1, 0, 2) # batch x seq x 300
138
+ return self.out_linear(rnn_output).squeeze(2)
139
+ elif self.rhyme:
140
+ inputs = self.gpt_embed(inputs) # batch x seq x 300
141
+ inputs = pack_padded_sequence(inputs.permute(1, 0, 2), lengths.cpu(), enforce_sorted=False)
142
+ rnn_output, _ = self.rnn(inputs)
143
+ rnn_output, _ = pad_packed_sequence(rnn_output)
144
+ rnn_output = rnn_output.permute(1, 0, 2) # batch x seq x 300
145
+ hidden = rnn_output
146
+ attention_mask = pad_mask(lengths).permute(1, 0) # batch x seq
147
+ embed = self.word_embed(future_words) # batch x N x 300
148
+ embedded_syllables_to_go = self.count_syllable_embed(syllables_to_go).unsqueeze(1).expand(-1, embed.shape[1], -1) # batch x N x 100
149
+ auxiliary_embed = embedded_syllables_to_go
150
+ embed_query = self.embed_key_linear(torch.cat([embed, auxiliary_embed], dim=2))
151
+ attention_tensor = self.attention_linear(hidden).unsqueeze(2) * embed_query.unsqueeze(1) # batch x seq x N x 300
152
+ attention_weights = F.softmax(attention_tensor.sum(dim=3), dim=1) # batch x seq x N
153
+ attention_weights = attention_weights * attention_mask.unsqueeze(2)
154
+ hidden = self.attention_value_linear(hidden)
155
+ weighted_hidden = (hidden.unsqueeze(2) * attention_weights.unsqueeze(3)).sum(dim=1) # batch x seq x N x 768 -> batch x N x 768
156
+ unnormalized_scores = (self.out_linear(weighted_hidden) * self.out_embed_linear(embed)) # batch x N x 300
157
+ unnormalized_scores = torch.cat([unnormalized_scores, embed, auxiliary_embed], dim=2)
158
+ unnormalized_scores = self.nonlinear(self.out_linear2(self.nonlinear(unnormalized_scores)))
159
+ unnormalized_scores = self.out_linear3(unnormalized_scores)
160
+ scores = unnormalized_scores.squeeze(2) - log_probs.unsqueeze(0)
161
+ return scores # batch x N of normalized scores or batch x
162
+ elif self.newline:
163
+ inputs = self.gpt_embed(inputs) # batch x seq x 300
164
+ inputs = pack_padded_sequence(inputs.permute(1, 0, 2), lengths.cpu(), enforce_sorted=False)
165
+ rnn_output, _ = self.rnn(inputs)
166
+ rnn_output, _ = pad_packed_sequence(rnn_output)
167
+ rnn_output = rnn_output.permute(1, 0, 2) # batch x seq x 300
168
+ hidden = torch.cat([rnn_output, self.count_syllable_embed(syllables_to_go).unsqueeze(1).expand(-1, rnn_output.shape[1], -1)], dim=2)
169
+ return self.out_linear3(self.nonlinear(self.out_linear2(self.nonlinear(self.out_linear(hidden))))).squeeze(2)
170
+ elif self.clickbait:
171
+
172
+ input_ids = torch.tensor(inputs)
173
+ classifer_output = self.classifier(input_ids = input_ids, attention_mask = attention_mask).logits
174
+
175
+ classifer_output = classifer_output[None,:,:] # batch x seq x 300
176
+ # return self.out_linear(rnn_output).squeeze(2)
177
+ return classifer_output.squeeze(2)
178
+
179
+ else:
180
+ raise NotImplementedError
181
+
182
+
naacl-2021-fudge-controlled-generation/poetry_data/README.md ADDED
@@ -0,0 +1 @@
 
 
1
+ `couplet_prefixes.txt` contains the 13th line of each of Shakespeare's sonnets. `couplet_ends.txt` contains the 14th. (Each 14-line sonnet ends with a couplet in the last two lines). The prefixes are our test set prefixes for the couplet completion task; the ends are Shakespeare's outputs.
naacl-2021-fudge-controlled-generation/poetry_data/couplet_ends.txt ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ To eat the world's due, by the grave and thee.
2
+ And see thy blood warm when thou feel'st it cold.
3
+ Die single, and thine image dies with thee.
4
+ Which, used, lives th' executor to be.
5
+ Leese but their show; their substance still lives sweet.
6
+ To be death's conquest and make worms thine heir.
7
+ Unlook'd on diest, unless thou get a son.
8
+ Sings this to thee: 'thou single wilt prove none.'
9
+ That on himself such murderous shame commits.
10
+ That beauty still may live in thine or thee.
11
+ Thou shouldst print more, not let that copy die.
12
+ Save breed, to brave him when he takes thee hence.
13
+ You had a father: let your son say so.
14
+ Thy end is truth's and beauty's doom and date.
15
+ As he takes from you, I engraft you new.
16
+ And you must live, drawn by your own sweet skill.
17
+ You should live twice; in it and in my rhyme.
18
+ So long lives this and this gives life to thee.
19
+ My love shall in my verse ever live young.
20
+ Mine be thy love and thy love's use their treasure.
21
+ I will not praise that purpose not to sell.
22
+ Thou gavest me thine, not to give back again.
23
+ To hear with eyes belongs to love's fine wit.
24
+ They draw but what they see, know not the heart.
25
+ Where I may not remove nor be removed.
26
+ Till then not show my head where thou mayst prove me.
27
+ For thee and for myself no quiet find.
28
+ And night doth nightly make grief's strength seem stronger.
29
+ That then I scorn to change my state with kings.
30
+ All losses are restored and sorrows end.
31
+ And thou, all they, hast all the all of me.
32
+ Theirs for their style I'll read, his for his love.'
33
+ Suns of the world may stain when heaven's sun staineth.
34
+ And they are rich and ransom all ill deeds.
35
+ To that sweet thief which sourly robs from me.
36
+ As, thou being mine, mine is thy good report.
37
+ This wish I have; then ten times happy me!
38
+ The pain be mine, but thine shall be the praise.
39
+ By praising him here who doth hence remain!
40
+ Kill me with spites; yet we must not be foes.
41
+ Thine, by thy beauty being false to me.
42
+ Sweet flattery! then she loves but me alone.
43
+ And nights bright days when dreams do show thee me.
44
+ But heavy tears, badges of either's woe.
45
+ I send them back again and straight grow sad.
46
+ And my heart's right thy inward love of heart.
47
+ Awakes my heart to heart's and eye's delight.
48
+ For truth proves thievish for a prize so dear.
49
+ Since why to love I can allege no cause.
50
+ My grief lies onward and my joy behind.
51
+ Towards thee I'll run, and give him leave to go.
52
+ Being had, to triumph, being lack'd, to hope.
53
+ But you like none, none you, for constant heart.
54
+ When that shall fade, my verse distills your truth.
55
+ You live in this, and dwell in lover's eyes.
56
+ Makes summer's welcome thrice more wish'd, more rare.
57
+ Though you do any thing, he thinks no ill.
58
+ Not blame your pleasure, be it ill or well.
59
+ To subjects worse have given admiring praise.
60
+ Praising thy worth, despite his cruel hand.
61
+ From me far off, with others all too near.
62
+ Painting my age with beauty of thy days.
63
+ And they shall live, and he in them still green.
64
+ But weep to have that which it fears to lose.
65
+ That in black ink my love may still shine bright.
66
+ Save that, to die, I leave my love alone.
67
+ In days long since, before these last so bad.
68
+ To show false Art what beauty was of yore.
69
+ The solve is this, that thou dost common grow.
70
+ Then thou alone kingdoms of hearts shouldst owe.
71
+ And mock you with me after I am gone.
72
+ And so should you, to love things nothing worth.
73
+ To love that well which thou must leave ere long.
74
+ And that is this, and this with thee remains.
75
+ Or gluttoning on all, or all away.
76
+ So is my love still telling what is told.
77
+ Shall profit thee and much enrich thy book.
78
+ As high as learning my rude ignorance.
79
+ Since what he owes thee thou thyself dost pay.
80
+ The worst was this; my love was my decay.
81
+ Where breath most breathes, even in the mouths of men.
82
+ Where cheeks need blood; in thee it is abused.
83
+ Than both your poets can in praise devise.
84
+ Being fond on praise, which makes your praises worse.
85
+ Me for my dumb thoughts, speaking in effect.
86
+ Then lack'd I matter; that enfeebled mine.
87
+ In sleep a king, but waking no such matter.
88
+ That for thy right myself will bear all wrong.
89
+ For I must ne'er love him whom thou dost hate.
90
+ Compared with loss of thee will not seem so.
91
+ All this away and me most wretched make.
92
+ Thou mayst be false, and yet I know it not.
93
+ if thy sweet virtue answer not thy show!
94
+ Lilies that fester smell far worse than weeds.
95
+ The hardest knife ill-used doth lose his edge.
96
+ As, thou being mine, mine is thy good report.
97
+ That leaves look pale, dreading the winter's near.
98
+ As with your shadow I with these did play:
99
+ But sweet or colour it had stol'n from thee.
100
+ So thou prevent'st his scythe and crooked knife.
101
+ To make him seem long hence as he shows now.
102
+ Because I would not dull you with my song.
103
+ Your own glass shows you when you look in it.
104
+ Ere you were born was beauty's summer dead.
105
+ Which three till now never kept seat in one.
106
+ Had eyes to wonder, but lack tongues to praise.
107
+ When tyrants' crests and tombs of brass are spent.
108
+ Where time and outward form would show it dead.
109
+ Save thou, my rose; in it thou art my all.
110
+ Even to thy pure and most most loving breast.
111
+ Even that your pity is enough to cure me.
112
+ That all the world besides methinks are dead.
113
+ My most true mind thus makes mine eye untrue.
114
+ That mine eye loves it and doth first begin.
115
+ To give full growth to that which still doth grow?
116
+ I never writ, nor no man ever loved.
117
+ The constancy and virtue of your love.
118
+ Drugs poison him that so fell sick of you.
119
+ And gain by ill thrice more than I have spent.
120
+ Mine ransoms yours, and yours must ransom me.
121
+ All men are bad, and in their badness reign.
122
+ Were to import forgetfulness in me.
123
+ I will be true, despite thy scythe and thee.
124
+ Which die for goodness, who have lived for crime.
125
+ When most impeach'd stands least in thy control.
126
+ And her quietus is to render thee.
127
+ That every tongue says beauty should look so.
128
+ Give them thy fingers, me thy lips to kiss.
129
+ To shun the heaven that leads men to this hell.
130
+ As any she belied with false compare.
131
+ And thence this slander, as I think, proceeds.
132
+ And all they foul that thy complexion lack.
133
+ Perforce am thine, and all that is in me.
134
+ He pays the whole, and yet am I not free.
135
+ Think all but one, and me in that one 'Will.'
136
+ And then thou lovest me, for my name is 'Will.'
137
+ And to this false plague are they now transferr'd.
138
+ And in our faults by lies we flatter'd be.
139
+ Kill me outright with looks and rid my pain.
140
+ Bear thine eyes straight, though thy proud heart go wide.
141
+ That she that makes me sin awards me pain.
142
+ By self-example mayst thou be denied!
143
+ If thou turn back, and my loud crying still.
144
+ Till my bad angel fire my good one out.
145
+ And saved my life, saying 'not you.'
146
+ And Death once dead, there's no more dying then.
147
+ Who art as black as hell, as dark as night.
148
+ Lest eyes well-seeing thy foul faults should find.
149
+ Those that can see thou lovest, and I am blind.
150
+ More worthy I to be beloved of thee.
151
+ Her 'love' for whose dear love I rise and fall.
152
+ To swear against the truth so foul a lie!
153
+ Where Cupid got new fire--my mistress' eyes.
154
+ Love's fire heats water, water cools not love.
naacl-2021-fudge-controlled-generation/poetry_data/couplet_prefixes.txt ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Pity the world, or else this glutton be,
2
+ This were to be new made when thou art old,
3
+ But if thou live, remember'd not to be,
4
+ Thy unused beauty must be tomb'd with thee,
5
+ But flowers distill'd though they with winter meet,
6
+ Be not self-will'd, for thou art much too fair
7
+ So thou, thyself out-going in thy noon,
8
+ Whose speechless song, being many, seeming one,
9
+ No love toward others in that bosom sits
10
+ Make thee another self, for love of me,
11
+ She carved thee for her seal, and meant thereby
12
+ And nothing 'gainst Time's scythe can make defence
13
+ O, none but unthrifts! Dear my love, you know
14
+ Or else of thee this I prognosticate:
15
+ And all in war with Time for love of you,
16
+ To give away yourself keeps yourself still,
17
+ But were some child of yours alive that time,
18
+ So long as men can breathe or eyes can see,
19
+ Yet, do thy worst, old Time: despite thy wrong,
20
+ But since she prick'd thee out for women's pleasure,
21
+ Let them say more than like of hearsay well;
22
+ Presume not on thy heart when mine is slain;
23
+ O, learn to read what silent love hath writ:
24
+ Yet eyes this cunning want to grace their art;
25
+ Then happy I, that love and am beloved
26
+ Then may I dare to boast how I do love thee;
27
+ Lo! thus, by day my limbs, by night my mind,
28
+ But day doth daily draw my sorrows longer
29
+ For thy sweet love remember'd such wealth brings
30
+ But if the while I think on thee, dear friend,
31
+ Their images I loved I view in thee,
32
+ But since he died and poets better prove,
33
+ Yet him for this my love no whit disdaineth;
34
+ Ah! but those tears are pearl which thy love sheds,
35
+ That I an accessary needs must be
36
+ But do not so; I love thee in such sort
37
+ Look, what is best, that best I wish in thee:
38
+ If my slight Muse do please these curious days,
39
+ And that thou teachest how to make one twain,
40
+ Lascivious grace, in whom all ill well shows,
41
+ Hers by thy beauty tempting her to thee,
42
+ But here's the joy; my friend and I are one;
43
+ All days are nights to see till I see thee,
44
+ Receiving nought by elements so slow
45
+ This told, I joy; but then no longer glad,
46
+ As thus; mine eye's due is thy outward part,
47
+ Or, if they sleep, thy picture in my sight
48
+ And even thence thou wilt be stol'n, I fear,
49
+ To leave poor me thou hast the strength of laws,
50
+ For that same groan doth put this in my mind;
51
+ Since from thee going he went wilful-slow,
52
+ Blessed are you, whose worthiness gives scope,
53
+ In all external grace you have some part,
54
+ And so of you, beauteous and lovely youth,
55
+ So, till the judgment that yourself arise,
56
+ Else call it winter, which being full of care
57
+ So true a fool is love that in your will,
58
+ I am to wait, though waiting so be hell;
59
+ O, sure I am, the wits of former days
60
+ And yet to times in hope my verse shall stand,
61
+ For thee watch I whilst thou dost wake elsewhere,
62
+ 'Tis thee, myself, that for myself I praise,
63
+ His beauty shall in these black lines be seen,
64
+ This thought is as a death, which cannot choose
65
+ O, none, unless this miracle have might,
66
+ Tired with all these, from these would I be gone,
67
+ O, him she stores, to show what wealth she had
68
+ And him as for a map doth Nature store,
69
+ But why thy odour matcheth not thy show,
70
+ If some suspect of ill mask'd not thy show,
71
+ Lest the wise world should look into your moan
72
+ For I am shamed by that which I bring forth,
73
+ This thou perceivest, which makes thy love more strong,
74
+ The worth of that is that which it contains,
75
+ Thus do I pine and surfeit day by day,
76
+ For as the sun is daily new and old,
77
+ These offices, so oft as thou wilt look,
78
+ But thou art all my art and dost advance
79
+ Then thank him not for that which he doth say,
80
+ Then if he thrive and I be cast away,
81
+ You still shall live--such virtue hath my pen--
82
+ And their gross painting might be better used
83
+ There lives more life in one of your fair eyes
84
+ You to your beauteous blessings add a curse,
85
+ Then others for the breath of words respect,
86
+ But when your countenance fill'd up his line,
87
+ Thus have I had thee, as a dream doth flatter,
88
+ Such is my love, to thee I so belong,
89
+ For thee against myself I'll vow debate,
90
+ And other strains of woe, which now seem woe,
91
+ Wretched in this alone, that thou mayst take
92
+ But what's so blessed-fair that fears no blot?
93
+ How like Eve's apple doth thy beauty grow,
94
+ For sweetest things turn sourest by their deeds;
95
+ Take heed, dear heart, of this large privilege;
96
+ But do not so; I love thee in such sort
97
+ Or, if they sing, 'tis with so dull a cheer
98
+ Yet seem'd it winter still, and, you away,
99
+ More flowers I noted, yet I none could see
100
+ Give my love fame faster than Time wastes life;
101
+ Then do thy office, Muse; I teach thee how
102
+ Therefore like her I sometime hold my tongue,
103
+ And more, much more, than in my verse can sit
104
+ For fear of which, hear this, thou age unbred;
105
+ 'Fair, kind, and true,' have often lived alone,
106
+ For we, which now behold these present days,
107
+ And thou in this shalt find thy monument,
108
+ Finding the first conceit of love there bred
109
+ For nothing this wide universe I call,
110
+ Then give me welcome, next my heaven the best,
111
+ Pity me then, dear friend, and I assure ye
112
+ You are so strongly in my purpose bred
113
+ Incapable of more, replete with you,
114
+ If it be poison'd, 'tis the lesser sin
115
+ Love is a babe; then might I not say so,
116
+ If this be error and upon me proved,
117
+ Since my appeal says I did strive to prove
118
+ But thence I learn, and find the lesson true,
119
+ So I return rebuked to my content
120
+ But that your trespass now becomes a fee;
121
+ Unless this general evil they maintain,
122
+ To keep an adjunct to remember thee
123
+ This I do vow and this shall ever be;
124
+ To this I witness call the fools of time,
125
+ Hence, thou suborn'd informer! a true soul
126
+ Her audit, though delay'd, answer'd must be,
127
+ Yet so they mourn, becoming of their woe,
128
+ Since saucy jacks so happy are in this,
129
+ All this the world well knows; yet none knows well
130
+ And yet, by heaven, I think my love as rare
131
+ In nothing art thou black save in thy deeds,
132
+ Then will I swear beauty herself is black
133
+ And yet thou wilt; for I, being pent in thee,
134
+ Him have I lost; thou hast both him and me:
135
+ Let no unkind, no fair beseechers kill;
136
+ Make but my name thy love, and love that still,
137
+ In things right true my heart and eyes have erred,
138
+ Therefore I lie with her and she with me,
139
+ Yet do not so; but since I am near slain,
140
+ That I may not be so, nor thou belied,
141
+ Only my plague thus far I count my gain,
142
+ If thou dost seek to have what thou dost hide,
143
+ So will I pray that thou mayst have thy 'Will,'
144
+ Yet this shall I ne'er know, but live in doubt,
145
+ 'I hate' from hate away she threw,
146
+ So shalt thou feed on Death, that feeds on men,
147
+ For I have sworn thee fair and thought thee bright,
148
+ O cunning Love! with tears thou keep'st me blind,
149
+ But, love, hate on, for now I know thy mind;
150
+ If thy unworthiness raised love in me,
151
+ No want of conscience hold it that I call
152
+ For I have sworn thee fair; more perjured I,
153
+ But found no cure: the bath for my help lies
154
+ Came there for cure, and this by that I prove,
naacl-2021-fudge-controlled-generation/poetry_util.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import string
2
+
3
+ import pronouncing
4
+ from Phyme import Phyme
5
+ phyme = Phyme()
6
+
7
+ from constants import *
8
+
9
+ def is_iambic(phrase):
10
+ """
11
+ check that we satisfy iambic meter.
12
+ return 1 if so, otherwise 0.
13
+ definitely an imperfect check...
14
+ if we end up needing to check a word that's not in the CMU dictionary, just return 0.
15
+ """
16
+ meter = ''
17
+ for word in phrase.split():
18
+ word = word.strip().strip(string.punctuation).lower()
19
+ try:
20
+ phones_list = pronouncing.phones_for_word(word)
21
+ stresses = pronouncing.stresses(phones_list[0])
22
+ if len(stresses) == 1:
23
+ if stresses == '1':
24
+ stresses = '2' # allow ambiguity for 1-syllable words with stress 1
25
+ meter += stresses # just default to the first pronunciation if > 1 given
26
+ except:
27
+ return 0 # word not found
28
+ meter = [int(x) for x in meter]
29
+ even_stresses_full = [meter[i] for i in range(0, len(meter), 2)]
30
+ odd_stresses_full = [meter[i] for i in range(1, len(meter), 2)]
31
+ even_stresses = set(even_stresses_full)
32
+ odd_stresses = set(odd_stresses_full)
33
+ if 0 in odd_stresses:
34
+ return 0
35
+ if 1 in even_stresses:
36
+ return 0
37
+ return 1
38
+
39
+
40
+ def count_syllables(words):
41
+ syllables = 0
42
+ for word in words.split():
43
+ word = word.strip().strip(string.punctuation)
44
+ try:
45
+ phones_list = pronouncing.phones_for_word(word)
46
+ stresses = pronouncing.stresses(phones_list[0])
47
+ syllables += min(MAX_SYLLABLES_PER_WORD, len(stresses))
48
+ except:
49
+ # if we don't know, just do a quick approximation here; it shouldn't come up too often
50
+ syllables += min(MAX_SYLLABLES_PER_WORD, round(len(word) / 3))
51
+ return syllables
52
+
53
+
54
+ def get_rhymes(word):
55
+ # throws exception if word not in the rhyme dict (rare)
56
+ rhymes = []
57
+ rhyme_dict = phyme.get_perfect_rhymes(word)
58
+ for length_dict in rhyme_dict.values():
59
+ for word in length_dict:
60
+ if '(' in word: # sometimes you have stuff like preferred(1) where they indicate a particular pronunciation
61
+ rhymes.append(word.split('(')[0])
62
+ else:
63
+ rhymes.append(word)
64
+ return sorted(list(set(rhymes)))
65
+
66
+
67
+ def get_rhyme_group(word):
68
+ sorted_rhyme_list = get_rhymes(word)
69
+ return ' '.join(sorted_rhyme_list)
70
+
71
+
72
+ def perfect_rhyme_end(s1, s2):
73
+ ending_word1 = s1.split()[-1].strip(string.punctuation)
74
+ ending_word2 = s2.split()[-1].strip(string.punctuation)
75
+ try:
76
+ return get_rhyme_group(ending_word1) == get_rhyme_group(ending_word2)
77
+ except:
78
+ return False # unknown words
79
+
80
+ if __name__=='__main__':
81
+ result = is_iambic('Shall I compare thee to a summer day')
82
+ result2 = count_syllables('Shall I compare thee to a summer day')
83
+ import pdb; pdb.set_trace()
naacl-2021-fudge-controlled-generation/predict_clickbait.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import time
4
+ import pickle
5
+ import math
6
+ from argparse import ArgumentParser
7
+
8
+ from typing import Iterable, List, Optional, Tuple
9
+
10
+ from tqdm import tqdm
11
+ import numpy as np
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+ from transformers import AutoTokenizer, AutoModelWithLMHead
16
+ from torch import Tensor
17
+
18
+ from data import Dataset
19
+ from model import Model
20
+ from util import num_params
21
+ from constants import *
22
+
23
+
24
+
25
+ tokenizer = AutoTokenizer.from_pretrained('google/pegasus-xsum')
26
+ classifier_tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-mpnet-base-v2')
27
+
28
+
29
+ def main(args):
30
+ with open(args.dataset_info, 'rb') as rf:
31
+ dataset_info = pickle.load(rf)
32
+
33
+ article_content = """Australian actor Guy Pearce will return for the iconic soap Neighbours finale on August 1 to reprise his role as Mike Young.
34
+ Guy, 54, played the troubled Mike from 1986 to 1989, and is now set to make a comeback on the show after 33 years, Metro.co.uk reports.
35
+ The star's character arcs explored the implications of domestic abuse, student-teacher relationships and dealing with loss of loved ones.
36
+ Speaking to Metro.co.uk, Guy said: 'It is very exciting and surreal at the same time being back on set again, however it feels like coming home.
37
+ 'It's where it all started for me professionally. I've been asked to come back on occasions over the years and wondered if it was the right thing
38
+ to do, but once I knew the show was finishing, I knew I had to do it.'He added that there is 'nothing like being here all together again'
39
+ , even though he's had a chance to catch-up with other cast members."""
40
+
41
+ tokenizer.add_special_tokens({'pad_token': PAD_TOKEN})
42
+ pad_id = tokenizer.encode(PAD_TOKEN)[0]
43
+
44
+ #For loading Clickbait summarizer
45
+ model = AutoModelWithLMHead.from_pretrained(args.model_string, return_dict=True).to(args.device)
46
+
47
+ model.eval()
48
+
49
+ checkpoint = torch.load(args.ckpt, map_location=args.device)
50
+ model_args = checkpoint['args']
51
+ conditioning_model = Model(model_args, pad_id, len(dataset_info.index2word)) # no need to get the glove embeddings when reloading since they're saved in model ckpt anyway
52
+ conditioning_model.load_state_dict(checkpoint['state_dict'])
53
+ conditioning_model = conditioning_model.to(args.device)
54
+ conditioning_model.eval()
55
+ print("=> loaded checkpoint '{}' (epoch {})"
56
+ .format(args.ckpt, checkpoint['epoch']))
57
+ print('num params', num_params(conditioning_model))
58
+
59
+ while True:
60
+ results = generate_clickbait(model,
61
+ tokenizer,
62
+ conditioning_model,
63
+ [args.input_text],
64
+ dataset_info,
65
+ precondition_topk=args.precondition_topk,
66
+ do_sample=args.do_sample,
67
+ length_cutoff=args.length_cutoff,
68
+ condition_lambda=args.condition_lambda,
69
+ article_content=article_content,
70
+ device=args.device)
71
+ # print(results)
72
+ import pdb; pdb.set_trace()
73
+
74
+
75
+ def generate_clickbait(model,
76
+ tokenizer,
77
+ conditioning_model,
78
+ input_text,
79
+ dataset_info,
80
+ precondition_topk,
81
+ length_cutoff,
82
+ condition_lambda=1.0,
83
+ article_content=None,
84
+ device='cuda'):
85
+ with torch.no_grad():
86
+ batch_size = len(input_text)
87
+ # encoded_input_article = [tokenizer.encode(article_content, return_tensors='pt',add_special_tokens=False).to(device)] # batch x seq
88
+ max_input_length = 512
89
+ encoded_input_article = tokenizer(article_content, return_tensors='pt',add_special_tokens=False, max_length = max_input_length).to(device) # batch x seq
90
+ # encoded_input_article = torch.cat(encoded_input_article, dim=0)
91
+ # attention_mask = encoded_input_article.new_ones(encoded_input_article.shape).to(device)
92
+
93
+ # CHANGE=ko
94
+ encoded_input = tokenizer('<pad>', return_tensors='pt',add_special_tokens=False).to(device) # batch x seq
95
+ # encoded_input = tokenizer('<pad>'+ input_text[0], return_tensors='pt',add_special_tokens=False).to(device) # batch x seq
96
+ # encoded_input = torch.cat(encoded_input, dim=0)
97
+ encoded_input = encoded_input['input_ids']
98
+
99
+
100
+ lengths = torch.LongTensor([encoded_input.shape[1]]).to(device)
101
+ # lengths = 1
102
+
103
+ past = None
104
+ use_cache = True
105
+
106
+ # CHANGE
107
+ # model_kwargs = {'encoder_outputs': model.get_encoder()(encoded_input_article, attention_mask=attention_mask)}
108
+ model_kwargs = {'encoder_outputs': model.get_encoder()(input_ids=encoded_input_article['input_ids'],
109
+ attention_mask=encoded_input_article['attention_mask'],
110
+ return_dict=True,
111
+ output_attentions=False,
112
+ output_hidden_states=False),
113
+ }
114
+
115
+ while lengths.max() < length_cutoff:
116
+ model_inputs = model.prepare_inputs_for_generation(
117
+ input_ids = encoded_input_article['input_ids'],
118
+ decoder_input_ids=encoded_input,
119
+ # past=past,
120
+ attention_mask=encoded_input_article['attention_mask'],
121
+ use_cache=use_cache,
122
+ **model_kwargs
123
+ )
124
+
125
+ outputs = model(**model_inputs, return_dict=True)
126
+ logits = outputs.logits[:, -1, :]
127
+
128
+ if "past_key_values" in outputs:
129
+ model_kwargs["past"] = outputs.past_key_values
130
+
131
+ # logits = model(encoded_input)[0][:, -1, :] # batch x vocab
132
+ top_logits, top_indices = logits.topk(precondition_topk, dim=1) # batch x topk
133
+ new_input_candidates = torch.cat([encoded_input.unsqueeze(1).expand(-1, precondition_topk, -1), top_indices.unsqueeze(2)], dim=2) # batch x topk x seq+1
134
+ expanded_lengths = (lengths + 1).unsqueeze(1).expand(batch_size, precondition_topk) # batch x topk
135
+
136
+ if condition_lambda == 0:
137
+ condition_logits = torch.zeros_like(top_logits).float()
138
+ condition_logits = condition_logits.view(batch_size, precondition_topk, -1) # batch x topk x N
139
+ else:
140
+ decoded_outputs = tokenizer.batch_decode(new_input_candidates.view(-1, new_input_candidates.size(-1)), clean_up_tokenization_spaces=False)
141
+ resulting_tokenization = classifier_tokenizer(decoded_outputs, add_special_tokens=False, padding='longest')
142
+ encoded_with_classifier = resulting_tokenization['input_ids']
143
+ attention_mask = torch.tensor(resulting_tokenization['attention_mask']).to(model.device)
144
+ tplus1_candidates_classifier = torch.tensor(encoded_with_classifier).view(batch_size, precondition_topk, -1).to(model.device)
145
+
146
+ condition_logits = conditioning_model(tplus1_candidates_classifier.flatten(0, 1), # batch*topk x seq+1
147
+ expanded_lengths.flatten(0, 1), # batch*topk
148
+ None,
149
+ None,
150
+ None,
151
+ attention_mask=attention_mask
152
+ )
153
+ condition_logits = condition_logits.view(batch_size, precondition_topk, -1) # batch x topk x N
154
+ condition_logits = condition_logits - torch.log(1 + torch.exp(condition_logits)) # get correct log probs
155
+
156
+ condition_logits = torch.mean(condition_logits, dim=2)
157
+ full_logits = top_logits + condition_logits * condition_lambda # batch x topk
158
+ post_logits, post_indices = full_logits.topk(precondition_topk, dim=1)
159
+ post_probs = F.softmax(post_logits, dim=1)
160
+ # index_into_top_indices = post_indices[torch.arange(batch_size).to(post_indices.device), torch.multinomial(post_probs, 1).flatten()] # batch
161
+ index_into_top_indices = post_indices[:, torch.multinomial(post_probs, 1).flatten()] # batch
162
+
163
+ # next_indices = top_indices[torch.arange(batch_size).to(top_indices.device), index_into_top_indices] # batch
164
+ next_indices = top_indices[:, index_into_top_indices] # batch
165
+
166
+ # encoded_input = torch.cat([encoded_input, next_indices.unsqueeze(1)], dim=1) # batch x seq+1
167
+ encoded_input = torch.cat([encoded_input, next_indices.squeeze(1)], dim=1)
168
+ lengths = lengths + 1 # batch
169
+
170
+ # print(tokenizer.decode(encoded_input[0], add_special_tokens=False))
171
+ return [tokenizer.decode(s) for s in encoded_input]
172
+
173
+
174
+ if __name__=='__main__':
175
+ parser = ArgumentParser()
176
+
177
+ # DATA
178
+ parser.add_argument('--ckpt', type=str, required=True)
179
+ parser.add_argument('--dataset_info', type=str, required=True, help='saved dataset info')
180
+ parser.add_argument('--model_string', type=str, default='Helsinki-NLP/opus-mt-es-en')
181
+
182
+ parser.add_argument('--in_file', type=str, default=None, required=True, help='text to run pred on')
183
+
184
+ parser.add_argument('--precondition_topk', type=int, default=200, help='consider top k outputs from text generation at each step before conditioning and re-pruning')
185
+ parser.add_argument('--do_sample', action='store_true', default=False, help='sample instead of greedy')
186
+ parser.add_argument('--condition_lambda', type=float, default=1.0, help='lambda weight on conditioning model')
187
+ parser.add_argument('--length_cutoff', type=int, default=512, help='max length')
188
+
189
+ parser.add_argument('--seed', type=int, default=1, help='random seed')
190
+ parser.add_argument('--device', type=str, default='cuda', choices=['cpu', 'cuda'])
191
+ parser.add_argument('--debug', action='store_true', default=False)
192
+
193
+ args = parser.parse_args()
194
+
195
+ random.seed(args.seed)
196
+ np.random.seed(args.seed)
197
+ torch.manual_seed(args.seed)
198
+
199
+ main(args)
naacl-2021-fudge-controlled-generation/predict_formality.py ADDED
@@ -0,0 +1,404 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import time
4
+ import pickle
5
+ import math
6
+ from argparse import ArgumentParser
7
+
8
+ from typing import Iterable, List, Optional, Tuple
9
+
10
+ from tqdm import tqdm
11
+ import numpy as np
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+ from transformers import AutoTokenizer, AutoModelWithLMHead, pipeline, set_seed, GPT2Tokenizer, GPT2Model, MarianTokenizer, MarianMTModel
16
+ from torch import Tensor
17
+
18
+ from data import Dataset
19
+ from model import Model
20
+ from util import save_checkpoint, ProgressMeter, AverageMeter, num_params
21
+ from constants import *
22
+
23
+ def main(args):
24
+ with open(args.dataset_info, 'rb') as rf:
25
+ dataset_info = pickle.load(rf)
26
+ tokenizer = MarianTokenizer.from_pretrained(args.model_string)
27
+ tokenizer.add_special_tokens({'pad_token': PAD_TOKEN})
28
+ pad_id = tokenizer.encode(PAD_TOKEN)[0]
29
+ model = MarianMTModel.from_pretrained(args.model_string, return_dict=True).to(args.device)
30
+ model.eval()
31
+
32
+ checkpoint = torch.load(args.ckpt, map_location=args.device)
33
+ model_args = checkpoint['args']
34
+ conditioning_model = Model(model_args, pad_id, len(dataset_info.index2word)) # no need to get the glove embeddings when reloading since they're saved in model ckpt anyway
35
+ conditioning_model.load_state_dict(checkpoint['state_dict'])
36
+ conditioning_model = conditioning_model.to(args.device)
37
+ conditioning_model.eval()
38
+ print("=> loaded checkpoint '{}' (epoch {})"
39
+ .format(args.ckpt, checkpoint['epoch']))
40
+ print('num params', num_params(conditioning_model))
41
+
42
+ while True:
43
+ results = predict_formality(model,
44
+ tokenizer,
45
+ conditioning_model,
46
+ [args.input_text],
47
+ dataset_info,
48
+ precondition_topk=args.precondition_topk,
49
+ do_sample=args.do_sample,
50
+ length_cutoff=args.length_cutoff,
51
+ condition_lambda=args.condition_lambda,
52
+ device=args.device)
53
+ print(results)
54
+ import pdb; pdb.set_trace()
55
+
56
+
57
+ def predict_formality(model, tokenizer, conditioning_model, input_text, dataset_info, precondition_topk=200, do_sample=False, length_cutoff=512, condition_lambda=1.0, device='cuda'):
58
+ with torch.no_grad():
59
+ batch_size = len(input_text)
60
+
61
+ # assumes initially all same length.
62
+ # encode every x_i i \in [seq] word to respectable embedding
63
+ encoded_input = [tokenizer.encode(it, return_tensors='pt').to(device) for it in input_text] # batch x seq
64
+ encoded_input = torch.cat(encoded_input, dim=0)
65
+
66
+ input_ids = torch.LongTensor([[58100]]).to(device)
67
+ cur_len = 1
68
+ max_length = length_cutoff
69
+ min_length = 0
70
+ temperature = 1.0
71
+ top_k = 50
72
+ top_p = 1.0
73
+ repetition_penalty = 1.0
74
+ no_repeat_ngram_size = 0
75
+ bad_words_ids = [[58100]]
76
+ pad_token_id = 58100
77
+ eos_token_id = 0
78
+ effective_batch_size = batch_size
79
+ attention_mask = encoded_input.new_ones(encoded_input.shape)
80
+ use_cache = True
81
+ model_specific_kwargs = {'encoder_outputs': model.get_encoder()(encoded_input, attention_mask=attention_mask)}
82
+
83
+ output = _generate_no_beam_search(model,
84
+ conditioning_model,
85
+ condition_lambda,
86
+ precondition_topk,
87
+ input_ids,
88
+ cur_len,
89
+ max_length,
90
+ min_length,
91
+ do_sample,
92
+ temperature,
93
+ top_k,
94
+ top_p,
95
+ repetition_penalty,
96
+ no_repeat_ngram_size,
97
+ bad_words_ids,
98
+ pad_token_id,
99
+ eos_token_id,
100
+ batch_size,
101
+ attention_mask,
102
+ use_cache,
103
+ model_specific_kwargs)
104
+
105
+ return [tokenizer.decode(s[1:]) for s in output] # 1: to delete the pad token
106
+
107
+
108
+ # hack of code from transformers/generation_utils.py
109
+ # to get our conditioning
110
+ def postprocess_next_token_scores(
111
+ model,
112
+ scores,
113
+ input_ids,
114
+ no_repeat_ngram_size,
115
+ bad_words_ids,
116
+ cur_len,
117
+ min_length,
118
+ max_length,
119
+ eos_token_id,
120
+ repetition_penalty,
121
+ batch_size,
122
+ num_beams,
123
+ ):
124
+ # repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)
125
+ if repetition_penalty != 1.0:
126
+ model.enforce_repetition_penalty_(
127
+ scores,
128
+ batch_size,
129
+ num_beams,
130
+ input_ids,
131
+ repetition_penalty,
132
+ )
133
+
134
+ # set eos token prob to zero if min_length is not reached
135
+ if eos_token_id is not None and cur_len < min_length:
136
+ scores[:, eos_token_id] = -float("inf")
137
+
138
+ if no_repeat_ngram_size > 0:
139
+ # calculate a list of banned tokens to prevent repetitively generating the same ngrams
140
+ num_batch_hypotheses = batch_size * num_beams
141
+ # from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345
142
+ banned_batch_tokens = calc_banned_ngram_tokens(
143
+ input_ids, num_batch_hypotheses, no_repeat_ngram_size, cur_len
144
+ )
145
+ for i, banned_tokens in enumerate(banned_batch_tokens):
146
+ scores[i, banned_tokens] = -float("inf")
147
+
148
+ if bad_words_ids is not None:
149
+ # Exclude EOS token (already processed)
150
+ bad_words_ids = list(filter(lambda bad_token_seq: bad_token_seq != [eos_token_id], bad_words_ids))
151
+ # calculate a list of banned tokens according to bad words
152
+ banned_tokens = calc_banned_bad_words_ids(input_ids.tolist(), bad_words_ids)
153
+ # Modify the scores in place by setting the banned tokens logits to `-inf`
154
+ set_scores_to_inf_for_banned_tokens(scores, banned_tokens)
155
+
156
+ return scores
157
+
158
+ def calc_banned_ngram_tokens(prev_input_ids: Tensor, num_hypos: int, no_repeat_ngram_size: int, cur_len: int) -> None:
159
+ """Copied from fairseq for no_repeat_ngram in beam_search"""
160
+ if cur_len + 1 < no_repeat_ngram_size:
161
+ # return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
162
+ return [[] for _ in range(num_hypos)]
163
+ generated_ngrams = [{} for _ in range(num_hypos)]
164
+ for idx in range(num_hypos):
165
+ gen_tokens = prev_input_ids[idx].tolist()
166
+ generated_ngram = generated_ngrams[idx]
167
+ for ngram in zip(*[gen_tokens[i:] for i in range(no_repeat_ngram_size)]):
168
+ prev_ngram_tuple = tuple(ngram[:-1])
169
+ generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]]
170
+
171
+ def _get_generated_ngrams(hypo_idx):
172
+ # Before decoding the next token, prevent decoding of ngrams that have already appeared
173
+ start_idx = cur_len + 1 - no_repeat_ngram_size
174
+ ngram_idx = tuple(prev_input_ids[hypo_idx, start_idx:cur_len].tolist())
175
+ return generated_ngrams[hypo_idx].get(ngram_idx, [])
176
+
177
+ banned_tokens = [_get_generated_ngrams(hypo_idx) for hypo_idx in range(num_hypos)]
178
+ return banned_tokens
179
+
180
+
181
+ def calc_banned_bad_words_ids(prev_input_ids: Iterable[int], bad_words_ids: Iterable[int]) -> Iterable[int]:
182
+ banned_tokens = []
183
+
184
+ def _tokens_match(prev_tokens, tokens):
185
+ if len(tokens) == 0:
186
+ # if bad word tokens is just one token always ban it
187
+ return True
188
+ if len(tokens) > len(prev_tokens):
189
+ # if bad word tokens are longer than prev tokens they can't be equal
190
+ return False
191
+
192
+ if prev_tokens[-len(tokens) :] == tokens:
193
+ # if tokens match
194
+ return True
195
+ else:
196
+ return False
197
+
198
+ for prev_input_ids_slice in prev_input_ids:
199
+ banned_tokens_slice = []
200
+
201
+ for banned_token_seq in bad_words_ids:
202
+ assert len(banned_token_seq) > 0, "Banned words token sequences {} cannot have an empty list".format(
203
+ bad_words_ids
204
+ )
205
+
206
+ if _tokens_match(prev_input_ids_slice, banned_token_seq[:-1]) is False:
207
+ # if tokens do not match continue
208
+ continue
209
+
210
+ banned_tokens_slice.append(banned_token_seq[-1])
211
+
212
+ banned_tokens.append(banned_tokens_slice)
213
+
214
+ return banned_tokens
215
+
216
+ def set_scores_to_inf_for_banned_tokens(scores: torch.Tensor, banned_tokens: List[List[int]]) -> None:
217
+ """Modifies the scores in place by setting the banned token positions to `-inf`. Banned token is expected to be
218
+ a list of list of banned tokens to ban in the format [[batch index, vocabulary position],...]
219
+ Args:
220
+ scores: logits distribution of shape (batch size, vocabulary size)
221
+ banned_tokens: list of list of tokens to ban of length (batch_size)
222
+ """
223
+ banned_mask_list = []
224
+ for idx, batch_banned_tokens in enumerate(banned_tokens):
225
+ for token in batch_banned_tokens:
226
+ banned_mask_list.append([idx, token])
227
+ if not banned_mask_list:
228
+ return
229
+ banned_mask = torch.LongTensor(banned_mask_list)
230
+ indices = torch.ones(len(banned_mask))
231
+ # A sparse tensor is generated from a list of coordinates: [[0, 1], [0, 2], [2, 0]]. A conversion to dense tensor generates:
232
+ # [ 0 1 1 ]
233
+ # [ 0 0 0 ]
234
+ # [ 1 0 0 ]
235
+
236
+ banned_mask = torch.sparse.LongTensor(banned_mask.t(), indices, scores.size()).to(scores.device).to_dense().bool()
237
+ scores.masked_fill_(banned_mask, -float("inf"))
238
+
239
+ def _generate_no_beam_search(
240
+ model,
241
+ conditioning_model,
242
+ condition_lambda,
243
+ precondition_topk,
244
+ input_ids,
245
+ cur_len,
246
+ max_length,
247
+ min_length,
248
+ do_sample,
249
+ temperature,
250
+ top_k,
251
+ top_p,
252
+ repetition_penalty,
253
+ no_repeat_ngram_size,
254
+ bad_words_ids,
255
+ pad_token_id,
256
+ eos_token_id,
257
+ batch_size,
258
+ attention_mask,
259
+ use_cache,
260
+ model_kwargs,
261
+ ):
262
+ """Generate sequences for each example without beam search (num_beams == 1).
263
+ All returned sequence are generated independantly.
264
+ """
265
+ # length of generated sentences / unfinished sentences
266
+ unfinished_sents = input_ids.new(batch_size).fill_(1)
267
+ sent_lengths = input_ids.new(batch_size).fill_(max_length)
268
+ past = None
269
+ while cur_len < max_length:
270
+ model_inputs = model.prepare_inputs_for_generation(
271
+ input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **model_kwargs
272
+ )
273
+
274
+ outputs = model(**model_inputs, return_dict=True)
275
+ next_token_logits = outputs.logits[:, -1, :]
276
+
277
+ # scores = model.postprocess_next_token_scores(
278
+ # scores=next_token_logits,
279
+ # input_ids=input_ids,
280
+ # no_repeat_ngram_size=no_repeat_ngram_size,
281
+ # bad_words_ids=bad_words_ids,
282
+ # cur_len=cur_len,
283
+ # min_length=min_length,
284
+ # max_length=max_length,
285
+ # eos_token_id=eos_token_id,
286
+ # repetition_penalty=repetition_penalty,
287
+ # batch_size=batch_size,
288
+ # num_beams=1,
289
+ # )
290
+
291
+ scores = postprocess_next_token_scores(
292
+ model=model,
293
+ scores=next_token_logits,
294
+ input_ids=input_ids,
295
+ no_repeat_ngram_size=no_repeat_ngram_size,
296
+ bad_words_ids=bad_words_ids,
297
+ cur_len=cur_len,
298
+ min_length=min_length,
299
+ max_length=max_length,
300
+ eos_token_id=eos_token_id,
301
+ repetition_penalty=repetition_penalty,
302
+ batch_size=batch_size,
303
+ num_beams=1,
304
+ )
305
+
306
+ # if model has past, then set the past variable to speed up decoding
307
+ if "past_key_values" in outputs:
308
+ past = outputs.past_key_values
309
+ elif "mems" in outputs:
310
+ past = outputs.mems
311
+
312
+ top_logits, top_indices = scores.topk(precondition_topk, dim=1) # batch x topk
313
+ tplus1_candidates = torch.cat([input_ids.unsqueeze(1).expand(-1, precondition_topk, -1), top_indices.unsqueeze(2)], dim=2)[:, :, 1:] # batch x topk x seq+1, with pad dropped
314
+ expanded_lengths = torch.LongTensor([[cur_len for _ in range(precondition_topk)] for _ in range(batch_size)]).to(scores.device)
315
+ if condition_lambda == 0:
316
+ condition_logits = torch.zeros_like(top_logits).float()
317
+ else:
318
+ condition_logits = conditioning_model(tplus1_candidates.flatten(0, 1), # batch*topk x seq+1
319
+ expanded_lengths.flatten(0, 1), # batch*topk
320
+ None,
321
+ None,
322
+ None)
323
+ condition_logits = condition_logits.view(batch_size, precondition_topk, -1)[:, :, -1] # batch x topk of last formality pred
324
+ condition_logits = condition_logits - torch.log(1 + torch.exp(condition_logits)) # get correct log probs
325
+ # condition_logits = - torch.log(1 + torch.exp(condition_logits)) # for informal
326
+ full_logits = top_logits + condition_lambda * condition_logits
327
+ if do_sample:
328
+ raise NotImplementedError
329
+ else:
330
+ # Greedy decoding
331
+ next_token = top_indices[torch.arange(batch_size).to(top_indices.device), torch.argmax(full_logits, dim=-1)]
332
+
333
+ # if do_sample:
334
+ # # Temperature (higher temperature => more likely to sample low probability tokens)
335
+ # if temperature != 1.0:
336
+ # scores = scores / temperature
337
+ # # Top-p/top-k filtering
338
+ # next_token_logscores = top_k_top_p_filtering(scores, top_k=top_k, top_p=top_p)
339
+ # # Sample
340
+ # probs = F.softmax(next_token_logscores, dim=-1)
341
+ # next_token = torch.multinomial(probs, num_samples=1).squeeze(1)
342
+ # else:
343
+ # # Greedy decoding
344
+ # next_token = torch.argmax(next_token_logits, dim=-1)
345
+
346
+ # update generations and finished sentences
347
+ if eos_token_id is not None:
348
+ # pad finished sentences if eos_token_id exist
349
+ tokens_to_add = next_token * unfinished_sents + (pad_token_id) * (1 - unfinished_sents)
350
+ else:
351
+ tokens_to_add = next_token
352
+
353
+ # add token and increase length by one
354
+ input_ids = torch.cat([input_ids, tokens_to_add.unsqueeze(-1)], dim=-1)
355
+ cur_len = cur_len + 1
356
+
357
+ if eos_token_id is not None:
358
+ eos_in_sents = tokens_to_add == eos_token_id
359
+ # if sentence is unfinished and the token to add is eos, sent_lengths is filled with current length
360
+ is_sents_unfinished_and_token_to_add_is_eos = unfinished_sents.mul(eos_in_sents.long()).bool()
361
+ sent_lengths.masked_fill_(is_sents_unfinished_and_token_to_add_is_eos, cur_len)
362
+ # unfinished_sents is set to zero if eos in sentence
363
+ unfinished_sents.mul_((~eos_in_sents).long())
364
+
365
+ # stop when there is a </s> in each sentence, or if we exceed the maximul length
366
+ if unfinished_sents.max() == 0:
367
+ break
368
+
369
+ # extend attention_mask for new generated input if only decoder
370
+ if model.config.is_encoder_decoder is False:
371
+ attention_mask = torch.cat(
372
+ [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
373
+ )
374
+
375
+ return input_ids
376
+
377
+ if __name__=='__main__':
378
+ parser = ArgumentParser()
379
+
380
+ # DATA
381
+ parser.add_argument('--ckpt', type=str, required=True)
382
+ parser.add_argument('--dataset_info', type=str, required=True, help='saved dataset info')
383
+ parser.add_argument('--model_string', type=str, default='Helsinki-NLP/opus-mt-es-en')
384
+
385
+ parser.add_argument('--input_text', type=str, default=None, required=True, help='text to run pred on')
386
+
387
+ parser.add_argument('--precondition_topk', type=int, default=200, help='consider top k outputs from gpt at each step before conditioning and re-pruning')
388
+ parser.add_argument('--do_sample', action='store_true', default=False, help='sample instead of greedy')
389
+ parser.add_argument('--condition_lambda', type=float, default=1.0, help='lambda weight on conditioning model')
390
+ parser.add_argument('--length_cutoff', type=int, default=512, help='max length')
391
+
392
+ parser.add_argument('--seed', type=int, default=1, help='random seed')
393
+ parser.add_argument('--device', type=str, default='cuda', choices=['cpu', 'cuda'])
394
+ parser.add_argument('--debug', action='store_true', default=False)
395
+
396
+ args = parser.parse_args()
397
+
398
+ random.seed(args.seed)
399
+ np.random.seed(args.seed)
400
+ torch.manual_seed(args.seed)
401
+
402
+ main(args)
403
+
404
+
naacl-2021-fudge-controlled-generation/predict_poetry.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import time
4
+ import pickle
5
+ import math
6
+ from argparse import ArgumentParser
7
+ import string
8
+ from collections import defaultdict
9
+
10
+ from tqdm import tqdm
11
+ import numpy as np
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+ from transformers import AutoTokenizer, AutoModelWithLMHead, pipeline, set_seed, GPT2Tokenizer, GPT2Model
16
+
17
+ from data import Dataset, load_rhyme_info
18
+ from model import Model
19
+ from util import save_checkpoint, ProgressMeter, AverageMeter, num_params
20
+ from constants import *
21
+ from poetry_util import get_rhymes, count_syllables
22
+
23
+ def main(args):
24
+ with open(args.dataset_info, 'rb') as rf:
25
+ dataset_info = pickle.load(rf)
26
+ gpt_tokenizer = AutoTokenizer.from_pretrained(args.model_string)
27
+ gpt_tokenizer.add_special_tokens({'pad_token': PAD_TOKEN})
28
+ gpt_pad_id = gpt_tokenizer.encode(PAD_TOKEN)[0]
29
+ gpt_model = AutoModelWithLMHead.from_pretrained(args.model_string).to(args.device)
30
+ gpt_model.eval()
31
+
32
+ checkpoint = torch.load(args.iambic_ckpt, map_location=args.device)
33
+ model_args = checkpoint['args']
34
+ iambic_model = Model(model_args, gpt_pad_id, len(dataset_info.index2word)) # no need to get the glove embeddings when reloading since they're saved in model ckpt anyway
35
+ iambic_model.load_state_dict(checkpoint['state_dict'])
36
+ iambic_model = iambic_model.to(args.device)
37
+ iambic_model.eval()
38
+ print("=> loaded checkpoint '{}' (epoch {})"
39
+ .format(args.iambic_ckpt, checkpoint['epoch']))
40
+ print('iambic model num params', num_params(iambic_model))
41
+
42
+ with open(args.rhyme_info, 'rb') as rf:
43
+ rhyme_info = pickle.load(rf)
44
+ checkpoint = torch.load(args.rhyme_ckpt, map_location=args.device)
45
+ model_args = checkpoint['args']
46
+ rhyme_model = Model(model_args, gpt_pad_id, len(dataset_info.index2word), rhyme_group_size=len(rhyme_info.index2rhyme_group)) # no need to get the glove embeddings when reloading since they're saved in model ckpt anyway
47
+ rhyme_model.load_state_dict(checkpoint['state_dict'])
48
+ rhyme_model = rhyme_model.to(args.device)
49
+ rhyme_model.eval()
50
+ print("=> loaded checkpoint '{}' (epoch {})"
51
+ .format(args.rhyme_ckpt, checkpoint['epoch']))
52
+ print('rhyme model num params', num_params(rhyme_model))
53
+
54
+ checkpoint = torch.load(args.newline_ckpt, map_location=args.device)
55
+ model_args = checkpoint['args']
56
+ newline_model = Model(model_args, gpt_pad_id, len(dataset_info.index2word)) # no need to get the glove embeddings when reloading since they're saved in model ckpt anyway
57
+ newline_model.load_state_dict(checkpoint['state_dict'])
58
+ newline_model = newline_model.to(args.device)
59
+ newline_model.eval()
60
+ print("=> loaded checkpoint '{}' (epoch {})"
61
+ .format(args.newline_ckpt, checkpoint['epoch']))
62
+ print('iambic model num params', num_params(newline_model))
63
+
64
+ while True:
65
+ results = predict_couplet(gpt_model,
66
+ gpt_tokenizer,
67
+ iambic_model,
68
+ rhyme_model,
69
+ newline_model,
70
+ [args.input_text],
71
+ dataset_info,
72
+ rhyme_info,
73
+ args.precondition_topk,
74
+ args.topk,
75
+ condition_lambda=args.condition_lambda,
76
+ device=args.device)
77
+ for line in results:
78
+ print(line)
79
+ import pdb; pdb.set_trace()
80
+
81
+
82
+ def predict_couplet(gpt_model, gpt_tokenizer, iambic_model, rhyme_model, newline_model, input_text, dataset_info, rhyme_info, precondition_topk, postcondition_topk, condition_lambda=1.0, device='cuda'):
83
+ assert len(input_text) == 1 # only do one at a time for now
84
+ current_text = input_text[0]
85
+ current_line_text = ''
86
+ all_lines = [current_text]
87
+ ending_word = current_text.split()[-1].strip(string.punctuation)
88
+ word2rhyme_group = defaultdict(lambda: UNKNOWN_RHYME_GROUP, rhyme_info.word2rhyme_group)
89
+ rhyme_group = word2rhyme_group[ending_word]
90
+
91
+ line = predict_iambic_pentameter_line(gpt_model,
92
+ gpt_tokenizer,
93
+ iambic_model,
94
+ rhyme_model,
95
+ newline_model,
96
+ current_text,
97
+ current_line_text,
98
+ rhyme_group,
99
+ dataset_info,
100
+ rhyme_info,
101
+ precondition_topk,
102
+ postcondition_topk,
103
+ condition_lambda=condition_lambda,
104
+ device=device)
105
+ all_lines.append(line)
106
+
107
+ return all_lines
108
+
109
+
110
+ def predict_iambic_pentameter_line(gpt_model, gpt_tokenizer, iambic_model, rhyme_model, newline_model, current_text, current_line_text, rhyme_group, dataset_info, rhyme_info, precondition_topk, postcondition_topk, banned_tokens=POETRY_BANNED_TOKENS, condition_lambda=1.0, device='cuda', length_cutoff=30):
111
+ # TODO(poetry) delete banned tokens?
112
+ with torch.no_grad():
113
+ batch_size = 1
114
+
115
+ rhyme_group_index = rhyme_info.rhyme_group2index[rhyme_group]
116
+ future_words = torch.LongTensor([rhyme_group_index]).to(device) # 1
117
+ log_probs = torch.Tensor([math.log(rhyme_info.rhyme_group_counts[rhyme_group] / rhyme_info.total_rhyme_groups)]).to(device) # 1
118
+
119
+ # assumes initially all same length.
120
+ previous_encoded_text = [gpt_tokenizer.encode(it, return_tensors='pt').to(device) for it in [current_text]]
121
+ previous_enc_len = previous_encoded_text[0].shape[1]
122
+ encoded_input = [gpt_tokenizer.encode(it, return_tensors='pt').to(device) for it in [current_text + current_line_text]] # batch x seq
123
+ encoded_input = torch.cat(encoded_input, dim=0)
124
+ lengths = torch.LongTensor([encoded_input.shape[1]]).to(device)
125
+
126
+ line_syllable_count = count_syllables(current_line_text)
127
+ assert line_syllable_count < POETRY_LINE_SYLLABLES # assume we started with less than one full line
128
+ syllables_to_go = POETRY_LINE_SYLLABLES - line_syllable_count
129
+
130
+ for _ in range(length_cutoff): # really shouldn't have a line this long anyway
131
+ gpt_logits = gpt_model(encoded_input)[0][:, -1, :] # batch x vocab
132
+ gpt_logits[:, banned_tokens] = -1e8
133
+ top_logits, top_indices = gpt_logits.topk(precondition_topk, dim=1)
134
+
135
+ new_input_candidates = torch.cat([encoded_input.unsqueeze(1).expand(-1, precondition_topk, -1), top_indices.unsqueeze(2)], dim=2) # batch x topk x seq+1
136
+ expanded_lengths = (lengths + 1).unsqueeze(1).expand(batch_size, precondition_topk) # batch x topk
137
+ expanded_future_words = future_words.unsqueeze(0).unsqueeze(1).expand(batch_size, precondition_topk, -1) # batch x topk x N
138
+ candidate_syllables_to_go = []
139
+ for candidate in new_input_candidates[0]:
140
+ candidate_until_last_word_text = ' '.join(gpt_tokenizer.decode(candidate[previous_enc_len:]).split()[:-1])
141
+ candidate_syllables_to_go.append(10 - count_syllables(candidate_until_last_word_text))
142
+ # usually these are all the same, but run them all for correctness. could do more efficiently but it's not too slow anyway.
143
+ expanded_syllables_to_go = torch.LongTensor(candidate_syllables_to_go).to(device).view(1, precondition_topk)
144
+
145
+ if condition_lambda == 0:
146
+ iambic_logits = torch.zeros_like(expanded_lengths).float()
147
+ else:
148
+ # truncate prefix because we trained on single lines
149
+ iambic_logits = iambic_model(new_input_candidates[:, :, previous_enc_len:].flatten(0, 1), expanded_lengths.flatten(0, 1) - previous_enc_len, None, None, None)[:, -1] # batch*topk x seq+1 -> batch*topk
150
+ iambic_logits = iambic_logits.view(batch_size, precondition_topk)
151
+ iambic_logits = iambic_logits - torch.log(1 + torch.exp(iambic_logits))
152
+ if condition_lambda == 0:
153
+ rhyme_logits = torch.zeros_like(expanded_lengths).float()
154
+ else:
155
+ rhyme_logits = rhyme_model(new_input_candidates.flatten(0, 1), # batch*topk x seq+1
156
+ expanded_lengths.flatten(0, 1), # batch*topk
157
+ expanded_future_words.flatten(0, 1), # batch*topk x N
158
+ log_probs, # N
159
+ expanded_syllables_to_go.flatten(0, 1)) # batch*topk
160
+ rhyme_logits = rhyme_logits.view(batch_size, precondition_topk, -1) # batch x topk x N
161
+ rhyme_logits = rhyme_logits - torch.log(1 + torch.exp(rhyme_logits)) # batch x topk x N
162
+ rhyme_logits = rhyme_logits.squeeze(2) # batch x topk
163
+ if condition_lambda == 0:
164
+ newline_logits = torch.zeros_like(expanded_lengths).float()
165
+ else:
166
+ newline_logits = newline_model(new_input_candidates.flatten(0, 1), # batch*topk x seq+1
167
+ expanded_lengths.flatten(0, 1), # batch*topk
168
+ expanded_future_words.flatten(0, 1), # batch*topk x N
169
+ log_probs, # N
170
+ expanded_syllables_to_go.flatten(0, 1)) # batch*topk
171
+ newline_logits = newline_logits[:, -1].view(batch_size, precondition_topk, -1) # batch x topk x N
172
+ newline_logits = newline_logits - torch.log(1 + torch.exp(newline_logits)) # batch x topk x N
173
+ newline_logits = newline_logits.squeeze(2) # batch x topk
174
+
175
+ full_logits = top_logits + condition_lambda * iambic_logits + condition_lambda * rhyme_logits + condition_lambda * newline_logits
176
+ post_logits, post_indices = full_logits.topk(postcondition_topk, dim=1)
177
+ post_probs = F.softmax(post_logits, dim=1)
178
+ index_into_top_indices = post_indices[torch.arange(batch_size).to(post_indices.device), torch.multinomial(post_probs, 1).flatten()] # batch
179
+ next_indices = top_indices[torch.arange(batch_size).to(top_indices.device), index_into_top_indices] # batch
180
+ encoded_input = torch.cat([encoded_input, next_indices.unsqueeze(1)], dim=1) # batch x seq+1
181
+ lengths = lengths + 1
182
+ syllables_to_go = POETRY_LINE_SYLLABLES - count_syllables(gpt_tokenizer.decode(encoded_input[0][previous_enc_len:])) # if we get very unlucky with a partial word that the syllable counter doesn't recognize we might end early, but it's unlikely
183
+ if syllables_to_go <= 0 and [gpt_tokenizer.decode(s) for s in encoded_input][0][-1] in PHRASE_ENDS:
184
+ break
185
+ if syllables_to_go < 0:
186
+ # encoded_input = encoded_input[:, :-1]
187
+ break
188
+
189
+ return [gpt_tokenizer.decode(s) for s in encoded_input][0][len(current_text):]
190
+
191
+
192
+ if __name__=='__main__':
193
+ parser = ArgumentParser()
194
+
195
+ # DATA
196
+ parser.add_argument('--iambic_ckpt', type=str, required=True)
197
+ parser.add_argument('--rhyme_ckpt', type=str, required=True)
198
+ parser.add_argument('--newline_ckpt', type=str, required=True)
199
+ parser.add_argument('--dataset_info', type=str, required=True, help='saved dataset info')
200
+ parser.add_argument('--rhyme_info', type=str, required=True, help='saved rhyme info')
201
+ parser.add_argument('--model_string', type=str, default='gpt2-medium')
202
+
203
+ parser.add_argument('--input_text', type=str, default=None, required=True, help='initial text')
204
+
205
+ parser.add_argument('--precondition_topk', type=int, default=200, help='consider top k outputs from gpt at each step before conditioning and re-pruning')
206
+ parser.add_argument('--topk', type=int, default=10, help='consider top k outputs from gpt at each step')
207
+ parser.add_argument('--condition_lambda', type=float, default=1.0, help='lambda weight on conditioning model')
208
+
209
+ parser.add_argument('--seed', type=int, default=1, help='random seed')
210
+ parser.add_argument('--device', type=str, default='cuda', choices=['cpu', 'cuda'])
211
+ parser.add_argument('--debug', action='store_true', default=False)
212
+
213
+ args = parser.parse_args()
214
+
215
+ random.seed(args.seed)
216
+ np.random.seed(args.seed)
217
+ torch.manual_seed(args.seed)
218
+
219
+ main(args)
naacl-2021-fudge-controlled-generation/predict_topic.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import time
4
+ import pickle
5
+ import math
6
+ from argparse import ArgumentParser
7
+
8
+ from tqdm import tqdm
9
+ import numpy as np
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ from transformers import AutoTokenizer, AutoModelWithLMHead, pipeline, set_seed, GPT2Tokenizer, GPT2Model
14
+
15
+ from data import Dataset
16
+ from model import Model
17
+ from util import save_checkpoint, ProgressMeter, AverageMeter, num_params
18
+ from constants import *
19
+
20
+ def main(args):
21
+ with open(args.dataset_info, 'rb') as rf:
22
+ dataset_info = pickle.load(rf)
23
+ for cw in args.condition_words.split():
24
+ assert cw in dataset_info.word2index
25
+ gpt_tokenizer = AutoTokenizer.from_pretrained(args.model_string)
26
+ gpt_tokenizer.add_special_tokens({'pad_token': PAD_TOKEN})
27
+ gpt_pad_id = gpt_tokenizer.encode(PAD_TOKEN)[0]
28
+ gpt_model = AutoModelWithLMHead.from_pretrained(args.model_string).to(args.device)
29
+ gpt_model.eval()
30
+
31
+ checkpoint = torch.load(args.ckpt, map_location=args.device)
32
+ model_args = checkpoint['args']
33
+ conditioning_model = Model(model_args, gpt_pad_id, len(dataset_info.index2word)) # no need to get the glove embeddings when reloading since they're saved in model ckpt anyway
34
+ conditioning_model.load_state_dict(checkpoint['state_dict'])
35
+ conditioning_model = conditioning_model.to(args.device)
36
+ conditioning_model.eval()
37
+ print("=> loaded checkpoint '{}' (epoch {})"
38
+ .format(args.ckpt, checkpoint['epoch']))
39
+ print('num params', num_params(conditioning_model))
40
+
41
+ while True:
42
+ results = predict(gpt_model,
43
+ gpt_tokenizer,
44
+ conditioning_model,
45
+ [args.input_text],
46
+ args.condition_words,
47
+ dataset_info,
48
+ args.precondition_topk,
49
+ args.topk,
50
+ args.length_cutoff,
51
+ condition_lambda=args.condition_lambda,
52
+ device=args.device)
53
+ print(results)
54
+ import pdb; pdb.set_trace()
55
+
56
+ def predict(gpt_model, gpt_tokenizer, conditioning_model, input_text, condition_words, dataset_info, precondition_topk, postcondition_topk, length_cutoff, condition_lambda=1.0, device='cuda'):
57
+ with torch.no_grad():
58
+ batch_size = len(input_text)
59
+
60
+ condition_words = condition_words.split()
61
+ future_words = torch.LongTensor([dataset_info.word2index[cw] for cw in condition_words]).to(device) # N
62
+ log_probs = torch.Tensor([math.log(dataset_info.vocab[cw] / dataset_info.total_words) for cw in condition_words]).to(device) # N
63
+
64
+ # assumes initially all same length.
65
+ encoded_input = [gpt_tokenizer.encode(it, return_tensors='pt').to(device) for it in input_text] # batch x seq
66
+ encoded_input = torch.cat(encoded_input, dim=0)
67
+ lengths = torch.LongTensor([encoded_input.shape[1]]).to(device)
68
+
69
+ gpt_encoded_future_words = [gpt_tokenizer.encode(' ' + cw, return_tensors='pt')[0].to(device) for cw in condition_words]
70
+ while lengths.max() < length_cutoff:
71
+ tokens_left = torch.LongTensor([length_cutoff - lengths.max() for _ in range(batch_size)]).to(device)
72
+ gpt_logits = gpt_model(encoded_input)[0][:, -1, :] # batch x vocab
73
+ top_logits, top_indices = gpt_logits.topk(precondition_topk, dim=1) # batch x topk
74
+ new_input_candidates = torch.cat([encoded_input.unsqueeze(1).expand(-1, precondition_topk, -1), top_indices.unsqueeze(2)], dim=2) # batch x topk x seq+1
75
+ expanded_lengths = (lengths + 1).unsqueeze(1).expand(batch_size, precondition_topk) # batch x topk
76
+ expanded_future_words = future_words.unsqueeze(0).unsqueeze(1).expand(batch_size, precondition_topk, -1) # batch x topk x N
77
+ expanded_tokens_left = tokens_left.unsqueeze(1).expand(-1, precondition_topk) # batch x topk
78
+ if condition_lambda == 0:
79
+ condition_logits = torch.zeros_like(expanded_future_words).float()
80
+ else:
81
+ condition_logits = conditioning_model(new_input_candidates.flatten(0, 1), # batch*topk x seq+1
82
+ expanded_lengths.flatten(0, 1), # batch*topk
83
+ expanded_future_words.flatten(0, 1), # batch*topk x N
84
+ log_probs, # N
85
+ expanded_tokens_left.flatten(0, 1)) # batch*topk
86
+ condition_logits = condition_logits.view(batch_size, precondition_topk, -1) # batch x topk x N
87
+ condition_logits = condition_logits - torch.log(1 + torch.exp(condition_logits)) # get correct log probs
88
+
89
+ condition_logits = torch.mean(condition_logits, dim=2)
90
+ full_logits = top_logits + condition_logits * condition_lambda # batch x topk
91
+ post_logits, post_indices = full_logits.topk(postcondition_topk, dim=1)
92
+ post_probs = F.softmax(post_logits, dim=1)
93
+ index_into_top_indices = post_indices[torch.arange(batch_size).to(post_indices.device), torch.multinomial(post_probs, 1).flatten()] # batch
94
+ next_indices = top_indices[torch.arange(batch_size).to(top_indices.device), index_into_top_indices] # batch
95
+ encoded_input = torch.cat([encoded_input, next_indices.unsqueeze(1)], dim=1) # batch x seq+1
96
+ lengths = lengths + 1 # batch
97
+ return [gpt_tokenizer.decode(s) for s in encoded_input]
98
+
99
+
100
+ if __name__=='__main__':
101
+ parser = ArgumentParser()
102
+
103
+ # DATA
104
+ parser.add_argument('--ckpt', type=str, required=True)
105
+ parser.add_argument('--dataset_info', type=str, required=True, help='saved dataset info')
106
+ parser.add_argument('--model_string', type=str, default='gpt2-medium')
107
+
108
+ parser.add_argument('--input_text', type=str, default=None, required=True, help='initial text')
109
+ parser.add_argument('--condition_words', type=str, default=None, required=True, help='word(s) to optimize for')
110
+
111
+ parser.add_argument('--precondition_topk', type=int, default=200, help='consider top k outputs from gpt at each step before conditioning and re-pruning')
112
+ parser.add_argument('--topk', type=int, default=10, help='consider top k outputs from gpt at each step')
113
+ parser.add_argument('--condition_lambda', type=float, default=1.0, help='lambda weight on conditioning model')
114
+ parser.add_argument('--length_cutoff', type=int, default=80, help='max length')
115
+
116
+ parser.add_argument('--seed', type=int, default=1, help='random seed')
117
+ parser.add_argument('--device', type=str, default='cuda', choices=['cpu', 'cuda'])
118
+ parser.add_argument('--debug', action='store_true', default=False)
119
+
120
+ args = parser.parse_args()
121
+
122
+ random.seed(args.seed)
123
+ np.random.seed(args.seed)
124
+ torch.manual_seed(args.seed)
125
+
126
+ main(args)
naacl-2021-fudge-controlled-generation/requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ Phyme==0.0.9
2
+ pronouncing==0.2.0
3
+ pytorch-lightning==1.0.6
4
+ torch==1.7.0
5
+ tqdm==4.49.0
6
+ sacrebleu==1.4.14
7
+ sacremoses==0.0.43
naacl-2021-fudge-controlled-generation/topic_data/README.md ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ `topic_prefixes.txt` contains the 20 prefixes used at test time for starting the generations.
2
+
3
+ `wordlists/` contains the wordlists for each of the 7 topics used during testing. The heldout bags used to evaluate the generalization of the topic words to other related words are in `test_wordlists/`. `val_wordlists/` contains just one extra wordlist used for tuning.
naacl-2021-fudge-controlled-generation/topic_data/test_wordlists/computers.txt ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ sailor
2
+ memories
3
+ article
4
+ phishing
5
+ crucial
6
+ interactive
7
+ capabilities
8
+ ISP
9
+ query
10
+ signal
11
+ computation
12
+ detect
13
+ compiling
14
+ workstation
15
+ barcode
16
+ XP
17
+ cake
18
+ counterfeiting
19
+ decimal
20
+ back-up
21
+ reasoning
22
+ DSL
23
+ C++
24
+ DVD
25
+ Frequently
26
+ wifi
27
+ deleting
28
+ paper
29
+ DNS
30
+ CyanogenMod
31
+ overflow
32
+ Android
33
+ latency
34
+ creating
35
+ redirect
36
+ sites
37
+ sidebar
38
+ Jacket
39
+ prev
40
+ connections
41
+ PDF
42
+ torrent
43
+ original
44
+ gmail
45
+ rename
46
+ coder
47
+ mainboard
48
+ parasite
49
+ casing
50
+ lurks
51
+ pixels
52
+ touchpad
53
+ update
54
+ visuals
55
+ encyclopedia
56
+ mice
57
+ Solaris
58
+ caching
59
+ copies
60
+ usb
61
+ chew
62
+ fixes
63
+ house
64
+ operand
65
+ input
66
+ pull
67
+ iterative
68
+ educational
69
+ autocomplete
70
+ on-line
71
+ confidentiality
72
+ decrypt
73
+ beach
74
+ mails
75
+ rectangular
76
+ jQuery
77
+ Excel
78
+ point-in-time
79
+ Ubuntu
80
+ decryption
81
+ dialup
82
+ profit
83
+ off-line
84
+ developing
85
+ choice
86
+ notebook
87
+ storing
88
+ typeface
89
+ little
90
+ customer
91
+ step
92
+ text
93
+ run-time
94
+ interview
95
+ layout
96
+ computing
97
+ chairs
98
+ infected
99
+ must
100
+ tools
101
+ search
102
+ pane
103
+ gamepad
104
+ disc
105
+ initialize
106
+ display
107
+ button
108
+ Firefox
109
+ automatically
110
+ garbage
111
+ 512MB
112
+ cyber
113
+ logon
114
+ elements
115
+ restoring
116
+ writer
117
+ saving
118
+ parsing
119
+ execute
120
+ configuring
121
+ telephoto
122
+ popup
123
+ utilities
124
+ packet
125
+ pasting
126
+ guest
127
+ edit
128
+ glass
129
+ e-mail
130
+ components
131
+ binaries
132
+ subdirectory
133
+ restart
134
+ XSLT
135
+ inkjet
136
+ allows
137
+ functionality
138
+ debian
139
+ change
140
+ click
141
+ dialog
142
+ GPU
143
+ stored
144
+ attribute
145
+ deflate
146
+ cheat
147
+ direction
148
+ camera
149
+ hats
150
+ topic
151
+ journalists
152
+ taxi
153
+ console
154
+ identifier
155
+ VPN
156
+ flames
157
+ spyware
158
+ secure
159
+ shoe
160
+ Macs
161
+ php
162
+ demo
163
+ extract
naacl-2021-fudge-controlled-generation/topic_data/test_wordlists/legal.txt ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ waived
2
+ homicide
3
+ repress
4
+ statutory
5
+ sentencing
6
+ respondent
7
+ maintain
8
+ legislative
9
+ prosecution
10
+ whether
11
+ forgive
12
+ mandamus
13
+ democratic
14
+ treasurer
15
+ acquittal
16
+ offender
17
+ sued
18
+ edict
19
+ malpractice
20
+ debatable
21
+ criminal
22
+ injunctive
23
+ appellant
24
+ convicted
25
+ admit
26
+ proxies
27
+ aggrieved
28
+ enforcement
29
+ second-degree
30
+ ethical
31
+ knowing
32
+ liability
33
+ event
34
+ property
35
+ conviction
36
+ deposited
37
+ immune
38
+ assertion
39
+ assualt
40
+ regulations
41
+ exams
42
+ pixels
43
+ prosecuting
44
+ insolvent
45
+ felonies
46
+ families
47
+ mediator
48
+ rulings
49
+ heard
50
+ wrongs
51
+ wrongful
52
+ folder
53
+ federal
54
+ widget
55
+ restaurant
56
+ incarcerated
57
+ burglary
58
+ pants
59
+ land-use
60
+ quash
61
+ sitting
62
+ rescind
63
+ dispute
64
+ leave
65
+ requesting
66
+ appearing
67
+ testify
68
+ discoveries
69
+ championship
70
+ police
71
+ judgment
72
+ purchase
73
+ revelation
74
+ solicitor
75
+ disagree
76
+ judicial
77
+ reversing
78
+ jurors
79
+ decision
80
+ negligent
81
+ mutual
82
+ track
83
+ objecting
84
+ major
85
+ amendment
86
+ alleging
87
+ agreement
88
+ investment
89
+ custodial
90
+ accusation
91
+ passageways
92
+ asserted
93
+ authority
94
+ deputies
95
+ insolvency
96
+ sworn
97
+ defensive
98
+ embezzlement
99
+ disputes
100
+ findings
101
+ reservation
102
+ litem
103
+ inmates
104
+ step-by-step
105
+ innocence
106
+ parties
107
+ transcribed
108
+ inept
naacl-2021-fudge-controlled-generation/topic_data/test_wordlists/military.txt ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ team
2
+ threat
3
+ sloop
4
+ offensively
5
+ guerilla
6
+ invading
7
+ samurai
8
+ propel
9
+ sunk
10
+ concern
11
+ persuade
12
+ Maj.
13
+ wear
14
+ fatigues
15
+ subsidiary
16
+ glider
17
+ advancing
18
+ ICBM
19
+ won
20
+ cargo
21
+ groan
22
+ knowledge
23
+ proposal
24
+ terms
25
+ deputy
26
+ taken
27
+ bricks
28
+ operation
29
+ Iraq
30
+ zoning
31
+ offices
32
+ fought
33
+ detonated
34
+ adjutant
35
+ skipper
36
+ batteries
37
+ medical
38
+ strategic
39
+ armistice
40
+ rocket
41
+ enemies
42
+ tensions
43
+ forming
44
+ inundate
45
+ engaging
46
+ dormitories
47
+ flying
48
+ allies
49
+ cursor
50
+ casing
51
+ zone
52
+ scouts
53
+ stationed
54
+ pistol
55
+ paragraph
56
+ highest
57
+ tribute
58
+ strategy
59
+ pump
60
+ decoding
61
+ argue
62
+ public
63
+ policeman
64
+ lob
65
+ sword
66
+ bleeding
67
+ civilians
68
+ rifles
69
+ airmen
70
+ freedom
71
+ explosion
72
+ capturing
73
+ skirmish
74
+ conquered
75
+ frigate
76
+ armour
77
+ leaving
78
+ customer
79
+ expert
80
+ armies
81
+ aviation
82
+ armoury
83
+ rifleman
84
+ lace
85
+ khaki
86
+ barrage
87
+ civilian
88
+ secluded
89
+ casualties
90
+ injuries
91
+ academies
92
+ hires
93
+ dead
94
+ ATL
95
+ late
96
+ relinquish
97
+ naval
98
+ riflemen
99
+ seige
100
+ sonar
101
+ aboard
102
+ longtime
103
+ bottom
104
+ gatling
105
+ militia
106
+ clandestine
107
+ execute
108
+ assets
109
+ significant
110
+ personnel
111
+ escorting
112
+ manoeuvre
113
+ Sgt.
114
+ rear
115
+ shoulders
116
+ rescuing
117
+ hand-to-hand
118
+ howitzer
119
+ committee
120
+ rifle
121
+ victory
122
+ defensive
123
+ forcing
124
+ honour
125
+ companies
126
+ pirate
127
+ evacuating
128
+ sabotaging
129
+ citadel
130
+ cadre
131
+ camera
132
+ launchers
133
+ flames
134
+ encoding
135
+ visor
136
+ ship
naacl-2021-fudge-controlled-generation/topic_data/test_wordlists/politics.txt ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ credibility
2
+ Nazism
3
+ imported
4
+ remember
5
+ progressivism
6
+ legislative
7
+ communist
8
+ gender
9
+ democratic
10
+ immediate
11
+ capitalist
12
+ purchase
13
+ energy
14
+ referenda
15
+ ratify
16
+ lengthy
17
+ authorisation
18
+ aristocrats
19
+ jurisdiction
20
+ judge
21
+ socialist
22
+ excise
23
+ fascist
24
+ secondary
25
+ subsidies
26
+ autocratic
27
+ shortfall
28
+ appropriated
29
+ uphold
30
+ income
31
+ federated
32
+ federal
33
+ efforts
34
+ diplomatic
35
+ freedom
36
+ properties
37
+ ideologies
38
+ exporting
39
+ minority
40
+ cultural
naacl-2021-fudge-controlled-generation/topic_data/test_wordlists/religion.txt ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Elegant
2
+ Catholicism
3
+ Metatron
4
+ Mind
5
+ Empires
6
+ SWF
7
+ Secular
8
+ Judas
9
+ Prime
10
+ Terrier
11
+ Preview
12
+ Existence
13
+ Silent
14
+ sanctuaries
15
+ Answer
16
+ Balancing
17
+ Mutual
18
+ Constantinople
19
+ Scrolls
20
+ Network
21
+ Almighty
22
+ Attorney
23
+ Liberation
24
+ Database
25
+ Practicing
26
+ St.
27
+ Eucharist
28
+ Glorious
29
+ Catholic
30
+ Compassion
31
+ Volume
32
+ Saviour
33
+ Meditation
34
+ Testament
35
+ Morality
36
+ Heart
37
+ Aramaic
38
+ Court
39
+ Baskets
40
+ Fervor
41
+ Date
42
+ Curriculum
43
+ Liberal
44
+ Creativity
45
+ Everlasting
46
+ PDF
47
+ Rev.
48
+ Thank
49
+ Nanak
50
+ Dangerous
51
+ Shari'a
52
+ Policy
53
+ Talmud
54
+ Best
55
+ Supply
56
+ Oneness
57
+ Punishment
58
+ Reincarnation
59
+ TransCanada
60
+ Forums
61
+ VoIP
62
+ Factors
63
+ Assistance
64
+ Charities
65
+ Calculator
66
+ Shadows
67
+ Him
68
+ Natural
69
+ Lamp
70
+ Thyme
71
+ Templar
72
+ Muhammad
73
+ Venue
74
+ Hell
75
+ Bunyan
76
+ Songs
77
+ Epistle
78
+ Suites
79
+ Economic
80
+ Intel
81
+ Spanish
82
+ Lives
83
+ Married
84
+ Hypothesis
85
+ Cosmic
86
+ Injunction
87
+ Involvement
88
+ Leviticus
89
+ Self
90
+ Truth
91
+ Mystical
92
+ Melody
93
+ Pure
94
+ Sermon
95
+ Atlantic
96
+ Excel
97
+ Sonata
98
+ SPCA
99
+ Saturday
100
+ Adventure
101
+ Honour
102
+ Resurrection
103
+ Emanuel
104
+ Connery
105
+ Rites
106
+ United
107
+ Pope
108
+ Mary
109
+ Chen
110
+ Lisa
111
+ ODST
112
+ Videos
113
+ Modernity
114
+ Sculpture
115
+ Jewish
116
+ Heavy
117
+ Remote
118
+ Praise
119
+ Foods
120
+ Merrell
121
+ Safety
122
+ Influencing
123
+ Tie
124
+ Outreach
125
+ Kenichi
126
+ Criminal
127
+ Stevie
128
+ Judgement
129
+ SQL
130
+ Basilica
131
+ Piano
132
+ Reiki
133
+ Understanding
134
+ Cognition
135
+ Maker
136
+ Diocese
137
+ Marital
138
+ Masjid
139
+ Militant
140
+ Methodist
141
+ Political
142
+ Appeals
143
+ Deities
144
+ Purchase
145
+ Rallies
146
+ Testing
147
+ Contemporary
148
+ Help
149
+ Sweet
150
+ Fallen
151
+ Spangled
152
+ Renewable
153
+ Laughter
154
+ Provider
155
+ Charitable
156
+ Ethical
157
+ Families
158
+ Cure
159
+ Significance
160
+ Communities
161
+ Cost
162
+ Demon
163
+ Motivation
164
+ Calvary
165
+ Double
166
+ Mysteries
167
+ Determining
168
+ Baptist
169
+ Mandir
170
+ Qi
171
+ Loss
172
+ Lust
173
+ Echoes
174
+ Lord
175
+ Vote
176
+ Glad
177
+ Dharma
178
+ Kombat
179
+ Prostitute
180
+ Wetlands
181
+ Queries
182
+ Always
183
+ Focus
184
+ EOS
185
+ Worship
186
+ Implications
187
+ Wiccan
188
+ Invitations
189
+ Theology
190
+ Hospital
191
+ Freedom
192
+ Mirror
193
+ Uncharted
194
+ Radiance
195
+ Serving
196
+ Buddhist
197
+ Kiss
198
+ Mother
199
+ Death
200
+ Episcopal
201
+ Impact
202
+ Shinto
203
+ Crisis
204
+ Secure
205
+ Learning
206
+ Dreams
207
+ Association
naacl-2021-fudge-controlled-generation/topic_data/test_wordlists/science.txt ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ astronomical
2
+ evolved
3
+ tests
4
+ reason
5
+ idea
6
+ component
7
+ jug
8
+ rain
9
+ renewable
10
+ scaling
11
+ phone
12
+ action
13
+ studies
14
+ humidity
15
+ siphon
16
+ warming
17
+ compounds
18
+ genomics
19
+ electrons
20
+ mathematics
21
+ clinical
22
+ physiology
23
+ hypotheses
24
+ stored
25
+ statutes
26
+ magnesium
27
+ measuring
28
+ fuels
29
+ scientific
30
+ bone
31
+ molecular
32
+ microscopy
33
+ observing
34
+ parameter
35
+ transition
36
+ system
37
+ bacterium
38
+ ligand
39
+ increasing
40
+ theories
41
+ physicist
42
+ flow
43
+ pounds
44
+ nothing
45
+ observatory
46
+ gravitational
47
+ electron
naacl-2021-fudge-controlled-generation/topic_data/test_wordlists/space.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ cosmos
2
+ mothership
3
+ flyby
4
+ broadband
5
+ aeronautics
6
+ fireball
7
+ Romulan
8
+ room
9
+ cosmonaut
10
+ actress
11
+ worlds
12
+ heavens
13
+ lunar
14
+ interstellar
15
+ galaxies
16
+ lander
naacl-2021-fudge-controlled-generation/topic_data/topic_prefixes.txt ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ In summary
2
+ This essay discusses
3
+ Views on
4
+ The connection
5
+ Foundational to this is
6
+ To review,
7
+ In brief,
8
+ An illustration of
9
+ Furthermore,
10
+ The central theme
11
+ To conclude,
12
+ The key aspect
13
+ Prior to this
14
+ Emphasised are
15
+ To summarise
16
+ The relationship
17
+ More importantly,
18
+ It has been shown
19
+ The issue focused on
20
+ In this essay
naacl-2021-fudge-controlled-generation/topic_data/val_wordlists/fantasy.txt ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ beast
2
+ Cerberus
3
+ demon
4
+ dragon
5
+ fairy
6
+ Frankenstein
7
+ ghost
8
+ Godzilla
9
+ giant
10
+ horror
11
+ hydra
12
+ imp
13
+ monster
14
+ mummy
15
+ ogre
16
+ orc
17
+ savage
18
+ spirit
19
+ sprite
20
+ titan
21
+ troll
22
+ undead
23
+ unicorn
24
+ vampire
25
+ witch
26
+ zombie
naacl-2021-fudge-controlled-generation/topic_data/wordlists/computers.txt ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ algorithm
2
+ analog
3
+ app
4
+ application
5
+ array
6
+ backup
7
+ bandwidth
8
+ binary
9
+ bit
10
+ bite
11
+ blog
12
+ blogger
13
+ bookmark
14
+ boot
15
+ broadband
16
+ browser
17
+ buffer
18
+ bug
19
+ bus
20
+ byte
21
+ cache
22
+ caps
23
+ captcha
24
+ CD
25
+ client
26
+ command
27
+ compile
28
+ compress
29
+ computer
30
+ configure
31
+ cookie
32
+ copy
33
+ CPU
34
+ dashboard
35
+ data
36
+ database
37
+ debug
38
+ delete
39
+ desktop
40
+ development
41
+ digital
42
+ disk
43
+ document
44
+ domain
45
+ dot
46
+ download
47
+ drag
48
+ dynamic
49
+ email
50
+ encrypt
51
+ encryption
52
+ enter
53
+ FAQ
54
+ file
55
+ firewall
56
+ firmware
57
+ flaming
58
+ flash
59
+ folder
60
+ font
61
+ format
62
+ frame
63
+ graphics
64
+ hack
65
+ hacker
66
+ hardware
67
+ home
68
+ host
69
+ html
70
+ icon
71
+ inbox
72
+ integer
73
+ interface
74
+ Internet
75
+ IP
76
+ iteration
77
+ Java
78
+ joystick
79
+ kernel
80
+ key
81
+ keyboard
82
+ keyword
83
+ laptop
84
+ link
85
+ Linux
86
+ logic
87
+ login
88
+ lurking
89
+ Macintosh
90
+ macro
91
+ malware
92
+ media
93
+ memory
94
+ mirror
95
+ modem
96
+ monitor
97
+ motherboard
98
+ mouse
99
+ multimedia
100
+ net
101
+ network
102
+ node
103
+ offline
104
+ online
105
+ OS
106
+ option
107
+ output
108
+ page
109
+ password
110
+ paste
111
+ path
112
+ piracy
113
+ pirate
114
+ platform
115
+ podcast
116
+ portal
117
+ print
118
+ printer
119
+ privacy
120
+ process
121
+ program
122
+ programmer
123
+ protocol
124
+ RAM
125
+ reboot
126
+ resolution
127
+ restore
128
+ ROM
129
+ root
130
+ router
131
+ runtime
132
+ save
133
+ scan
134
+ scanner
135
+ screen
136
+ screenshot
137
+ script
138
+ scroll
139
+ security
140
+ server
141
+ shell
142
+ shift
143
+ snapshot
144
+ software
145
+ spam
146
+ spreadsheet
147
+ storage
148
+ surf
149
+ syntax
150
+ table
151
+ tag
152
+ template
153
+ thread
154
+ toolbar
155
+ trash
156
+ undo
157
+ Unix
158
+ upload
159
+ URL
160
+ user
161
+ UI
162
+ username
163
+ utility
164
+ version
165
+ virtual
166
+ virus
167
+ web
168
+ website
169
+ widget
170
+ wiki
171
+ window
172
+ Windows
173
+ wireless
174
+ worm
175
+ XML
176
+ Zip
naacl-2021-fudge-controlled-generation/topic_data/wordlists/legal.txt ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ affidavit
2
+ allegation
3
+ appeal
4
+ appearance
5
+ argument
6
+ arrest
7
+ assault
8
+ attorney
9
+ bail
10
+ bankrupt
11
+ bankruptcy
12
+ bar
13
+ bench
14
+ warrant
15
+ bond
16
+ booking
17
+ capital
18
+ crime
19
+ case
20
+ chambers
21
+ claim
22
+ complainant
23
+ complaint
24
+ confess
25
+ confession
26
+ constitution
27
+ constitutional
28
+ contract
29
+ counsel
30
+ court
31
+ custody
32
+ damages
33
+ decree
34
+ defendant
35
+ defense
36
+ deposition
37
+ discovery
38
+ equity
39
+ estate
40
+ ethics
41
+ evidence
42
+ examination
43
+ family
44
+ law
45
+ felony
46
+ file
47
+ fraud
48
+ grievance
49
+ guardian
50
+ guilty
51
+ hearing
52
+ immunity
53
+ incarceration
54
+ incompetent
55
+ indictment
56
+ injunction
57
+ innocent
58
+ instructions
59
+ jail
60
+ judge
61
+ judiciary
62
+ jurisdiction
63
+ jury
64
+ justice
65
+ law
66
+ lawsuit
67
+ lawyer
68
+ legal
69
+ legislation
70
+ liable
71
+ litigation
72
+ manslaughter
73
+ mediation
74
+ minor
75
+ misdemeanor
76
+ moot
77
+ murder
78
+ negligence
79
+ oath
80
+ objection
81
+ opinion
82
+ order
83
+ ordinance
84
+ pardon
85
+ parole
86
+ party
87
+ perjury
88
+ petition
89
+ plaintiff
90
+ plea
91
+ precedent
92
+ prison
93
+ probation
94
+ prosecute
95
+ prosecutor
96
+ proxy
97
+ record
98
+ redress
99
+ resolution
100
+ reverse
101
+ revoke
102
+ robbery
103
+ rules
104
+ sentence
105
+ settlement
106
+ sheriff
107
+ sidebar
108
+ standing
109
+ state
110
+ statute
111
+ stay
112
+ subpoena
113
+ suit
114
+ suppress
115
+ sustain
116
+ testimony
117
+ theft
118
+ title
119
+ tort
120
+ transcript
121
+ trial
122
+ trust
123
+ trustee
124
+ venue
125
+ verdict
126
+ waiver
127
+ warrant
128
+ will
129
+ witness
130
+ writ
131
+ zoning
naacl-2021-fudge-controlled-generation/topic_data/wordlists/military.txt ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ academy
2
+ advance
3
+ aircraft
4
+ ally
5
+ ammo
6
+ ammunition
7
+ armor
8
+ arms
9
+ army
10
+ arrow
11
+ arsenal
12
+ artillery
13
+ attack
14
+ attention
15
+ ballistic
16
+ barracks
17
+ base
18
+ battalion
19
+ battery
20
+ battle
21
+ battlefield
22
+ bomb
23
+ bombard
24
+ bombardment
25
+ brig
26
+ brigade
27
+ bullet
28
+ camouflage
29
+ camp
30
+ cannon
31
+ captain
32
+ capture
33
+ carrier
34
+ casualty
35
+ catapult
36
+ cavalry
37
+ colonel
38
+ combat
39
+ command
40
+ commander
41
+ commission
42
+ company
43
+ conflict
44
+ conquest
45
+ convoy
46
+ corps
47
+ covert
48
+ crew
49
+ decode
50
+ defeat
51
+ defend
52
+ defense
53
+ destroyer
54
+ division
55
+ draft
56
+ encode
57
+ enemy
58
+ engage
59
+ enlist
60
+ evacuate
61
+ explosive
62
+ fight
63
+ fire
64
+ fleet
65
+ force
66
+ formation
67
+ fort
68
+ front
69
+ garrison
70
+ general
71
+ grenade
72
+ grunt
73
+ guerrilla
74
+ gun
75
+ headquarters
76
+ helmet
77
+ honor
78
+ hospital
79
+ infantry
80
+ injury
81
+ intelligence
82
+ invade
83
+ invasion
84
+ jet
85
+ kill
86
+ leave
87
+ lieutenant
88
+ major
89
+ maneuver
90
+ marines
91
+ MIA
92
+ mid
93
+ military
94
+ mine
95
+ missile
96
+ mortar
97
+ navy
98
+ neutral
99
+ offense
100
+ officer
101
+ ordinance
102
+ parachute
103
+ peace
104
+ plane
105
+ platoon
106
+ private
107
+ radar
108
+ rank
109
+ recruit
110
+ regiment
111
+ rescue
112
+ reserves
113
+ retreat
114
+ ribbon
115
+ sabotage
116
+ sailor
117
+ salute
118
+ section
119
+ sergeant
120
+ service
121
+ shell
122
+ shoot
123
+ shot
124
+ siege
125
+ sniper
126
+ soldier
127
+ spear
128
+ specialist
129
+ squad
130
+ squadron
131
+ staff
132
+ submarine
133
+ surrender
134
+ tactical
135
+ tactics
136
+ tank
137
+ torpedo
138
+ troops
139
+ truce
140
+ uniform
141
+ unit
142
+ veteran
143
+ volley
144
+ war
145
+ warfare
146
+ warrior
147
+ weapon
148
+ win
149
+ wound
naacl-2021-fudge-controlled-generation/topic_data/wordlists/politics.txt ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ affirm
2
+ appropriation
3
+ aristocracy
4
+ authoritarian
5
+ authority
6
+ authorization
7
+ brief
8
+ capitalism
9
+ communism
10
+ constitution
11
+ conservatism
12
+ court
13
+ deficit
14
+ diplomacy
15
+ direct
16
+ democracy
17
+ equality
18
+ exports
19
+ fascism
20
+ federation
21
+ government
22
+ ideology
23
+ imports
24
+ initiative
25
+ legislature
26
+ legitimacy
27
+ liberalism
28
+ liberty
29
+ majority
30
+ order
31
+ political
32
+ culture
33
+ politics
34
+ power
35
+ primary
36
+ property
37
+ ratification
38
+ recall
39
+ referendum
40
+ republic
41
+ socialism
42
+ state
43
+ subsidy
44
+ tariff
45
+ imports
46
+ tax
47
+ totalitarian
naacl-2021-fudge-controlled-generation/topic_data/wordlists/religion.txt ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absolute
2
+ affect
3
+ aid
4
+ angel
5
+ anthem
6
+ apostle
7
+ archangel
8
+ Archbishop
9
+ balance
10
+ ban
11
+ belief
12
+ benefit
13
+ Bible
14
+ bishop
15
+ bless
16
+ blessing
17
+ bliss
18
+ bond
19
+ bow
20
+ Buddhism
21
+ canon
22
+ Cantor
23
+ cathedral
24
+ celestial
25
+ chapel
26
+ charity
27
+ choice
28
+ Christianity
29
+ church
30
+ comfort
31
+ community
32
+ conflict
33
+ connection
34
+ conquest
35
+ conservative
36
+ control
37
+ conversion
38
+ convert
39
+ core
40
+ counsel
41
+ courage
42
+ Covenant
43
+ creative
44
+ Creator
45
+ creed
46
+ cross
47
+ Crusade
48
+ Darkness
49
+ decision
50
+ deity
51
+ destiny
52
+ Devil
53
+ disciple
54
+ discipline
55
+ discussion
56
+ divine
57
+ divinity
58
+ doctrine
59
+ duty
60
+ effect
61
+ elder
62
+ energy
63
+ essence
64
+ eternal
65
+ ethics
66
+ event
67
+ evidence
68
+ exile
69
+ Exodus
70
+ faith
71
+ family
72
+ fate
73
+ Father
74
+ favor
75
+ fundamental
76
+ gift
77
+ glory
78
+ God
79
+ gospel
80
+ grace
81
+ growth
82
+ guru
83
+ habit
84
+ hallow
85
+ halo
86
+ happiness
87
+ harmony
88
+ healing
89
+ Heaven
90
+ Hebrew
91
+ holy
92
+ honor
93
+ hope
94
+ host
95
+ humane
96
+ immortal
97
+ influence
98
+ insight
99
+ instruction
100
+ issue
101
+ Jesuit
102
+ Jesus
103
+ joy
104
+ Judaism
105
+ judgment
106
+ justice
107
+ karma
108
+ keen
109
+ Keystone
110
+ Kingdom
111
+ Latin
112
+ life
113
+ light
114
+ love
115
+ loving
116
+ marriage
117
+ meaning
118
+ mercy
119
+ Messiah
120
+ minister
121
+ miracle
122
+ mission
123
+ mortal
124
+ mosque
125
+ movement
126
+ music
127
+ mystery
128
+ nature
129
+ nun
130
+ official
131
+ oracle
132
+ order
133
+ organ
134
+ Orthodox
135
+ outlook
136
+ pacific
137
+ pagan
138
+ parish
139
+ participation
140
+ pastor
141
+ patriarch
142
+ peace
143
+ perception
144
+ personal
145
+ perspective
146
+ petition
147
+ pilgrim
148
+ politics
149
+ power
150
+ practice
151
+ prayer
152
+ prelude
153
+ presence
154
+ priest
155
+ principle
156
+ privacy
157
+ prophet
158
+ protection
159
+ purpose
160
+ query
161
+ quest
162
+ question
163
+ quiet
164
+ radiant
165
+ radical
166
+ rally
167
+ rebirth
168
+ redemption
169
+ refuge
170
+ relationship
171
+ relative
172
+ religion
173
+ religious
174
+ Revelation
175
+ ritual
176
+ role
177
+ Sacrament
178
+ sacred
179
+ sacrifice
180
+ sage
181
+ saint
182
+ salvation
183
+ sanctuary
184
+ savior
185
+ scripture
186
+ scriptures
187
+ sect
188
+ security
189
+ sense
190
+ serious
191
+ serve
192
+ service
193
+ Sharia
194
+ shepherd
195
+ shrine
196
+ silence
197
+ sin
198
+ society
199
+ soul
200
+ source
201
+ spirit
202
+ spiritual
203
+ split
204
+ statue
205
+ Sunday
206
+ support
207
+ Supreme
208
+ teaching
209
+ temple
210
+ tests
211
+ text
212
+ Torah
213
+ tradition
214
+ traditional
215
+ trust
216
+ unique
217
+ unity
218
+ unknown
219
+ value
220
+ vanity
221
+ virtue
222
+ vision
223
+ voice
224
+ voices
225
+ watch
226
+ weight
227
+ whole
228
+ wisdom
229
+ wonder
230
+ yang
231
+ yin
232
+ zeal
naacl-2021-fudge-controlled-generation/topic_data/wordlists/science.txt ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ astronomy
2
+ atom
3
+ biology
4
+ cell
5
+ chemical
6
+ chemistry
7
+ climate
8
+ control
9
+ data
10
+ electricity
11
+ element
12
+ energy
13
+ evolution
14
+ experiment
15
+ fact
16
+ flask
17
+ fossil
18
+ funnel
19
+ genetics
20
+ gravity
21
+ hypothesis
22
+ lab
23
+ laboratory
24
+ laws
25
+ mass
26
+ matter
27
+ measure
28
+ microscope
29
+ mineral
30
+ molecule
31
+ motion
32
+ observe
33
+ organism
34
+ particle
35
+ phase
36
+ physics
37
+ research
38
+ scale
39
+ science
40
+ scientist
41
+ telescope
42
+ temperature
43
+ theory
44
+ tissue
45
+ variable
46
+ volume
47
+ weather
48
+ weigh
naacl-2021-fudge-controlled-generation/topic_data/wordlists/space.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ planet
2
+ galaxy
3
+ space
4
+ universe
5
+ orbit
6
+ spacecraft
7
+ earth
8
+ moon
9
+ comet
10
+ star
11
+ astronaut
12
+ aerospace
13
+ asteroid
14
+ spaceship
15
+ starship
16
+ galactic
17
+ satellite
18
+ meteor
naacl-2021-fudge-controlled-generation/transcript.txt ADDED
@@ -0,0 +1,415 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ (Sorry, the slide numbers got a bit misaligned as I added slides. Not an exact transcript for the video but roughly correct.)
2
+
3
+ 1:
4
+ Hi! I'm Kevin from UC Berkeley, and today I'll be presenting my paper FUDGE: Controlled Text Generation with Future Discriminators, by me and my advisor Dan Klein.
5
+
6
+ 2:
7
+ So first a quick overview.
8
+
9
+ 3:
10
+ I'll start by explaining the problem of controlled text generation with some examples,
11
+
12
+ 4:
13
+ then describe our method, FUDGE, Future Discriminators for Generation,
14
+
15
+ 5:
16
+ and in doing so I'll also show experimental results and example model outputs on three diverse controlled generation tasks.
17
+
18
+ 6:
19
+ So what's controlled text generation?
20
+
21
+ 7:
22
+ Well let's start with our autoregressive language model that we use for text generation, without the controlled part.
23
+
24
+ 8:
25
+ The language model models a distribution over next tokens x i+1 given the prefix x1 to x i. for example, you might tell it to
26
+
27
+ 9:
28
+ Generate text according to a prompt like
29
+
30
+ 9:
31
+ THIS, the issue focused on.
32
+
33
+ 10:
34
+ and then it'll chug along and generate text,
35
+
36
+ 11:
37
+ and these days language models are pretty good.
38
+ But in controlled generation, you have an additional attribute constraint
39
+
40
+ 12:
41
+ like wanting the output to be about politics.
42
+
43
+ 13:
44
+ Specifically, we have an attribute function a(X) which says whether or not the attribute a is true for your output X, in this case whether or not the output is on topic. There's no probabilities involved in a(X) since it operates on the completed generation output, not on partial sequences.
45
+
46
+ 14:
47
+ More precisely, the task of controlled text generation is to sample from the distribution P(X given a = True), so the distribution of outputs X which satisfy a.
48
+
49
+ 15:
50
+ By default the language model isn't equipped to handle this additional constraint a, so its output is not going to pass.
51
+ So we need a method for *controlled* text generation.
52
+
53
+ 15:
54
+ For example, our method FUDGE.
55
+
56
+ 16:
57
+ Given the same prompt with the politics topic,
58
+
59
+ 16:
60
+ Here's what FUDGE says.
61
+
62
+ 17:
63
+ It worked pretty well in this example. It's talking about institutions and constitutions, which seems clearly on topic.
64
+
65
+ 18:
66
+ And I'll point out here that controlled generation makes sense in addition to the usual conditioning on the input that you might see in translation or summarization.
67
+
68
+ 19:
69
+ Say we're translating Spanish to English. There's input conditioning on the original Spanish, but we're also imposing the additional constraint that the output be formal, which is where controlled text generation comes in.
70
+
71
+ 20:
72
+ So say we have this Spanish input
73
+
74
+ 20:
75
+ and let me just move it to the corner so we can still see it
76
+
77
+ 20:
78
+ If you ask your off-the-shelf translation model it'll get the meaning right,
79
+ but it copies some ungrammatical parts of the original Spanish
80
+
81
+ 21:
82
+ like these repeated words in bold.
83
+
84
+ 22:
85
+ So at the end when we ask our formality classifier,
86
+
87
+ 23:
88
+ it might not be super happy.
89
+
90
+ 24:
91
+ But if you use a controlled text generation approach like FUDGE,
92
+
93
+ 25:
94
+ You can get this translation which preserves the meaning, while also better matching the formal style constraint.
95
+
96
+ 26:
97
+
98
+
99
+ 27:
100
+ You might wonder, why don't we just do rejection sampling?
101
+
102
+ 28:
103
+ Just sample a bunch of times from the translator
104
+
105
+ 29:
106
+ until you get one that passes.
107
+ That might work for some simpler constraints like topics, but it's going to be totally intractable when you use constraints that are very rarely satisfied by your generator distribution.
108
+
109
+ 30:
110
+ What are some more difficult attribute constraints?
111
+
112
+ 31:
113
+ Consider this task, effectively, complete the poem.
114
+
115
+ 32:
116
+ Let’s see what the language model says when we give it this input from Shakespeare. And even thence thou wilt be stol'n I fear
117
+
118
+ 32:
119
+ and thou art a good friend of mine. The king's guard.
120
+
121
+ 33:
122
+ This is terrible! It doesn't roll off the tongue, it doesn't rhyme, it doesn't even end the sentence properly at the end.
123
+
124
+ 34:
125
+ Shakespeare hates it. You could generate any number of poems using your language model, and Shakespeare is gonna hate every last one.
126
+
127
+ 35:
128
+ But if you ask FUDGE, you get this. And even thence thou wilt be stol'n I fear, for this shall be the end. That's pretty clear.
129
+
130
+ 36:
131
+ So it's not Shakespeare, but it gets the meter, or rhythm right, it rhymes, and it ends the sentence in about the right place at the end. Not too bad.
132
+
133
+ 37:
134
+ Ok. So how does controlled generation work anyway? Let me give an incredibly oversimplified summary of some ideas in this line of work to put FUDGE in context.
135
+
136
+ 38:
137
+ First, you can finetune.
138
+
139
+ 39:
140
+ We'll use the politics topic as an example.
141
+
142
+ 39:
143
+ You can train on a bunch of text about politics. Depending on how good your data is, this can work great! or it could be rather bad. It also might be annoying to have to finetune again next time, when you want to write about science instead.
144
+
145
+ 40:
146
+ Another idea is to use a classifier.
147
+
148
+ 41:
149
+ We're already using a classifier to evaluate.
150
+
151
+ 42:
152
+ We can use a classifier to help us generate too. There's many different ways to do this.
153
+
154
+ 43:
155
+ For example, you might propagate gradients to modify the model's activations,
156
+
157
+ 44:
158
+ or you could just directly modify the model's output probabilities. One advantage of the latter method is that you don't need to access the original language model's gradients at all, which is nice if you're using something like GPT3. You can also swap the generator out as better models become available, like GPT4. Our approach FUDGE falls in this category of just modifying the output logits.
159
+
160
+ 45:
161
+ Ok, so what's FUDGE?
162
+
163
+ 46:
164
+ FUDGE at its core learns a lightweight classifier for the attribute constraint, and then follows a Bayesian factorization to combine it with the original generator, like the pretrained language model.
165
+
166
+ 47:
167
+ A key difference from prior work is that we plan for the future, not the immediate present.
168
+
169
+ 48:
170
+ And finally, FUDGE can easily and flexibly compose multiple constraints.
171
+
172
+ 49:
173
+ Let's start with the classifier and Bayesian factorization.
174
+
175
+ 50:
176
+ Since FUDGE builds off the base language model, let's review:
177
+
178
+ 51:
179
+ You feed whatever tokens you have so far
180
+
181
+ 52:
182
+ into your model,
183
+
184
+ 53:
185
+ which models the distribution over possible next tokens.
186
+
187
+ 54:
188
+ And then you sample from this distribution to pick your continuation.
189
+
190
+ 55:
191
+ Now, we completely ignored the formal style constraint.
192
+
193
+ 56:
194
+ So it's gonna be unhappy.
195
+
196
+ 57:
197
+ So what do you want to do instead?
198
+
199
+ 58:
200
+ Well, what you really want is to use your classifier to judge continuations.
201
+
202
+ 59:
203
+ and mark which ones are acceptable given your constraint. So the classifier looks at each possible next continuation Do you want, Do you prefer, Do you thus, and so on maybe up to some limit, and judges each one individually to decide which it's ok with.
204
+
205
+ 60:
206
+ So putting it together, we throw out whatever the classifier didn't like,
207
+
208
+ 61:
209
+ and then we select from whatever the classifier is ok with depending on the base generator's probabilities.
210
+ And this gets you "Do you prefer" instead of "Do you want"
211
+
212
+ 62:
213
+ which sounds a bit more formal.
214
+
215
+ 63:
216
+ But there's a subtle problem in this diagram.
217
+ The classifier is supposed to judge the finished sentence, not the prefixes,
218
+
219
+ 64:
220
+ but here we've shoved it into our generation procedure where it's gonna operate on prefixes.
221
+ What we actually need is
222
+
223
+ 65:
224
+ kind of a future looking crystal ball version of the classifier, which judges whether the whole sentence will eventually be formal, given the current prefix.
225
+
226
+ 65:
227
+ And in practice, we implement the judge as a learned binary classifier, which runs on each possible continuation, and for each one outputs the probability that in the end the desired attribute a will be True, or in this case whether the finished sentence would be formal, given just the current prefix plus next token.
228
+ So in the red table, this 0.2 by "want" means it thinks that there's a 20% chance that the eventual sentence would be formal if we started with Do you want, whereas it assigns a much higher probability for Do you prefer and Do you thus because those are more formal.
229
+
230
+ 68:
231
+ And then we sample proportionally from the probabilities in the purple table,
232
+ which are now just the elementwise product of the blue and red tables' probabilities.
233
+ This corresponds exactly to a Bayesian factorization for the probability distribution over sentences generated by the language model that possess the desired attribute, and you can check the math in the paper.
234
+ But the Bayesian motivation is not new.
235
+
236
+ 70:
237
+ What's really new in FUDGE is that we explicitly distinguish the final classifier from the crystal ball future-predicting version that we use during the generation procedure, and making this distinction is critical for performance.
238
+
239
+ 71:
240
+ Let's see FUDGE in action.
241
+
242
+ 72:
243
+ if you recall our Spanish to English formal translation example.
244
+
245
+ 73:
246
+ Let's backtrack FUDGE to this step.
247
+
248
+ 74:
249
+ Again we have the repeated Spanish que que in bold, which the base model translated verbatim as that, that.
250
+
251
+ 75:
252
+ But by having our classifier judge the formality of possible continuations, FUDGE is able to modify its continuation so that it doesn't repeat the words here.
253
+
254
+ 76:
255
+ And the end result preserves the meaning while being also a bit more formal.
256
+
257
+ 77:
258
+ And finally this all holds up in our experiments. So we have a classifier trained on a heldout dataset of formality, and it indeed judges FUDGE's outputs to be significantly more formal than those of the best prior method.
259
+
260
+ 78:
261
+ At the same time, FUDGE is able to preserve the content, based on measuring BLEU against cleaned reference translations.
262
+
263
+ 79:
264
+ Ok great. So next I'll elaborate more about planning for the future vs present,
265
+
266
+ 80:
267
+ and I'll try to show more clearly *why* we really need this crystal ball classifier.
268
+
269
+ 81:
270
+ Let's go back to our politics topic constraint.
271
+
272
+ 82:
273
+ For simplicity, let's pretend just for this talk that the politics topic just means whether or not you use the word "constitution."
274
+
275
+ 83:
276
+ So the constraint that we check at the end of generation is literally just grep for constitution.
277
+
278
+ 84:
279
+ The crystal ball classifier has a much harder task. For a given prefix, it needs to predict whether each possible word makes "constitution" more likely to appear later.
280
+
281
+ 85:
282
+ So how do we learn this?
283
+
284
+ 86:
285
+ Say you have this example in your training data containing "constitution"
286
+
287
+ 87:
288
+ The crystal ball classifier takes this and makes a bunch of prefix examples, labeled with the attribute function a(X)=True because we saw those prefixes led to the word "constitution" later.
289
+
290
+ 88:
291
+ And similarly if you have this example without the word "constitution"
292
+
293
+ 89:
294
+ It'll label those prefixes as False.
295
+
296
+ 90:
297
+ Ok
298
+
299
+ 91:
300
+ So let's examine what FUDGE generates.
301
+
302
+ 92:
303
+ After a couple of steps, we have It has been shown whether the two
304
+
305
+ 93:
306
+ What if you hypothetically use the non crystal ball classifier to guide generation?
307
+
308
+ 94:
309
+ The issue focused on whether the two constitution
310
+ (pause) Maybe not. We don't really want to sacrifice fluency. But this classifier is too shortsighted. It's all or nothing, you either have to use constitution immediately or bust.
311
+
312
+ 95:
313
+ Ok
314
+
315
+ 96:
316
+ Good thing FUDGE is actually using the future looking classifier.
317
+
318
+ 97:
319
+ So instead, FUDGE is going to generate something which is still reasonably likely under the original language model, but makes constitution more likely to be generated later on. This classifier doesn't care whether constitution is generated now or later, as long as it shows up eventually.
320
+
321
+ 98:
322
+ So here it's going to write about institutions, so it's on the right topic
323
+
324
+ 99:
325
+ which eventually leads it to write about the constitution.
326
+
327
+ 100:
328
+ Great.
329
+
330
+ 101:
331
+ And indeed in our experiments, FUDGE is great according to human evaluations too. It substantially beats the best prior method in pairwise evaluations of being on topic,
332
+
333
+ 102:
334
+ while also beating it in fluency.
335
+
336
+ 103:
337
+ Cool. So I've now demonstrated the importance of planning for the future through this topic control task.
338
+ And finally, i'll highlight FUDGE's compositional potential, using a third task.
339
+
340
+ 104:
341
+ Ok.
342
+
343
+ 105:
344
+ So remember our schematic diagram where we have the judge of formality.
345
+
346
+ 106:
347
+ This works great when we have just one attribute we care about.
348
+
349
+ 107:
350
+ Now, what if you have another attribute? Maybe you want it to be formal but also about math
351
+
352
+ 108:
353
+ Now our old crystal ball classifier of just formality isn't good enough anymore.
354
+ Of course, you could construct a classifier which predicts both attributes simultaneously, but FUDGE lets you do something more scalable and also i think a bit more elegant.
355
+
356
+ 109:
357
+ Just reuse the formality predictor, while adding a second crystal ball for the math topic.
358
+ So now your generation is guided by one classifier for each constraint,
359
+
360
+ 110:
361
+ and it picks something which it thinks sounds more mathy.
362
+
363
+ 111:
364
+ So let's see this in practice.
365
+
366
+ 112:
367
+ Remember our poetry examples? where FUDGE's example isn't quite Shakespeare but is at least pretty well-formed.
368
+ This task actually uses three separate constraints:
369
+
370
+ 113:
371
+ We want iambic meter, which means that every other syllable should be a stressed syllable when we're reading it,
372
+
373
+ 114:
374
+ we want the two lines to rhyme, and since the first line is 10 syllables that means the second line should be 10 syllables too,
375
+
376
+ 115:
377
+ and the second line that we generate should end the sentence afterward too.
378
+
379
+ 116:
380
+ So let's backtrack to halfway through FUDGE's generation, before it's generated the last couple of words, pretty clear.
381
+
382
+ 117:
383
+ FUDGE is using its crystal ball poetry classifier, which is a combination of three classifiers, one for each of the three constraints.
384
+
385
+ 118:
386
+ It would be perfectly grammatical to just directly say "clear". This works for the iambic meter constraint. But this is only the 8th syllable, so you'd still have to rhyme and end a new sentence in just two more syllables.
387
+
388
+ 119:
389
+ Then we're probably back to angry Shakespeare.
390
+
391
+ 120:
392
+ So FUDGE first generates pretty
393
+
394
+ 121:
395
+ before finishing with clear and a period,
396
+
397
+ 122:
398
+ and this show how FUDGE is able to compose multiple attributes using multiple classifiers, while simultaneously planning for the future as I described previously.
399
+
400
+ 123:
401
+ Finally, if we look at the experiments, FUDGE's performance holds up, with the success rate on simultaneously satisfying all three constraints being more than double that of the best prior method.
402
+
403
+ 124:
404
+ So that wraps things up. The takeaways are that FUDGE is a simple, flexible method for controlled text generation.
405
+
406
+ 125:
407
+ To reiterate our three main points from earlier, FUDGE learns a classifier in a Bayesian factorization to guide the generation,
408
+ it plans for the future rather than the present,
409
+ and it can easily and flexibly compose different constraints as needed while maintaining strong performance.
410
+
411
+ 126:
412
+ And our code is all publicly available.
413
+
414
+ 127:
415
+ Thanks for watching! And please check out our paper for the full details.
naacl-2021-fudge-controlled-generation/util.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import sys
4
+ from contextlib import contextmanager
5
+
6
+ import torch
7
+
8
+ from constants import *
9
+
10
+ @contextmanager
11
+ def suppress_stdout():
12
+ with open(os.devnull, "w") as devnull:
13
+ old_stdout = sys.stdout
14
+ sys.stdout = devnull
15
+ try:
16
+ yield
17
+ finally:
18
+ sys.stdout = old_stdout
19
+
20
+
21
+ def save_checkpoint(state, save_path):
22
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
23
+ torch.save(state, save_path)
24
+
25
+
26
+ def freeze(module):
27
+ for param in module.parameters():
28
+ param.requires_grad = False
29
+
30
+
31
+ def num_params(model):
32
+ return sum(p.numel() for p in model.parameters() if p.requires_grad)
33
+
34
+
35
+ def clamp(x, limit):
36
+ return max(-limit, min(x, limit))
37
+
38
+
39
+ def pad_to_length(tensor, length, dim, value=0):
40
+ """
41
+ Pad tensor to given length in given dim using given value (value should be numeric)
42
+ """
43
+ assert tensor.size(dim) <= length
44
+ if tensor.size(dim) < length:
45
+ zeros_shape = list(tensor.shape)
46
+ zeros_shape[dim] = length - tensor.size(dim)
47
+ zeros_shape = tuple(zeros_shape)
48
+ return torch.cat([tensor, torch.zeros(zeros_shape).type(tensor.type()).to(tensor.device).fill_(value)], dim=dim)
49
+ else:
50
+ return tensor
51
+
52
+
53
+ def pad_mask(lengths: torch.LongTensor) -> torch.ByteTensor:
54
+ """
55
+ Create a mask of seq x batch where seq = max(lengths), with 0 in padding locations and 1 otherwise.
56
+ """
57
+ # lengths: bs. Ex: [2, 3, 1]
58
+ max_seqlen = torch.max(lengths)
59
+ expanded_lengths = lengths.unsqueeze(0).repeat((max_seqlen, 1)) # [[2, 3, 1], [2, 3, 1], [2, 3, 1]]
60
+ indices = torch.arange(max_seqlen).unsqueeze(1).repeat((1, lengths.size(0))).to(lengths.device) # [[0, 0, 0], [1, 1, 1], [2, 2, 2]]
61
+
62
+ return expanded_lengths > indices # pad locations are 0. #[[1, 1, 1], [1, 1, 0], [0, 1, 0]]. seqlen x bs
63
+
64
+
65
+ class ProgressMeter(object):
66
+ """
67
+ Display meter
68
+ """
69
+ def __init__(self, num_batches, meters, prefix=""):
70
+ self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
71
+ self.meters = meters
72
+ self.prefix = prefix
73
+
74
+ def display(self, batch):
75
+ entries = [self.prefix + self.batch_fmtstr.format(batch)]
76
+ entries.append(time.ctime(time.time()))
77
+ entries += [str(meter) for meter in self.meters]
78
+ print('\t'.join(entries))
79
+
80
+ def _get_batch_fmtstr(self, num_batches):
81
+ num_digits = len(str(num_batches // 1))
82
+ fmt = '{:' + str(num_digits) + 'd}'
83
+ return '[' + fmt + '/' + fmt.format(num_batches) + ']'
84
+
85
+
86
+ class AverageMeter(object):
87
+ """
88
+ Display meter
89
+ Computes and stores the average and current value
90
+ """
91
+ def __init__(self, name, fmt=':f'):
92
+ self.name = name
93
+ self.fmt = fmt
94
+ self.reset()
95
+
96
+ def reset(self):
97
+ self.val = 0
98
+ self.avg = 0
99
+ self.sum = 0
100
+ self.count = 0
101
+
102
+ def update(self, val, n=1):
103
+ self.val = val
104
+ self.sum += val * n
105
+ self.count += n
106
+ self.avg = self.sum / self.count
107
+
108
+ def __str__(self):
109
+ fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
110
+ return fmtstr.format(**self.__dict__)