ccmaymay commited on
Commit
c034681
·
verified ·
1 Parent(s): e0be3dc

Add files from rrivera1849/LUAR-MUD 858fcb1.

Browse files
Files changed (9) hide show
  1. README.md +71 -3
  2. config.json +16 -0
  3. config.py +18 -0
  4. merges.txt +0 -0
  5. model.py +219 -0
  6. special_tokens_map.json +51 -0
  7. tokenizer.json +0 -0
  8. tokenizer_config.json +58 -0
  9. vocab.json +0 -0
README.md CHANGED
@@ -1,3 +1,71 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language:
3
+ - en
4
+ license: apache-2.0
5
+ ---
6
+
7
+ # LUAR-MUD development model forked from rrivera1849/LUAR-MUD
8
+
9
+ Author Style Representations using [LUAR](https://aclanthology.org/2021.emnlp-main.70.pdf).
10
+
11
+ The LUAR training and evaluation repository can be found [here](https://github.com/llnl/luar).
12
+
13
+ This model was trained on the Reddit Million User Dataset (MUD) found [here](https://aclanthology.org/2021.naacl-main.415.pdf).
14
+
15
+ ## Usage
16
+
17
+ ```python
18
+ from transformers import AutoModel, AutoTokenizer
19
+
20
+ tokenizer = AutoTokenizer.from_pretrained("rrivera1849/LUAR-MUD")
21
+ model = AutoModel.from_pretrained("rrivera1849/LUAR-MUD")
22
+
23
+ # we embed `episodes`, a colletion of documents presumed to come from an author
24
+ # NOTE: make sure that `episode_length` consistent across `episode`
25
+ batch_size = 3
26
+ episode_length = 16
27
+ text = [
28
+ ["Foo"] * episode_length,
29
+ ["Bar"] * episode_length,
30
+ ["Zoo"] * episode_length,
31
+ ]
32
+ text = [j for i in text for j in i]
33
+ tokenized_text = tokenizer(
34
+ text,
35
+ max_length=32,
36
+ padding="max_length",
37
+ truncation=True,
38
+ return_tensors="pt"
39
+ )
40
+ # inputs size: (batch_size, episode_length, max_token_length)
41
+ tokenized_text["input_ids"] = tokenized_text["input_ids"].reshape(batch_size, episode_length, -1)
42
+ tokenized_text["attention_mask"] = tokenized_text["attention_mask"].reshape(batch_size, episode_length, -1)
43
+ print(tokenized_text["input_ids"].size()) # torch.Size([3, 16, 32])
44
+ print(tokenized_text["attention_mask"].size()) # torch.Size([3, 16, 32])
45
+
46
+ out = model(**tokenized_text)
47
+ print(out.size()) # torch.Size([3, 512])
48
+
49
+ # to get the Transformer attentions:
50
+ out, attentions = model(**tokenized_text, output_attentions=True)
51
+ print(attentions[0].size()) # torch.Size([48, 12, 32, 32])
52
+ ```
53
+
54
+ ## Citing & Authors
55
+
56
+ If you find this model helpful, feel free to cite our [publication](https://aclanthology.org/2021.emnlp-main.70.pdf).
57
+
58
+ ```
59
+ @inproceedings{uar-emnlp2021,
60
+ author = {Rafael A. Rivera Soto and Olivia Miano and Juanita Ordonez and Barry Chen and Aleem Khan and Marcus Bishop and Nicholas Andrews},
61
+ title = {Learning Universal Authorship Representations},
62
+ booktitle = {EMNLP},
63
+ year = {2021},
64
+ }
65
+ ```
66
+
67
+ ## License
68
+
69
+ LUAR is distributed under the terms of the Apache License (Version 2.0).
70
+
71
+ All new contributions must be made under the Apache-2.0 licenses.
config.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "LUAR"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "config.LUARConfig",
7
+ "AutoModel": "model.LUAR"
8
+ },
9
+ "embedding_size": 512,
10
+ "k_bucket_size": 1024,
11
+ "model_type": "LUAR",
12
+ "q_bucket_size": 512,
13
+ "torch_dtype": "float32",
14
+ "transformers_version": "4.33.2",
15
+ "use_memory_efficient_attention": false
16
+ }
config.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from transformers import PretrainedConfig
3
+
4
+ class LUARConfig(PretrainedConfig):
5
+ model_type = "LUAR"
6
+
7
+ def __init__(self,
8
+ embedding_size: int = 512,
9
+ use_memory_efficient_attention=False,
10
+ q_bucket_size=512,
11
+ k_bucket_size=1024,
12
+ **kwargs,
13
+ ):
14
+ self.embedding_size = embedding_size
15
+ self.use_memory_efficient_attention = use_memory_efficient_attention
16
+ self.q_bucket_size = q_bucket_size
17
+ self.k_bucket_size = k_bucket_size
18
+ super().__init__(**kwargs)
merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
model.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import math
3
+ from functools import partial
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from einops import rearrange, reduce, repeat
9
+ from torch.utils.checkpoint import checkpoint
10
+ from transformers import AutoModel, PreTrainedModel
11
+
12
+ from .config import LUARConfig
13
+
14
+ # Adapted LucidRains impl. of Memory Efficient Attention
15
+ # https://github.com/lucidrains/memory-efficient-attention-pytorch
16
+
17
+ def exists(val):
18
+ return val is not None
19
+
20
+ def summarize_qkv_chunk(
21
+ q, k, v,
22
+ mask
23
+ ):
24
+ """Dot-Product Attention for a chunk of queries, keys, and values.
25
+ """
26
+ weight = torch.einsum('b h i d, b h j d -> b h i j', q, k)
27
+
28
+ if exists(mask):
29
+ # HuggingFace masks have to be added:
30
+ weight += mask
31
+
32
+ weight_max = weight.amax(dim = -1, keepdim = True).detach()
33
+ weight = weight - weight_max
34
+
35
+ exp_weight = weight.exp()
36
+ weighted_value = torch.einsum('b h i j, b h j d -> b h i d', exp_weight, v)
37
+
38
+ return exp_weight.sum(dim = -1), weighted_value, rearrange(weight_max, '... 1 -> ...')
39
+
40
+ checkpointed_summarize_qkv_chunk = partial(checkpoint, summarize_qkv_chunk)
41
+
42
+ def memory_efficient_attention(
43
+ q, k, v,
44
+ mask = None,
45
+ q_bucket_size = 512,
46
+ k_bucket_size = 1024,
47
+ eps = 1e-8
48
+ ):
49
+ scale = q.shape[-1] ** -0.5
50
+ q = q * scale
51
+
52
+ # function
53
+ needs_backwards = q.requires_grad or k.requires_grad or v.requires_grad
54
+ summarize_qkv_fn = checkpointed_summarize_qkv_chunk if needs_backwards else summarize_qkv_chunk
55
+
56
+ # chunk all the inputs
57
+ q_chunks = q.split(q_bucket_size, dim = -2)
58
+ k_chunks = k.split(k_bucket_size, dim = -2)
59
+ v_chunks = v.split(k_bucket_size, dim = -2)
60
+ mask_chunks = mask.split(k_bucket_size, dim = -1) if exists(mask) else ((None,) * len(k_chunks))
61
+
62
+ # loop through all chunks and accumulate
63
+ out = []
64
+ for q_index, q_chunk in enumerate(q_chunks):
65
+ exp_weights = []
66
+ weighted_values = []
67
+ weight_maxes = []
68
+
69
+ for k_index, (k_chunk, v_chunk, mask_chunk) in enumerate(zip(k_chunks, v_chunks, mask_chunks)):
70
+
71
+ exp_weight_chunk, weighted_value_chunk, weight_max_chunk = summarize_qkv_fn(
72
+ q_chunk,
73
+ k_chunk,
74
+ v_chunk,
75
+ mask_chunk,
76
+ )
77
+
78
+ exp_weights.append(exp_weight_chunk)
79
+ weighted_values.append(weighted_value_chunk)
80
+ weight_maxes.append(weight_max_chunk)
81
+
82
+ exp_weights = torch.stack(exp_weights, dim = -1)
83
+ weighted_values = torch.stack(weighted_values, dim = -1)
84
+ weight_maxes = torch.stack(weight_maxes, dim = -1)
85
+
86
+ global_max = weight_maxes.amax(dim = -1, keepdim = True)
87
+ renorm_factor = (weight_maxes - global_max).exp().detach()
88
+
89
+ exp_weights = exp_weights * renorm_factor
90
+ weighted_values = weighted_values * rearrange(renorm_factor, '... c -> ... 1 c')
91
+
92
+ all_values = weighted_values.sum(dim = -1)
93
+ all_weights = exp_weights.sum(dim = -1)
94
+
95
+ normalized_values = all_values / (rearrange(all_weights, '... -> ... 1') + eps)
96
+ out.append(normalized_values)
97
+
98
+ return torch.cat(out, dim=-2)
99
+
100
+ class SelfAttention(nn.Module):
101
+ """Implements Dot-Product Self-Attention as used in "Attention is all You Need".
102
+ """
103
+ def __init__(
104
+ self,
105
+ memory_efficient_attention=False,
106
+ q_bucket_size=512,
107
+ k_bucket_size=1024,
108
+ ):
109
+ super(SelfAttention, self).__init__()
110
+ self.use_memory_efficient_attention = memory_efficient_attention
111
+ self.q_bucket_size = q_bucket_size
112
+ self.k_bucket_size = k_bucket_size
113
+
114
+ def forward(self, k, q, v):
115
+
116
+ if self.use_memory_efficient_attention:
117
+ q, k, v = map(
118
+ lambda t: rearrange(t, 'b n (h d) -> b h n d', h = 12),
119
+ (q, k, v)
120
+ )
121
+
122
+ out = memory_efficient_attention(
123
+ q, k, v,
124
+ q_bucket_size=self.q_bucket_size,
125
+ k_bucket_size=self.k_bucket_size
126
+ )
127
+ out = rearrange(out, 'b h n d -> b n (h d)')
128
+ return out
129
+ else:
130
+ d_k = q.size(-1)
131
+ scores = torch.matmul(k, q.transpose(-2, -1)) / math.sqrt(d_k)
132
+ p_attn = F.softmax(scores, dim=-1)
133
+ return torch.matmul(p_attn, v)
134
+
135
+ class LUAR(PreTrainedModel):
136
+ """Defines the LUAR model.
137
+ """
138
+ config_class = LUARConfig
139
+
140
+ def __init__(self, config):
141
+ super().__init__(config)
142
+ self.create_transformer()
143
+ self.attn_fn = SelfAttention(
144
+ config.use_memory_efficient_attention,
145
+ config.q_bucket_size,
146
+ config.k_bucket_size,
147
+ )
148
+ self.linear = nn.Linear(self.hidden_size, config.embedding_size)
149
+
150
+ def create_transformer(self):
151
+ """Creates the Transformer backbone.
152
+ """
153
+ self.transformer = AutoModel.from_pretrained("sentence-transformers/paraphrase-distilroberta-base-v1")
154
+ self.hidden_size = self.transformer.config.hidden_size
155
+ self.num_attention_heads = self.transformer.config.num_attention_heads
156
+ self.dim_head = self.hidden_size // self.num_attention_heads
157
+
158
+ def mean_pooling(self, token_embeddings, attention_mask):
159
+ """Mean Pooling as described in the SBERT paper.
160
+ """
161
+ input_mask_expanded = repeat(attention_mask, 'b l -> b l d', d=self.hidden_size).type(token_embeddings.type())
162
+ sum_embeddings = reduce(token_embeddings * input_mask_expanded, 'b l d -> b d', 'sum')
163
+ sum_mask = torch.clamp(reduce(input_mask_expanded, 'b l d -> b d', 'sum'), min=1e-9)
164
+ return sum_embeddings / sum_mask
165
+
166
+ def get_episode_embeddings(self, input_ids, attention_mask, output_attentions=False, document_batch_size=0):
167
+ """Computes the Author Embedding.
168
+ """
169
+ B, E, _ = attention_mask.shape
170
+
171
+ input_ids = rearrange(input_ids, 'b e l -> (b e) l')
172
+ attention_mask = rearrange(attention_mask, 'b e l -> (b e) l')
173
+
174
+ if document_batch_size > 0:
175
+ outputs = {"last_hidden_state": [], "attentions": []}
176
+ for i in range(0, len(input_ids), document_batch_size):
177
+ out = self.transformer(
178
+ input_ids=input_ids[i:i+document_batch_size],
179
+ attention_mask=attention_mask[i:i+document_batch_size],
180
+ return_dict=True,
181
+ output_hidden_states=False,
182
+ output_attentions=output_attentions,
183
+ )
184
+ outputs["last_hidden_state"].append(out["last_hidden_state"])
185
+ if output_attentions:
186
+ outputs["attentions"].append(out["attentions"])
187
+ outputs["last_hidden_state"] = torch.cat(outputs["last_hidden_state"], dim=0)
188
+ if output_attentions:
189
+ outputs["attentions"] = tuple([torch.cat([x[i] for x in outputs["attentions"]], dim=0) for i in range(len(outputs["attentions"][0]))])
190
+ else:
191
+ outputs = self.transformer(
192
+ input_ids=input_ids,
193
+ attention_mask=attention_mask,
194
+ return_dict=True,
195
+ output_hidden_states=False,
196
+ output_attentions=output_attentions,
197
+ )
198
+
199
+ # at this point, we're embedding individual "comments"
200
+ comment_embeddings = self.mean_pooling(outputs['last_hidden_state'], attention_mask)
201
+ comment_embeddings = rearrange(comment_embeddings, '(b e) l -> b e l', b=B, e=E)
202
+
203
+ # aggregate individual comments embeddings into episode embeddings
204
+ episode_embeddings = self.attn_fn(comment_embeddings, comment_embeddings, comment_embeddings)
205
+ episode_embeddings = reduce(episode_embeddings, 'b e l -> b l', 'max')
206
+
207
+ episode_embeddings = self.linear(episode_embeddings)
208
+
209
+ if output_attentions:
210
+ return episode_embeddings, outputs["attentions"]
211
+
212
+ return episode_embeddings
213
+
214
+ def forward(self, input_ids, attention_mask, output_attentions=False, document_batch_size=0):
215
+ """Calculates a fixed-length feature vector for a batch of episode samples.
216
+ """
217
+ output = self.get_episode_embeddings(input_ids, attention_mask, output_attentions, document_batch_size)
218
+
219
+ return output
special_tokens_map.json ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<s>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "cls_token": {
10
+ "content": "<s>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "eos_token": {
17
+ "content": "</s>",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "mask_token": {
24
+ "content": "<mask>",
25
+ "lstrip": true,
26
+ "normalized": false,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ },
30
+ "pad_token": {
31
+ "content": "<pad>",
32
+ "lstrip": false,
33
+ "normalized": false,
34
+ "rstrip": false,
35
+ "single_word": false
36
+ },
37
+ "sep_token": {
38
+ "content": "</s>",
39
+ "lstrip": false,
40
+ "normalized": false,
41
+ "rstrip": false,
42
+ "single_word": false
43
+ },
44
+ "unk_token": {
45
+ "content": "<unk>",
46
+ "lstrip": false,
47
+ "normalized": false,
48
+ "rstrip": false,
49
+ "single_word": false
50
+ }
51
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "added_tokens_decoder": {
4
+ "0": {
5
+ "content": "<s>",
6
+ "lstrip": false,
7
+ "normalized": false,
8
+ "rstrip": false,
9
+ "single_word": false,
10
+ "special": true
11
+ },
12
+ "1": {
13
+ "content": "<pad>",
14
+ "lstrip": false,
15
+ "normalized": false,
16
+ "rstrip": false,
17
+ "single_word": false,
18
+ "special": true
19
+ },
20
+ "2": {
21
+ "content": "</s>",
22
+ "lstrip": false,
23
+ "normalized": false,
24
+ "rstrip": false,
25
+ "single_word": false,
26
+ "special": true
27
+ },
28
+ "3": {
29
+ "content": "<unk>",
30
+ "lstrip": false,
31
+ "normalized": false,
32
+ "rstrip": false,
33
+ "single_word": false,
34
+ "special": true
35
+ },
36
+ "50264": {
37
+ "content": "<mask>",
38
+ "lstrip": true,
39
+ "normalized": false,
40
+ "rstrip": false,
41
+ "single_word": false,
42
+ "special": true
43
+ }
44
+ },
45
+ "bos_token": "<s>",
46
+ "clean_up_tokenization_spaces": false,
47
+ "cls_token": "<s>",
48
+ "eos_token": "</s>",
49
+ "errors": "replace",
50
+ "full_tokenizer_file": null,
51
+ "mask_token": "<mask>",
52
+ "model_max_length": 512,
53
+ "pad_token": "<pad>",
54
+ "sep_token": "</s>",
55
+ "tokenizer_class": "RobertaTokenizer",
56
+ "trim_offsets": true,
57
+ "unk_token": "<unk>"
58
+ }
vocab.json ADDED
The diff for this file is too large to render. See raw diff