dsvilarko commited on
Commit
c4ebaf8
1 Parent(s): 7fa27b9

Initial commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.md +5 -5
  2. app.py +119 -0
  3. checkpoint-1464/config.json +20 -0
  4. checkpoint-1464/optimizer.pt +3 -0
  5. checkpoint-1464/pytorch_model.bin +3 -0
  6. checkpoint-1464/rng_state.pth +3 -0
  7. checkpoint-1464/scheduler.pt +3 -0
  8. checkpoint-1464/special_tokens_map.json +1 -0
  9. checkpoint-1464/tokenizer.json +0 -0
  10. checkpoint-1464/tokenizer_config.json +1 -0
  11. checkpoint-1464/trainer_state.json +106 -0
  12. checkpoint-1464/training_args.bin +3 -0
  13. checkpoint-1464/vocab.txt +0 -0
  14. checkpoint-150/config.json +58 -0
  15. checkpoint-150/optimizer.pt +3 -0
  16. checkpoint-150/pytorch_model.bin +3 -0
  17. checkpoint-150/rng_state.pth +3 -0
  18. checkpoint-150/scheduler.pt +3 -0
  19. checkpoint-150/special_tokens_map.json +110 -0
  20. checkpoint-150/spiece.model +3 -0
  21. checkpoint-150/tokenizer.json +0 -0
  22. checkpoint-150/tokenizer_config.json +117 -0
  23. checkpoint-150/trainer_state.json +56 -0
  24. checkpoint-150/training_args.bin +3 -0
  25. fudge/LICENSE +21 -0
  26. fudge/README.md +155 -0
  27. fudge/clickbait_classifier.py +128 -0
  28. fudge/constants.py +32 -0
  29. fudge/data.py +415 -0
  30. fudge/eval_formality_metrics.py +73 -0
  31. fudge/eval_poetry_metrics.py +135 -0
  32. fudge/eval_topic_metrics.py +134 -0
  33. fudge/evaluate_clickbait.py +200 -0
  34. fudge/evaluate_formality.py +104 -0
  35. fudge/evaluate_poetry.py +115 -0
  36. fudge/evaluate_topic.py +143 -0
  37. fudge/formality_data/README.md +2 -0
  38. fudge/formality_data/fisher_test_oracle.es +0 -0
  39. fudge/formality_data/test.noid.cleaned_0 +0 -0
  40. fudge/formality_data/test.noid.cleaned_1 +0 -0
  41. fudge/main.py +192 -0
  42. fudge/model.py +182 -0
  43. fudge/poetry_data/README.md +1 -0
  44. fudge/poetry_data/couplet_ends.txt +154 -0
  45. fudge/poetry_data/couplet_prefixes.txt +154 -0
  46. fudge/poetry_util.py +83 -0
  47. fudge/predict_clickbait.py +199 -0
  48. fudge/predict_formality.py +404 -0
  49. fudge/predict_poetry.py +219 -0
  50. fudge/predict_topic.py +126 -0
README.md CHANGED
@@ -1,13 +1,13 @@
1
  ---
2
  title: Clickbaitonator
3
- emoji: 😻
4
- colorFrom: gray
5
- colorTo: red
6
  sdk: gradio
7
- sdk_version: 3.0.26
8
  app_file: app.py
9
  pinned: false
10
- license: gpl
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
  title: Clickbaitonator
3
+ emoji: 💩
4
+ colorFrom: purple
5
+ colorTo: yellow
6
  sdk: gradio
7
+ sdk_version: 3.0.24
8
  app_file: app.py
9
  pinned: false
10
+ license: afl-3.0
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import os
2
+
3
+ # os.chdir('naacl-2021-fudge-controlled-generation/')
4
+
5
+ import gradio as gr
6
+ from fudge.predict_clickbait import generate_clickbait, tokenizer, classifier_tokenizer
7
+ from datasets import load_dataset,DatasetDict,Dataset
8
+ # from datasets import
9
+ from transformers import AutoTokenizer,AutoModelForSeq2SeqLM
10
+ import numpy as np
11
+ from sklearn.model_selection import train_test_split
12
+ import pandas as pd
13
+ from sklearn.utils.class_weight import compute_class_weight
14
+ import torch
15
+ import pandas as pd
16
+ from fudge.model import Model
17
+ import os
18
+ from argparse import ArgumentParser
19
+ from collections import namedtuple
20
+ import mock
21
+
22
+ from tqdm import tqdm
23
+ import numpy as np
24
+ import torch.nn as nn
25
+ import torch.nn.functional as F
26
+ from fudge.data import Dataset
27
+ from fudge.util import save_checkpoint, ProgressMeter, AverageMeter, num_params
28
+ from fudge.constants import *
29
+
30
+
31
+ device = 'cpu'
32
+ # imp.reload(model)
33
+ pretrained_model = "checkpoint-150/"
34
+ generation_model = AutoModelForSeq2SeqLM.from_pretrained(pretrained_model, return_dict=True).to(device)
35
+
36
+
37
+ pad_id = 0
38
+
39
+ generation_model.eval()
40
+
41
+ model_args = mock.Mock()
42
+ model_args.task = 'clickbait'
43
+ model_args.device = device
44
+ model_args.checkpoint = 'checkpoint-1464/'
45
+
46
+ # conditioning_model = Model(model_args, pad_id, len(dataset_info.index2word)) # no need to get the glove embeddings when reloading since they're saved in model ckpt anyway
47
+ conditioning_model = Model(model_args, pad_id, vocab_size=None) # no need to get the glove embeddings when reloading since they're saved in model ckpt anyway
48
+ conditioning_model = conditioning_model.to(device)
49
+ conditioning_model.eval()
50
+
51
+ condition_lambda = 5.0
52
+ length_cutoff = 50
53
+ precondition_topk = 200
54
+
55
+
56
+ conditioning_model.classifier
57
+
58
+ model_args.checkpoint
59
+
60
+ classifier_tokenizer = AutoTokenizer.from_pretrained(model_args.checkpoint, load_best_model_at_end=True)
61
+
62
+
63
+ def rate_title(input_text, model, tokenizer, device='cuda'):
64
+ # input_text = {
65
+ # "postText": input_text['postText'],
66
+ # "truthClass" : input_text['truthClass']
67
+ # }
68
+ tokenized_input = preprocess_function_title_only_classification(input_text,tokenizer=tokenizer)
69
+ # print(tokenized_input.items())
70
+ dict_tokenized_input = {k : torch.tensor([v]).to(device) for k,v in tokenized_input.items() if k != 'labels'}
71
+ predicted_class = float(model(**dict_tokenized_input).logits)
72
+ actual_class = input_text['truthClass']
73
+
74
+ # print(predicted_class, actual_class)
75
+ return {'predicted_class' : predicted_class}
76
+
77
+ def preprocess_function_title_only_classification(examples,tokenizer=None):
78
+ model_inputs = tokenizer(examples['postText'], padding="longest", truncation=True, max_length=25)
79
+
80
+ model_inputs['labels'] = examples['truthClass']
81
+
82
+ return model_inputs
83
+
84
+
85
+
86
+ def clickbait_generator(article_content, condition_lambda=5.0):
87
+ # result = "Hi {}! 😎. The Mulitple of {} is {}".format(name, number, round(number**2, 2))
88
+ results = generate_clickbait(model=generation_model,
89
+ tokenizer=tokenizer,
90
+ conditioning_model=conditioning_model,
91
+ input_text=[None],
92
+ dataset_info=None,
93
+ precondition_topk=precondition_topk,
94
+ length_cutoff=length_cutoff,
95
+ condition_lambda=condition_lambda,
96
+ article_content=article_content,
97
+ device=device)
98
+
99
+ return results[0].replace('</s>', '').replace('<pad>', '')
100
+
101
+ title = "Clickbaitinator - Controllable Clickbait generator"
102
+ description = """
103
+ Use the [Fudge](https://github.com/yangkevin2/naacl-2021-fudge-controlled-generation) implementation fine-tuned for our purposes to try and create news headline you are looking for! Use condition_lambda to steer your clickbaitiness higher (by increasing the slider value) or lower (by decreasing the slider value). <br/>
104
+ Note that this is using two Transformers and is executed with CPU-only, so it will take a minute or two to finish generating a title.
105
+ """
106
+
107
+ article = "Check out [the codebase for our model](https://github.com/dsvilarkovic/naacl-2021-fudge-controlled-generation) that this demo is based of. You need collaborator access, which you have been probably invited for."
108
+
109
+
110
+ app = gr.Interface(
111
+ title = title,
112
+ description = description,
113
+ label = 'Article content or paragraph',
114
+ fn = clickbait_generator,
115
+ inputs=["text", gr.Slider(0, 15, step=0.1, value=5.0)],
116
+ outputs="text",
117
+ article=article,
118
+ )
119
+ app.launch()
checkpoint-1464/config.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "BertClickbaitClassifier"
4
+ ],
5
+ "dropout": 0.2,
6
+ "freeze_bert": false,
7
+ "id2label": {
8
+ "0": "LABEL_0"
9
+ },
10
+ "inner_dim1": 256,
11
+ "inner_dim2": 32,
12
+ "label2id": {
13
+ "LABEL_0": 0
14
+ },
15
+ "load_pretrained": true,
16
+ "max_length": 25,
17
+ "pretrained_model": "sentence-transformers/all-mpnet-base-v2",
18
+ "torch_dtype": "float32",
19
+ "transformers_version": "4.19.2"
20
+ }
checkpoint-1464/optimizer.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e86af40c41b9c7bd860a34891976e7ccb5cb03e35f540bac3901e768a5e90947
3
+ size 872925589
checkpoint-1464/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:881e07cb4fc93116a9f5bee91fc6048ce366725537b34edf7fdf7f243d2ba240
3
+ size 438838053
checkpoint-1464/rng_state.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:af047df93f2e21bc6b802f06e57d980fb915986b57e8908ad2e2e43065125260
3
+ size 14503
checkpoint-1464/scheduler.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c454d2b62b7deada05505a5f5f4607c60de6caa71a7d7b0a6c0c1821f97993c7
3
+ size 623
checkpoint-1464/special_tokens_map.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"bos_token": "<s>", "eos_token": "</s>", "unk_token": "[UNK]", "sep_token": "</s>", "pad_token": "<pad>", "cls_token": "<s>", "mask_token": {"content": "<mask>", "single_word": false, "lstrip": true, "rstrip": false, "normalized": false}}
checkpoint-1464/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
checkpoint-1464/tokenizer_config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"do_lower_case": true, "bos_token": "<s>", "eos_token": "</s>", "sep_token": "</s>", "cls_token": "<s>", "unk_token": "[UNK]", "pad_token": "<pad>", "mask_token": "<mask>", "tokenize_chinese_chars": true, "strip_accents": null, "model_max_length": 512, "special_tokens_map_file": null, "name_or_path": "sentence-transformers/all-mpnet-base-v2", "tokenizer_class": "MPNetTokenizer"}
checkpoint-1464/trainer_state.json ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "best_metric": 1.0750963687896729,
3
+ "best_model_checkpoint": "drive/MyDrive/nlp_lss_data/mpnet_clickbait_classification_maxlen25/checkpoint-488",
4
+ "epoch": 6.0,
5
+ "global_step": 1464,
6
+ "is_hyper_param_search": false,
7
+ "is_local_process_zero": true,
8
+ "is_world_process_zero": true,
9
+ "log_history": [
10
+ {
11
+ "epoch": 1.0,
12
+ "eval_accuracy": 0.8498845265588915,
13
+ "eval_balanced_accuracy": 0.7757283651550713,
14
+ "eval_f1": 0.670793472144063,
15
+ "eval_loss": 1.1807881593704224,
16
+ "eval_precision": 0.7146282973621103,
17
+ "eval_recall": 0.6320254506892895,
18
+ "eval_runtime": 34.3115,
19
+ "eval_samples_per_second": 113.577,
20
+ "eval_steps_per_second": 113.577,
21
+ "step": 244
22
+ },
23
+ {
24
+ "epoch": 2.0,
25
+ "eval_accuracy": 0.8545034642032333,
26
+ "eval_balanced_accuracy": 0.7932135085090511,
27
+ "eval_f1": 0.6916802610114193,
28
+ "eval_loss": 1.0750963687896729,
29
+ "eval_precision": 0.7098214285714286,
30
+ "eval_recall": 0.6744432661717922,
31
+ "eval_runtime": 33.949,
32
+ "eval_samples_per_second": 114.79,
33
+ "eval_steps_per_second": 114.79,
34
+ "step": 488
35
+ },
36
+ {
37
+ "epoch": 2.05,
38
+ "learning_rate": 3.533724340175953e-05,
39
+ "loss": 1.3113,
40
+ "step": 500
41
+ },
42
+ {
43
+ "epoch": 3.0,
44
+ "eval_accuracy": 0.8601488324352066,
45
+ "eval_balanced_accuracy": 0.7835817278869854,
46
+ "eval_f1": 0.6873207114170969,
47
+ "eval_loss": 1.1083189249038696,
48
+ "eval_precision": 0.74875,
49
+ "eval_recall": 0.6352067868504772,
50
+ "eval_runtime": 33.7352,
51
+ "eval_samples_per_second": 115.517,
52
+ "eval_steps_per_second": 115.517,
53
+ "step": 732
54
+ },
55
+ {
56
+ "epoch": 4.0,
57
+ "eval_accuracy": 0.8534770336156018,
58
+ "eval_balanced_accuracy": 0.7993947132812708,
59
+ "eval_f1": 0.6964380648591175,
60
+ "eval_loss": 1.1579805612564087,
61
+ "eval_precision": 0.6982942430703625,
62
+ "eval_recall": 0.694591728525981,
63
+ "eval_runtime": 33.9178,
64
+ "eval_samples_per_second": 114.895,
65
+ "eval_steps_per_second": 114.895,
66
+ "step": 976
67
+ },
68
+ {
69
+ "epoch": 4.1,
70
+ "learning_rate": 1.7008797653958943e-05,
71
+ "loss": 0.7869,
72
+ "step": 1000
73
+ },
74
+ {
75
+ "epoch": 5.0,
76
+ "eval_accuracy": 0.8552732871439569,
77
+ "eval_balanced_accuracy": 0.8009405080804215,
78
+ "eval_f1": 0.6993603411513859,
79
+ "eval_loss": 1.2740588188171387,
80
+ "eval_precision": 0.7031082529474812,
81
+ "eval_recall": 0.6956521739130435,
82
+ "eval_runtime": 34.1758,
83
+ "eval_samples_per_second": 114.028,
84
+ "eval_steps_per_second": 114.028,
85
+ "step": 1220
86
+ },
87
+ {
88
+ "epoch": 6.0,
89
+ "eval_accuracy": 0.8555298947908647,
90
+ "eval_balanced_accuracy": 0.793168635227608,
91
+ "eval_f1": 0.6925177498634627,
92
+ "eval_loss": 1.3905503749847412,
93
+ "eval_precision": 0.713963963963964,
94
+ "eval_recall": 0.672322375397667,
95
+ "eval_runtime": 33.4993,
96
+ "eval_samples_per_second": 116.331,
97
+ "eval_steps_per_second": 116.331,
98
+ "step": 1464
99
+ }
100
+ ],
101
+ "max_steps": 1464,
102
+ "num_train_epochs": 6,
103
+ "total_flos": 1204353585477900.0,
104
+ "trial_name": null,
105
+ "trial_params": null
106
+ }
checkpoint-1464/training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6cf148502136d195161841d56eda30be54006d58e628e1439c9393fb6404ef4a
3
+ size 3311
checkpoint-1464/vocab.txt ADDED
The diff for this file is too large to render. See raw diff
 
checkpoint-150/config.json ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "google/pegasus-xsum",
3
+ "activation_dropout": 0.1,
4
+ "activation_function": "relu",
5
+ "add_bias_logits": false,
6
+ "add_final_layer_norm": true,
7
+ "architectures": [
8
+ "PegasusForConditionalGeneration"
9
+ ],
10
+ "attention_dropout": 0.1,
11
+ "bos_token_id": 0,
12
+ "classif_dropout": 0.0,
13
+ "classifier_dropout": 0.0,
14
+ "d_model": 1024,
15
+ "decoder_attention_heads": 16,
16
+ "decoder_ffn_dim": 4096,
17
+ "decoder_layerdrop": 0.0,
18
+ "decoder_layers": 16,
19
+ "decoder_start_token_id": 0,
20
+ "do_blenderbot_90_layernorm": false,
21
+ "dropout": 0.1,
22
+ "encoder_attention_heads": 16,
23
+ "encoder_ffn_dim": 4096,
24
+ "encoder_layerdrop": 0.0,
25
+ "encoder_layers": 16,
26
+ "eos_token_id": 1,
27
+ "extra_pos_embeddings": 0,
28
+ "force_bos_token_to_be_generated": false,
29
+ "forced_eos_token_id": 1,
30
+ "gradient_checkpointing": false,
31
+ "id2label": {
32
+ "0": "LABEL_0",
33
+ "1": "LABEL_1",
34
+ "2": "LABEL_2"
35
+ },
36
+ "init_std": 0.02,
37
+ "is_encoder_decoder": true,
38
+ "label2id": {
39
+ "LABEL_0": 0,
40
+ "LABEL_1": 1,
41
+ "LABEL_2": 2
42
+ },
43
+ "length_penalty": 0.6,
44
+ "max_length": 64,
45
+ "max_position_embeddings": 512,
46
+ "model_type": "pegasus",
47
+ "normalize_before": true,
48
+ "normalize_embedding": false,
49
+ "num_beams": 8,
50
+ "num_hidden_layers": 16,
51
+ "pad_token_id": 0,
52
+ "scale_embedding": true,
53
+ "static_position_embeddings": true,
54
+ "torch_dtype": "float32",
55
+ "transformers_version": "4.20.1",
56
+ "use_cache": true,
57
+ "vocab_size": 96103
58
+ }
checkpoint-150/optimizer.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b31a813daf8949431f72c9672f50293c37c937f8239655269de86409df0a04ad
3
+ size 5839694
checkpoint-150/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6247ec114255a90ed5b84b8a94e1f9c20e3ff778c1cb853fb3758706d58deb78
3
+ size 2279605745
checkpoint-150/rng_state.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:61da4aab34859d84193f6d36e1ca6db2cab8dd2449b8ba79a7f1af61aa8c44a5
3
+ size 14503
checkpoint-150/scheduler.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e43ff59eb3184ee5df5457b7f99569ad47a950434bbef23184de1d9025687c8e
3
+ size 623
checkpoint-150/special_tokens_map.json ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<mask_1>",
4
+ "<unk_2>",
5
+ "<unk_3>",
6
+ "<unk_4>",
7
+ "<unk_5>",
8
+ "<unk_6>",
9
+ "<unk_7>",
10
+ "<unk_8>",
11
+ "<unk_9>",
12
+ "<unk_10>",
13
+ "<unk_11>",
14
+ "<unk_12>",
15
+ "<unk_13>",
16
+ "<unk_14>",
17
+ "<unk_15>",
18
+ "<unk_16>",
19
+ "<unk_17>",
20
+ "<unk_18>",
21
+ "<unk_19>",
22
+ "<unk_20>",
23
+ "<unk_21>",
24
+ "<unk_22>",
25
+ "<unk_23>",
26
+ "<unk_24>",
27
+ "<unk_25>",
28
+ "<unk_26>",
29
+ "<unk_27>",
30
+ "<unk_28>",
31
+ "<unk_29>",
32
+ "<unk_30>",
33
+ "<unk_31>",
34
+ "<unk_32>",
35
+ "<unk_33>",
36
+ "<unk_34>",
37
+ "<unk_35>",
38
+ "<unk_36>",
39
+ "<unk_37>",
40
+ "<unk_38>",
41
+ "<unk_39>",
42
+ "<unk_40>",
43
+ "<unk_41>",
44
+ "<unk_42>",
45
+ "<unk_43>",
46
+ "<unk_44>",
47
+ "<unk_45>",
48
+ "<unk_46>",
49
+ "<unk_47>",
50
+ "<unk_48>",
51
+ "<unk_49>",
52
+ "<unk_50>",
53
+ "<unk_51>",
54
+ "<unk_52>",
55
+ "<unk_53>",
56
+ "<unk_54>",
57
+ "<unk_55>",
58
+ "<unk_56>",
59
+ "<unk_57>",
60
+ "<unk_58>",
61
+ "<unk_59>",
62
+ "<unk_60>",
63
+ "<unk_61>",
64
+ "<unk_62>",
65
+ "<unk_63>",
66
+ "<unk_64>",
67
+ "<unk_65>",
68
+ "<unk_66>",
69
+ "<unk_67>",
70
+ "<unk_68>",
71
+ "<unk_69>",
72
+ "<unk_70>",
73
+ "<unk_71>",
74
+ "<unk_72>",
75
+ "<unk_73>",
76
+ "<unk_74>",
77
+ "<unk_75>",
78
+ "<unk_76>",
79
+ "<unk_77>",
80
+ "<unk_78>",
81
+ "<unk_79>",
82
+ "<unk_80>",
83
+ "<unk_81>",
84
+ "<unk_82>",
85
+ "<unk_83>",
86
+ "<unk_84>",
87
+ "<unk_85>",
88
+ "<unk_86>",
89
+ "<unk_87>",
90
+ "<unk_88>",
91
+ "<unk_89>",
92
+ "<unk_90>",
93
+ "<unk_91>",
94
+ "<unk_92>",
95
+ "<unk_93>",
96
+ "<unk_94>",
97
+ "<unk_95>",
98
+ "<unk_96>",
99
+ "<unk_97>",
100
+ "<unk_98>",
101
+ "<unk_99>",
102
+ "<unk_100>",
103
+ "<unk_101>",
104
+ "<unk_102>"
105
+ ],
106
+ "eos_token": "</s>",
107
+ "mask_token": "<mask_2>",
108
+ "pad_token": "<pad>",
109
+ "unk_token": "<unk>"
110
+ }
checkpoint-150/spiece.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0015189ef36359283fec8b93cf6d9ce51bca37eb1101defc68a53b394913b96c
3
+ size 1912529
checkpoint-150/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
checkpoint-150/tokenizer_config.json ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<mask_1>",
4
+ "<unk_2>",
5
+ "<unk_3>",
6
+ "<unk_4>",
7
+ "<unk_5>",
8
+ "<unk_6>",
9
+ "<unk_7>",
10
+ "<unk_8>",
11
+ "<unk_9>",
12
+ "<unk_10>",
13
+ "<unk_11>",
14
+ "<unk_12>",
15
+ "<unk_13>",
16
+ "<unk_14>",
17
+ "<unk_15>",
18
+ "<unk_16>",
19
+ "<unk_17>",
20
+ "<unk_18>",
21
+ "<unk_19>",
22
+ "<unk_20>",
23
+ "<unk_21>",
24
+ "<unk_22>",
25
+ "<unk_23>",
26
+ "<unk_24>",
27
+ "<unk_25>",
28
+ "<unk_26>",
29
+ "<unk_27>",
30
+ "<unk_28>",
31
+ "<unk_29>",
32
+ "<unk_30>",
33
+ "<unk_31>",
34
+ "<unk_32>",
35
+ "<unk_33>",
36
+ "<unk_34>",
37
+ "<unk_35>",
38
+ "<unk_36>",
39
+ "<unk_37>",
40
+ "<unk_38>",
41
+ "<unk_39>",
42
+ "<unk_40>",
43
+ "<unk_41>",
44
+ "<unk_42>",
45
+ "<unk_43>",
46
+ "<unk_44>",
47
+ "<unk_45>",
48
+ "<unk_46>",
49
+ "<unk_47>",
50
+ "<unk_48>",
51
+ "<unk_49>",
52
+ "<unk_50>",
53
+ "<unk_51>",
54
+ "<unk_52>",
55
+ "<unk_53>",
56
+ "<unk_54>",
57
+ "<unk_55>",
58
+ "<unk_56>",
59
+ "<unk_57>",
60
+ "<unk_58>",
61
+ "<unk_59>",
62
+ "<unk_60>",
63
+ "<unk_61>",
64
+ "<unk_62>",
65
+ "<unk_63>",
66
+ "<unk_64>",
67
+ "<unk_65>",
68
+ "<unk_66>",
69
+ "<unk_67>",
70
+ "<unk_68>",
71
+ "<unk_69>",
72
+ "<unk_70>",
73
+ "<unk_71>",
74
+ "<unk_72>",
75
+ "<unk_73>",
76
+ "<unk_74>",
77
+ "<unk_75>",
78
+ "<unk_76>",
79
+ "<unk_77>",
80
+ "<unk_78>",
81
+ "<unk_79>",
82
+ "<unk_80>",
83
+ "<unk_81>",
84
+ "<unk_82>",
85
+ "<unk_83>",
86
+ "<unk_84>",
87
+ "<unk_85>",
88
+ "<unk_86>",
89
+ "<unk_87>",
90
+ "<unk_88>",
91
+ "<unk_89>",
92
+ "<unk_90>",
93
+ "<unk_91>",
94
+ "<unk_92>",
95
+ "<unk_93>",
96
+ "<unk_94>",
97
+ "<unk_95>",
98
+ "<unk_96>",
99
+ "<unk_97>",
100
+ "<unk_98>",
101
+ "<unk_99>",
102
+ "<unk_100>",
103
+ "<unk_101>",
104
+ "<unk_102>"
105
+ ],
106
+ "eos_token": "</s>",
107
+ "full_tokenizer_file": null,
108
+ "mask_token": "<mask_2>",
109
+ "mask_token_sent": "<mask_1>",
110
+ "model_max_length": 512,
111
+ "name_or_path": "google/pegasus-xsum",
112
+ "offset": 103,
113
+ "pad_token": "<pad>",
114
+ "special_tokens_map_file": null,
115
+ "tokenizer_class": "PegasusTokenizer",
116
+ "unk_token": "<unk>"
117
+ }
checkpoint-150/trainer_state.json ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "best_metric": null,
3
+ "best_model_checkpoint": null,
4
+ "epoch": 4.982725527831094,
5
+ "global_step": 150,
6
+ "is_hyper_param_search": false,
7
+ "is_local_process_zero": true,
8
+ "is_world_process_zero": true,
9
+ "log_history": [
10
+ {
11
+ "epoch": 0.98,
12
+ "eval_loss": 2.3803367614746094,
13
+ "eval_runtime": 213.2043,
14
+ "eval_samples_per_second": 18.33,
15
+ "eval_steps_per_second": 18.33,
16
+ "step": 30
17
+ },
18
+ {
19
+ "epoch": 1.98,
20
+ "eval_loss": 2.2591161727905273,
21
+ "eval_runtime": 212.2667,
22
+ "eval_samples_per_second": 18.411,
23
+ "eval_steps_per_second": 18.411,
24
+ "step": 60
25
+ },
26
+ {
27
+ "epoch": 2.98,
28
+ "eval_loss": 2.203186511993408,
29
+ "eval_runtime": 212.608,
30
+ "eval_samples_per_second": 18.381,
31
+ "eval_steps_per_second": 18.381,
32
+ "step": 90
33
+ },
34
+ {
35
+ "epoch": 3.98,
36
+ "eval_loss": 2.1706554889678955,
37
+ "eval_runtime": 212.4689,
38
+ "eval_samples_per_second": 18.393,
39
+ "eval_steps_per_second": 18.393,
40
+ "step": 120
41
+ },
42
+ {
43
+ "epoch": 4.98,
44
+ "eval_loss": 2.1453042030334473,
45
+ "eval_runtime": 213.096,
46
+ "eval_samples_per_second": 18.339,
47
+ "eval_steps_per_second": 18.339,
48
+ "step": 150
49
+ }
50
+ ],
51
+ "max_steps": 900,
52
+ "num_train_epochs": 30,
53
+ "total_flos": 1.076296204604375e+17,
54
+ "trial_name": null,
55
+ "trial_params": null
56
+ }
checkpoint-150/training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:50f1c331773948f1aafffbec50322b3f07edf4e08988fbdb44798e3e9b3db9fd
3
+ size 3375
fudge/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2022 Kevin Yang
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
fudge/README.md ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # FUDGE: Controlled Text Generation With Future Discriminators
2
+
3
+ This repo contains code corresponding to the paper FUDGE: Controlled Text Generation With Future Discriminators (https://arxiv.org/abs/2104.05218) by Kevin Yang and Dan Klein, published at NAACL 2021.
4
+
5
+ You can also find a video presentation at http://somup.com/crhlVPFKN7 and the corresponding slides in `slides.pptx`.
6
+
7
+ ## Setup/Installation
8
+
9
+ We tested on Python 3.8.5 but earlier versions of Python 3 are almost certainly fine. To get the required packages (other versions likely to work too):
10
+
11
+ ```
12
+ pip install -r requirements.txt
13
+ ```
14
+
15
+ Additionally, to get our pre-trained predictor checkpoints and training data, run:
16
+
17
+ ```
18
+ wget https://naacl2021-fudge-files.s3.amazonaws.com/large_files.zip
19
+ ```
20
+
21
+ and extract the zip to the top-level `lm-prediction/` folder. (There should be three folders, `ckpt/`, `train_data/`, and `topic_human_evals/`. The zip is 7GB.) Note: the zip seems to not work for some people actually, if this is the case you can get the files directly from https://drive.google.com/drive/folders/1GZfOGqpQxDmIfD2RvuhUQla9eX2OHUXU?usp=sharing (13GB).
22
+
23
+ `ckpt/` contains predictor checkpoints for each task if you are just interested in running inference. (Note that for the paper results, we used predictors trained with an older version of the code, but the new checkpoints get similar results, so you are OK to use the new predictors provided here if e.g. you just want to use FUDGE as a baseline. You can just run the evaluation commands provided below; it should take maybe 5-60 minutes depending on the task and your compute, assuming you have a GPU.)
24
+
25
+ `train_data/` contains our GPT2-generated training data for the poetry and topic tasks' predictors. See https://github.com/raosudha89/GYAFC-corpus for instructions on gaining access to the GYAFC data used for the machine translation formality task; replace our dummy folders with the corresponding folders/files if you want to train our formality predictor.
26
+
27
+ ## Clickbait
28
+ To generate outputs, run:
29
+
30
+ ```
31
+ python -u evaluate_clickbait.py --ckpt ckpt/topic/future_word_predictor/model.pth.tar --dataset_info ckpt/topic/future_word_predictor/dataset_info --in_file topic_data/topic_prefixes.txt --condition_lambda 4.0 --verbose --precondition_topk 200 --length_cutoff 80 --device cpu
32
+
33
+ python -u evaluate_clickbait.py --ckpt ckpt/formality/predictor_gyafc_entertainment_music/model.pth.tar --dataset_info ckpt/formality/predictor_gyafc_entertainment_music/dataset_info --in_file formality_data/fisher_test_oracle.es
34
+
35
+ python -u evaluate_clickbait.py --ckpt ckpt/topic/future_word_predictor/model.pth.tar --dataset_info ckpt/topic/future_word_predictor/dataset_info --in_file topic_data/topic_prefixes.txt --condition_lambda 4.0 --verbose --precondition_topk 200 --sample_size 3 --max_sample_batch 1 --length_cutoff 80 --log_file clickbait_preds.log
36
+ ```
37
+
38
+ Then evaluate metrics using:
39
+
40
+ ```
41
+ python eval_topic_metrics.py --log_file topic_preds.log --tw_dir topic_data/test_wordlists
42
+ ```
43
+
44
+
45
+ ## Poetry Couplet Completion
46
+
47
+ ### Evaluation
48
+
49
+ To generate outputs, run:
50
+
51
+ ```
52
+ python -u evaluate_poetry.py --iambic_ckpt ckpt/poetry/iambic_predictor/model.pth.tar --rhyme_ckpt ckpt/poetry/rhyme_predictor/model.pth.tar --newline_ckpt ckpt/poetry/newline_predictor/model.pth.tar --dataset_info ckpt/poetry/rhyme_predictor/dataset_info --rhyme_info ckpt/poetry/rhyme_predictor/rhyme_info --prefix_file poetry_data/couplet_prefixes.txt --precondition_topk 200 > poetry_preds.log
53
+ ```
54
+
55
+ Then evaluate metrics using:
56
+
57
+ ```
58
+ python eval_poetry_metrics.py --pred_file poetry_preds.log --prefix_file poetry_data/couplet_prefixes.txt
59
+ ```
60
+
61
+ ### Training your own predictors
62
+
63
+ Example commands for all three predictors used in the poetry task below. (You actually probably don't need so many epochs for iambic and rhyme; in any case the commands will save intermediate ckpts so you can just stop them early if needed by inspecting the log.)
64
+
65
+ Iambic predictor:
66
+
67
+ ```
68
+ python -u main.py --task iambic --data_dir train_data/gpt2_generations --save_dir ckpt/poetry/iambic_retrain_predictor --num_workers 20 --batch_size 128 --epoch_max_len 100000 --validation_freq 10 --lr 2e-4 --epochs 1500 > iambic_retrain_predictor.log
69
+ ```
70
+
71
+ Rhyme predictor:
72
+
73
+ ```
74
+ python -u main.py --task rhyme --data_dir train_data/gpt2_generations --save_dir ckpt/poetry/rhyme_retrain_predictor --num_workers 20 --batch_size 128 --epoch_max_len 100000 --validation_freq 10 --lr 2e-4 --epochs 1500 > rhyme_retrain_predictor.log
75
+ ```
76
+
77
+ End of sentence predictor (referred to as "newline" in the code; 50 epochs is more than enough for this one):
78
+
79
+ ```
80
+ python -u main.py --task newline --data_dir train_data/gpt2_generations --save_dir ckpt/poetry/newline_retrain_predictor --num_workers 20 --batch_size 128 --epoch_max_len 100000 --validation_freq 10 --lr 2e-4 --epochs 50 > newline_retrain_predictor.log
81
+ ```
82
+
83
+ The same evaluation commands as before will work; just modify the paths in the command to point to `model_best.pth.tar`, `dataset_info`, and `rhyme_info` from your newly trained ckpt folders.
84
+
85
+ ## Topic Control
86
+
87
+ ### Evaluation
88
+
89
+ To generate outputs, run:
90
+
91
+ ```
92
+ python -u evaluate_topic.py --ckpt ckpt/topic/future_word_predictor/model.pth.tar --dataset_info ckpt/topic/future_word_predictor/dataset_info --prefix_file topic_data/topic_prefixes.txt --wordlist_dir topic_data/wordlists --condition_lambda 4.0 --verbose --precondition_topk 200 --topk 10 --sample_size 3 --max_sample_batch 1 --length_cutoff 80 --log_file topic_preds.log
93
+ ```
94
+
95
+ Then evaluate metrics using:
96
+
97
+ ```
98
+ python eval_topic_metrics.py --log_file topic_preds.log --tw_dir topic_data/test_wordlists
99
+ ```
100
+
101
+ You can also find our original generations and baselines in `topic_human_evals/`.
102
+
103
+ ### Training your own predictors
104
+
105
+ Example command below.
106
+
107
+ ```
108
+ python -u main.py --task topic --data_dir train_data/gpt2_generations --save_dir ckpt/topic/future_word_retrain_predictor --num_workers 20 --batch_size 128 --epoch_max_len 100000 --validation_freq 10 --lr 2e-4 --epochs 500 --glove_file train_data/glove.840B.300d.txt > future_word_retrain_predictor.log
109
+ ```
110
+
111
+ The same evaluation commands as before will work; just modify the paths in the command to point to `model_best.pth.tar`, `dataset_info`, and `rhyme_info` from your newly trained ckpt folders.
112
+
113
+ ## Machine Translation Formality
114
+
115
+ ### Evaluation
116
+
117
+ To generate outputs, run:
118
+
119
+ ```
120
+ python -u evaluate_formality.py --ckpt ckpt/formality/predictor_gyafc_entertainment_music/model.pth.tar --dataset_info ckpt/formality/predictor_gyafc_entertainment_music/dataset_info --in_file formality_data/fisher_test_oracle.es --model_path ckpt/formality/marian_finetune_fisher > formality_preds.log
121
+ ```
122
+
123
+ The above command generates predictions using the Marian model finetuned on the Fisher dataset; remove the `--model_path` argument to get predictions with the un-finetuned Marian model from HuggingFace (referred to as 0-shot in the paper)
124
+
125
+ Then evaluate metrics using:
126
+
127
+ ```
128
+ python eval_formality_metrics.py --pred formality_preds.log --ref formality_data/test.noid.cleaned_0 formality_data/test.noid.cleaned_1 --ckpt ckpt/formality/test_evaluator_gyafc_family_relationships/model.pth.tar --dataset_info ckpt/formality/test_evaluator_gyafc_family_relationships/dataset_info
129
+ ```
130
+
131
+ ### Training your own predictors
132
+
133
+ Example command below. (Reminder: you need to go get the GYAFC dataset following the instructions in https://github.com/raosudha89/GYAFC-corpus.)
134
+
135
+ ```
136
+ python -u main.py --task formality --data_dir train_data/GYAFC_Corpus/Entertainment_Music --save_dir ckpt/formality/formality_retrain_predictor --num_workers 20 --batch_size 32 --epoch_max_len 1000000 --validation_freq 1 --lr 2e-5 --epochs 20 > formality_retrain_predictor.log
137
+ ```
138
+
139
+ (The test-time formality evaluator is trained in the same way, just using the Family/Relationships half of the GYAFC dataset.)
140
+
141
+ The same evaluation commands as before will work; just modify the paths in the command to point to `model_best.pth.tar`, `dataset_info`, and `rhyme_info` from your newly trained ckpt folders.
142
+
143
+ ## Running FUDGE on your own data
144
+
145
+ The code has been refactored so that the iambic (poetry), rhyme (poetry), newline (poetry), future word (topic), and formality (machine translation) are controlled by the `--task` flag to `main.py`. You should add your task as another option here, then modify the data processing in `data.py` and the model in `model.py` as needed for your task. (In `data.py` you probably won't need all the entries of the tuple that is expected of the loader; you can just put dummy entries in the ones you don't need.) You might also need to modify the loss computation in the `train` and `validate` functions in `main.py`. You'll probably want to write new evaluation scripts, though the existing poetry/topic/formality ones are hopefully helpful as references.
146
+
147
+ Alternatively, the general FUDGE framework is pretty simple, so you could always try reimplementing things yourself. A few additional details based on questions I've received:
148
+
149
+ (1) The formality task setup is likely closest to what you want if you're just trying to run the simplest form of FUDGE (take a language model, and use a classifier to optimize toward a single attribute) although you may need to swap out the Marian translation model/tokenizer we use.
150
+
151
+ (2) When you construct your training data, if you have an example in your data e.g. "This movie is great!" for positive sentiment, you want to learn on all the pairs (This, +), (This movie, +), (This movie is, +), etc., as that's one of the main points of our approach.
152
+
153
+ (3) For computational efficiency, we first filter the base model's next token probabilities down to the top 200 (Sec. 3.1 in the paper), before adding the classifier logits. This way you only need to evaluate your classifier on 200 continuations. Then afterward, you filter down again to whatever top-k/greedy/nucleus sampling you're using for evaluation (we use top-k with k=10 for poetry and topic, greedy for formality).
154
+
155
+ (4) You can use a pretrained LM backbone instead of a simple LSTM backbone for the predictor as well. This should work better when your dataset is smaller.
fudge/clickbait_classifier.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import BertModel, BertConfig, PretrainedConfig, PreTrainedModel, AutoModel, AutoConfig
3
+ from typing import List, Optional, Tuple, Union
4
+ from transformers.modeling_outputs import TokenClassifierOutput,SequenceClassifierOutput
5
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss, BCELoss
6
+ import torch.nn as nn
7
+ # from modeling_mpnet import MPNetModel, MPnetConfig
8
+
9
+ class ClickbaitConfig(PretrainedConfig):
10
+ def __init__(
11
+ self,
12
+ model_type: str = "bert",
13
+ pretrained_model: str = "bert-base-uncased",
14
+ num_labels: int = 1,
15
+ dropout: float = 0.1,
16
+ inner_dim1: int = 256,
17
+ inner_dim2: int = 32,
18
+ max_length: int = 512,
19
+ load_pretrained: bool = True,
20
+ freeze_bert: bool = True,
21
+ **kwargs
22
+ ):
23
+ super(ClickbaitConfig, self).__init__(num_labels=num_labels, **kwargs)
24
+ self.model_type = model_type
25
+ self.pretrained_model = pretrained_model
26
+ self.dropout = dropout
27
+ self.inner_dim1 = inner_dim1
28
+ self.inner_dim2 = inner_dim2
29
+ self.max_length = max_length
30
+ self.load_pretrained = load_pretrained
31
+ self.freeze_bert = freeze_bert
32
+
33
+
34
+ class BertClickbaitClassifier(PreTrainedModel):
35
+ """
36
+ Taken and extended from BertforSequenceClassification : https://github.com/huggingface/transformers/blob/v4.19.2/src/transformers/models/bert/modeling_bert.py#L1508
37
+ """
38
+ config_class = ClickbaitConfig
39
+ def __init__(self, config: ClickbaitConfig):
40
+ super(BertClickbaitClassifier, self).__init__(config)
41
+ self.num_labels = config.num_labels
42
+ self.config = config
43
+ # self.bert_config = BertConfig.from_pretrained(config.pretrained_model)
44
+ self.bert_config = AutoConfig.from_pretrained(config.pretrained_model)
45
+
46
+ # self.bert = BertModel(self.bert_config)
47
+ self.bert = AutoModel.from_pretrained(config.pretrained_model, config=self.bert_config)
48
+ # self.bert = SentenceTransformer(config.pretrained_model, config=self.bert_config)
49
+ # self.bert = MPNetModel(config.pretrained_model, config=self.bert_config)
50
+ if config.load_pretrained:
51
+ print("Load pretrained weights from {}".format(config.pretrained_model))
52
+ self.bert = self.bert.from_pretrained(config.pretrained_model)
53
+ if config.freeze_bert:
54
+ print("Freeze weights in the BERT model. Just the classifier will be trained")
55
+ for param in self.bert.parameters():
56
+ param.requires_grad = False
57
+
58
+ self.linear_1 = nn.Linear(self.bert.config.hidden_size, config.inner_dim1)
59
+ self.dropout_1 = nn.Dropout(config.dropout)
60
+ self.relu_1 = nn.ReLU()
61
+ self.dropout_2 = nn.Dropout(config.dropout)
62
+ self.linear_2 = nn.Linear(config.inner_dim1, config.inner_dim2)
63
+ self.relu_2 = nn.ReLU()
64
+ self.dropout_3 = nn.Dropout(config.dropout)
65
+ self.classifier = nn.Linear(config.inner_dim2, config.num_labels)
66
+ self.sigmoid = nn.Sigmoid()
67
+
68
+
69
+ def forward(
70
+ self,
71
+ input_ids: Optional[torch.Tensor] = None,
72
+ attention_mask: Optional[torch.Tensor] = None,
73
+ token_type_ids: Optional[torch.Tensor] = None,
74
+ position_ids: Optional[torch.Tensor] = None,
75
+ head_mask: Optional[torch.Tensor] = None,
76
+ inputs_embeds: Optional[torch.Tensor] = None,
77
+ labels: Optional[torch.Tensor] = None,
78
+ output_attentions: Optional[bool] = None,
79
+ output_hidden_states: Optional[bool] = None,
80
+ return_dict: Optional[bool] = None,
81
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
82
+ r"""
83
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
84
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
85
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
86
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
87
+ """
88
+
89
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
90
+
91
+ outputs = self.bert(
92
+ input_ids,
93
+ attention_mask=attention_mask,
94
+ token_type_ids=token_type_ids,
95
+ position_ids=position_ids,
96
+ head_mask=head_mask,
97
+ inputs_embeds=inputs_embeds,
98
+ output_attentions=output_attentions,
99
+ output_hidden_states=output_hidden_states,
100
+ return_dict=return_dict,
101
+ )
102
+
103
+ output = outputs[0][:,0,:]
104
+
105
+ x = self.dropout_1(output)
106
+ x = self.linear_1(x)
107
+ x = self.relu_1(x)
108
+ x = self.dropout_2(x)
109
+ x = self.linear_2(x)
110
+ x = self.relu_2(x)
111
+ x = self.dropout_3(x)
112
+
113
+ logits = self.classifier(x)
114
+ logits = self.sigmoid(logits)
115
+
116
+ loss = None
117
+ if labels is not None:
118
+ loss_fct = BCELoss(weight=WEIGHT)
119
+ labels = 1.0*labels
120
+ loss = loss_fct(logits.view(-1), labels.view(-1))
121
+ if not return_dict:
122
+ output = (logits,) + outputs[2:]
123
+ return ((loss,) + output) if loss is not None else output
124
+
125
+ return SequenceClassifierOutput(
126
+ loss=loss,
127
+ logits=logits
128
+ )
fudge/constants.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ PAD_TOKEN = '[PAD]'
2
+ EOT_TOKEN = '<|endoftext|>'
3
+ SEP = 50256 # just use the weird eot token
4
+
5
+ TOPIC_MODEL_STRING = 'gpt2-medium'
6
+ FORMALITY_MODEL_STRING = 'Helsinki-NLP/opus-mt-es-en'
7
+
8
+ DIR_END_SPLIT_POSITIONS = 32
9
+
10
+ TOPIC_VAL_SIZE = 100000
11
+ FORMALITY_VAL_SIZE = 2000
12
+ VOCAB_SIZE = 50000
13
+
14
+ FORMALITY_MAX_LEN = 200
15
+
16
+ GLOVE_PRINT_PROGRESS_FREQ = 1000000
17
+ GLOVE_DIM = 300
18
+ HIDDEN_DIM = 300
19
+ RNN_DIM = 150
20
+
21
+ MIN_SENTENCE_LENGTH = 3
22
+
23
+ POETRY_LINE_SYLLABLES = 10
24
+ MAX_SYLLABLES_PER_WORD = 10 # no way anything is more
25
+ MAX_COUNT_SYLLABLE_DIST = 10
26
+ MAX_COUNT_SYLLABLE_INPUT_LENGTH = 25 # for just a couplet, shouldn't need more
27
+ COUNT_SYLLABLE_DIM = 100
28
+ UNKNOWN_RHYME_GROUP = 'UNKNOWN_RHYME_GROUP'
29
+ PHRASE_ENDS = '.?!'
30
+
31
+ POETRY_BANNED_TOKENS = [198, 50256, 628, 220] # newlines and eos and such
32
+
fudge/data.py ADDED
@@ -0,0 +1,415 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import math
3
+ import os
4
+ import pickle
5
+ from collections import defaultdict, namedtuple
6
+ import string
7
+
8
+ os.environ['TOKENIZERS_PARALLELISM'] = 'false' # turn off since we're using multiple threads for loading anyway
9
+
10
+ from transformers import AutoTokenizer, AutoModelWithLMHead, pipeline, set_seed, GPT2Tokenizer, GPT2Model
11
+ import numpy as np
12
+ from tqdm import tqdm
13
+ import torch
14
+
15
+ from fudge.util import suppress_stdout
16
+ from fudge.poetry_util import is_iambic, count_syllables, get_rhymes, get_rhyme_group
17
+ from fudge.constants import *
18
+
19
+ DatasetInfo = namedtuple('DatasetInfo',
20
+ ['index2word', 'word2index', 'total_words', 'vocab', 'glove_embeddings'])
21
+ RhymeInfo = namedtuple('RhymeInfo',
22
+ ['word2rhyme_group', 'rhyme_group_counts', 'rhyme_groups', 'index2rhyme_group', 'rhyme_group2index', 'total_rhyme_groups'])
23
+
24
+ def collate(batch):
25
+ pad_id = batch[0][4]
26
+ inputs = [b[0] for b in batch]
27
+ lengths = torch.LongTensor([b[1] for b in batch])
28
+ max_length = lengths.max()
29
+ for i in range(len(inputs)):
30
+ if len(inputs[i]) < max_length:
31
+ inputs[i] = torch.cat([inputs[i], torch.zeros(max_length - len(inputs[i])).long()], dim=0) # actually 0 is fine as pad since it's masked out
32
+ inputs = torch.stack(inputs, dim=0)
33
+ future_words = torch.LongTensor([b[2] for b in batch]).unsqueeze(0).expand(len(batch), -1).clone() # batch x N=batch
34
+ labels = torch.zeros_like(future_words).long()
35
+ labels = labels.scatter(1, torch.arange(len(batch)).unsqueeze(1), torch.ones(len(batch)).long().unsqueeze(1)).clone()
36
+ log_probs = torch.Tensor([b[3] for b in batch])
37
+ classification_labels = [b[5] for b in batch] # batch
38
+ if type(classification_labels[0]) == list:
39
+ for i in range(len(classification_labels)):
40
+ assert len(classification_labels[i]) == lengths[i]
41
+ if len(classification_labels[i]) < max_length:
42
+ classification_labels[i] = torch.cat([torch.LongTensor(classification_labels[i]), -1 + torch.zeros(max_length - len(classification_labels[i])).long()], dim=0)
43
+ else:
44
+ classification_labels[i] = torch.LongTensor(classification_labels[i])
45
+ classification_labels = torch.stack(classification_labels, dim=0) # batch x seq
46
+ else:
47
+ assert type(classification_labels[0]) == int
48
+ classification_labels = torch.LongTensor(classification_labels) # they're just int labels
49
+ syllables_to_go = torch.LongTensor([b[6] for b in batch])
50
+ future_word_num_syllables = torch.LongTensor([b[7] for b in batch])
51
+ rhyme_group_index = torch.LongTensor([b[8] for b in batch])
52
+ return (inputs, lengths, future_words, log_probs, labels, classification_labels, syllables_to_go, future_word_num_syllables, rhyme_group_index)
53
+
54
+
55
+ def load_rhyme_info(index2word, vocab):
56
+ word2rhyme_group = defaultdict(lambda: UNKNOWN_RHYME_GROUP)
57
+ rhyme_group_counts = defaultdict(lambda: 0)
58
+ rhyme_groups = set()
59
+ for word in index2word:
60
+ try:
61
+ rhyme_group = get_rhyme_group(word)
62
+ word2rhyme_group[word] = rhyme_group
63
+ rhyme_group_counts[rhyme_group] += (vocab[word] if word in vocab else 1) # for rare words not in vocab, just use 1
64
+ rhyme_groups.add(rhyme_group)
65
+ except:
66
+ rhyme_group_counts[UNKNOWN_RHYME_GROUP] += (vocab[word] if word in vocab else 1)
67
+ index2rhyme_group = [UNKNOWN_RHYME_GROUP] + sorted(list(rhyme_groups))
68
+ rhyme_group2index = {s: i for i, s in enumerate(index2rhyme_group)}
69
+ total_rhyme_groups = sum(rhyme_group_counts.values())
70
+
71
+ return RhymeInfo(word2rhyme_group=dict(word2rhyme_group),
72
+ rhyme_group_counts=dict(rhyme_group_counts),
73
+ rhyme_groups=rhyme_groups,
74
+ index2rhyme_group=index2rhyme_group,
75
+ rhyme_group2index=rhyme_group2index,
76
+ total_rhyme_groups=total_rhyme_groups)
77
+
78
+
79
+ class Dataset:
80
+ def __init__(self, args):
81
+ print('loading data')
82
+ random.seed(args.seed)
83
+ self.batch_size = args.batch_size
84
+ self.data_dir = args.data_dir
85
+ self.topic = args.task == 'topic'
86
+ self.formality = args.task == 'formality'
87
+ self.iambic = args.task == 'iambic'
88
+ self.rhyme = args.task == 'rhyme'
89
+ self.newline = args.task == 'newline'
90
+
91
+ self.tokenizer = AutoTokenizer.from_pretrained(FORMALITY_MODEL_STRING if self.formality else TOPIC_MODEL_STRING)
92
+ self.tokenizer.add_special_tokens({'pad_token': PAD_TOKEN})
93
+ self.gpt_pad_id = self.tokenizer.encode(PAD_TOKEN)[0] # actually just the vocab size
94
+ sentences = []
95
+ self.vocab = defaultdict(lambda: 0)
96
+ if self.formality:
97
+ self.vocab['placeholder'] = 1 # anything so we don't crash
98
+ train, val, test = [], [], []
99
+ for category, label in [('formal', 1), ('informal', 0)]:
100
+ with open(os.path.join(args.data_dir, 'train', category), 'r') as rf:
101
+ for i, line in enumerate(rf):
102
+ if len(line) > FORMALITY_MAX_LEN:
103
+ line = ' '.join(line.strip()[:FORMALITY_MAX_LEN].split()[:-1]) # cutoff words until below max len; chosen so only ~20 examples affected in dataset
104
+ if i < FORMALITY_VAL_SIZE // 2:
105
+ val.append((line.strip(), label))
106
+ else:
107
+ train.append((line.strip(), label))
108
+ with open(os.path.join(args.data_dir, 'test', category), 'r') as rf:
109
+ for line in rf:
110
+ if len(line) > FORMALITY_MAX_LEN:
111
+ line = ' '.join(line.strip()[:FORMALITY_MAX_LEN].split()[:-1]) # cutoff words until below max len
112
+ test.append((line.strip(), label))
113
+ self.splits = {}
114
+ self.splits['train'], self.splits['val'], self.splits['test'] = train, val, test
115
+ else: # topic / poetry
116
+ for root, _, filenames in os.walk(args.data_dir):
117
+ for fname in filenames:
118
+ with open(os.path.join(root, fname), 'r') as rf:
119
+ for line in rf:
120
+ sentences.append(line.strip())
121
+ for word in line.strip().split(' '):
122
+ self.vocab[word] += 1
123
+ random.shuffle(sentences)
124
+ self.splits = {}
125
+ if args.debug:
126
+ self.splits['val'] = sentences
127
+ self.splits['test'] = sentences
128
+ self.splits['train'] = sentences
129
+ else:
130
+ self.splits['val'] = sentences[:TOPIC_VAL_SIZE]
131
+ self.splits['test'] = sentences[TOPIC_VAL_SIZE:2*TOPIC_VAL_SIZE]
132
+ self.splits['train'] = sentences[2*TOPIC_VAL_SIZE:]
133
+
134
+ if args.dataset_info is not None:
135
+ print('loading dataset info from file')
136
+ with open(args.dataset_info, 'rb') as rf:
137
+ dataset_info = pickle.load(rf)
138
+ self.vocab, self.total_words, self.index2word, self.word2index, self.glove_embeddings = \
139
+ dataset_info.vocab, dataset_info.total_words, dataset_info.index2word, dataset_info.word2index, dataset_info.glove_embeddings
140
+ self.dataset_info = dataset_info
141
+ else:
142
+ print('generating dataset info from scratch')
143
+ words_values = list(self.vocab.items())
144
+ words_values = sorted(words_values, key=lambda x: x[1], reverse=True)
145
+ if args.glove_file is None:
146
+ print('no glove embeddings given')
147
+ for word, _ in words_values[VOCAB_SIZE:]: # only use somewhat common tokens
148
+ del self.vocab[word]
149
+ glove_embeddings = None
150
+ else:
151
+ print('loading glove embeddings')
152
+ glove_embeddings = {}
153
+ with open(args.glove_file, 'r') as rf:
154
+ for i, line in enumerate(rf):
155
+ if i % GLOVE_PRINT_PROGRESS_FREQ == 0:
156
+ print(i)
157
+ line = line.strip().split()
158
+ if len(line) != GLOVE_DIM + 1:
159
+ continue # skip multi-word embeddings which are rare anyway
160
+ glove_embeddings[line[0]] = [float(x) for x in line[1:]]
161
+ for word, _ in words_values:
162
+ if word not in glove_embeddings:
163
+ del self.vocab[word]
164
+ self.total_words = sum(self.vocab.values())
165
+ self.index2word = [PAD_TOKEN] + sorted(list(self.vocab.keys()))
166
+ self.word2index = {s: i for i, s in enumerate(self.index2word)}
167
+ self.vocab = dict(self.vocab) # so we can pickle later
168
+ if glove_embeddings is None:
169
+ self.glove_embeddings = None
170
+ else:
171
+ self.glove_embeddings = torch.stack([torch.zeros(GLOVE_DIM)] + [torch.Tensor(glove_embeddings[word]) for word in self.index2word[1:]], dim=0)
172
+
173
+ self.dataset_info = DatasetInfo(index2word=self.index2word,
174
+ word2index=self.word2index,
175
+ total_words=self.total_words,
176
+ vocab=self.vocab,
177
+ glove_embeddings=self.glove_embeddings)
178
+
179
+ if self.rhyme:
180
+ if args.rhyme_info is not None:
181
+ print('loading rhyme info from file')
182
+ with open(args.rhyme_info, 'rb') as rf:
183
+ self.rhyme_info = pickle.load(rf)
184
+ else:
185
+ self.rhyme_info = load_rhyme_info(self.index2word, self.vocab)
186
+ self.word2rhyme_group, self.rhyme_group_counts, self.rhyme_groups, self.index2rhyme_group, self.rhyme_group2index, self.total_rhyme_groups = \
187
+ defaultdict(lambda: UNKNOWN_RHYME_GROUP, self.rhyme_info.word2rhyme_group), self.rhyme_info.rhyme_group_counts, self.rhyme_info.rhyme_groups, self.rhyme_info.index2rhyme_group, self.rhyme_info.rhyme_group2index, self.rhyme_info.total_rhyme_groups
188
+
189
+ print('done loading data')
190
+ print('split sizes:')
191
+ for key in ['train', 'val', 'test']:
192
+ print(key, len(self.splits[key]))
193
+ if not self.formality:
194
+ print('total words', self.total_words)
195
+ print('vocab size', len(self.index2word))
196
+
197
+
198
+ def shuffle(self, split, seed=None):
199
+ assert split in ['train', 'val', 'test']
200
+ if seed is not None:
201
+ random.seed(seed)
202
+ random.shuffle(self.splits[split])
203
+
204
+
205
+ def loader(self, split, num_workers=20, indices=None):
206
+ assert split in ['train', 'val', 'test']
207
+ data = self.splits[split] if indices is None else [self.splits[split][i] for i in indices]
208
+ return torch.utils.data.DataLoader(SplitLoader(data, self), batch_size=self.batch_size, pin_memory=True, collate_fn=collate, num_workers=num_workers)
209
+
210
+
211
+ class SplitLoader(torch.utils.data.IterableDataset):
212
+ def __init__(self, data, parent):
213
+ super(SplitLoader).__init__()
214
+ self.data = data
215
+ self.pos = 0
216
+ self.parent = parent
217
+
218
+
219
+ def __len__(self):
220
+ return len(self.data)
221
+
222
+