yehzw commited on
Commit
b1c4dc3
·
verified ·
1 Parent(s): 1af9f6d

Upload folder using huggingface_hub

Browse files
config.json ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_cross_attention": false,
3
+ "architectures": [
4
+ "ParserkerModel"
5
+ ],
6
+ "attention_probs_dropout_prob": 0.1,
7
+ "auto_map": {
8
+ "AutoConfig": "configuration_parserker.ParserkerConfig",
9
+ "AutoModel": "modeling_parserker.ParserkerModel"
10
+ },
11
+ "bos_token_id": 0,
12
+ "classifier_dropout": null,
13
+ "dtype": "float32",
14
+ "eos_token_id": 2,
15
+ "hidden_act": "gelu",
16
+ "hidden_dropout_prob": 0.1,
17
+ "hidden_size": 768,
18
+ "initializer_range": 0.02,
19
+ "intermediate_size": 3072,
20
+ "is_decoder": false,
21
+ "layer_norm_eps": 1e-05,
22
+ "max_position_embeddings": 514,
23
+ "model_type": "parserker",
24
+ "num_attention_heads": 12,
25
+ "num_bits": 16,
26
+ "num_hidden_layers": 12,
27
+ "pad_token_id": 1,
28
+ "position_embedding_type": "absolute",
29
+ "tie_word_embeddings": true,
30
+ "transformers_version": "5.7.0",
31
+ "type_vocab_size": 1,
32
+ "use_cache": true,
33
+ "vocab_size": 50265
34
+ }
configuration_parserker.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.models.roberta.modeling_roberta import RobertaConfig
2
+
3
+
4
+ class ParserkerConfig(RobertaConfig):
5
+ model_type = "parserker"
6
+
7
+ def __init__(self, num_bits=16, **kwargs):
8
+ super(ParserkerConfig, self).__init__(**kwargs)
9
+ self.num_bits = num_bits
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5075a6ca3c482f96f2d076096557713801cad7dcb91d644089ca9d0a802363c1
3
+ size 500969248
modeling_parserker.py ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, NamedTuple
2
+ from typing import List, Tuple, Type, Union
3
+
4
+ import torch
5
+ from nltk import Tree
6
+ from torch import Tensor
7
+ from torch import nn
8
+ from torch.distributions.utils import lazy_property
9
+ from torchrua import C, segment_mean, L, Z
10
+ from transformers.models.roberta.modeling_roberta import PreTrainedModel, RobertaModel
11
+
12
+ from tmp.configuration_parserker import ParserkerConfig
13
+
14
+ Frames = Union[List[Tensor], Tuple[Tensor, ...]]
15
+
16
+
17
+ def diag(tensor: Tensor, offset: int) -> Tensor:
18
+ return tensor.diagonal(offset=offset, dim1=1, dim2=2)
19
+
20
+
21
+ def diag_scatter(chart: Tensor, score: Tensor, offset: int) -> None:
22
+ chart.diagonal(offset=offset, dim1=1, dim2=2)[::] = score
23
+
24
+
25
+ def left(chart: Tensor, offset: int) -> Tensor:
26
+ b, t, _, *size = chart.size()
27
+ c, n, m, *stride = chart.stride()
28
+ return chart.as_strided(
29
+ size=(b, t - offset, offset, *size),
30
+ stride=(c, n + m, m, *stride),
31
+ )
32
+
33
+
34
+ def right(chart: Tensor, offset: int) -> Tensor:
35
+ b, t, _, *size = chart.size()
36
+ c, n, m, *stride = chart.stride()
37
+ return chart[:, 1:, offset:].as_strided(
38
+ size=(b, t - offset, offset, *size),
39
+ stride=(c, n + m, n, *stride),
40
+ )
41
+
42
+
43
+ def to_hex(x: int, num_bits: int) -> str:
44
+ return f'{x:0{(num_bits + 3) // 4}X}'
45
+
46
+
47
+ def bits_to_long(tensor: Tensor) -> Tensor:
48
+ *_, num_bits = tensor.size()
49
+ index = torch.arange(num_bits, dtype=torch.long, device=tensor.device)
50
+ return (tensor << index).sum(dim=-1)
51
+
52
+
53
+ def long_to_bits(tensor: Tensor, num_bits: int) -> Tensor:
54
+ index = torch.arange(num_bits, dtype=torch.long, device=tensor.device)
55
+ return (tensor[..., None] >> index) & 1
56
+
57
+
58
+ def max(tensor: Tensor, dim: int, keepdim: bool = False) -> Tensor:
59
+ return torch.max(tensor, dim=dim, keepdim=keepdim).values
60
+
61
+
62
+ class Semiring(NamedTuple):
63
+ zero: float
64
+ one: float
65
+ add: Callable
66
+ mul: Callable
67
+ sum: Callable
68
+ prod: Callable
69
+
70
+
71
+ Log = Semiring(
72
+ zero=-float('inf'),
73
+ one=0.,
74
+ add=torch.logaddexp,
75
+ mul=torch.add,
76
+ sum=torch.logsumexp,
77
+ prod=torch.sum,
78
+ )
79
+
80
+ Max = Semiring(
81
+ zero=-float('inf'),
82
+ one=0.,
83
+ add=torch.maximum,
84
+ mul=torch.add,
85
+ sum=max,
86
+ prod=torch.sum,
87
+ )
88
+
89
+
90
+ def cumsum(tensor: Tensor) -> Tensor:
91
+ b, t1, t2, k = tensor.size()
92
+ assert t1 == t2, f'{t1} != {t2}'
93
+
94
+ p1 = tensor.permute(0, 3, 1, 2).triu()
95
+ c1 = p1.cumsum(dim=-1)
96
+ c2 = c1.flip(dims=[-2]).cumsum(dim=-2).flip(dims=[-2])
97
+ p2 = c2.permute(0, 2, 3, 1)
98
+ return p2
99
+
100
+
101
+ def cky_partitions(logits: Tensor, token_sizes: Tensor, semiring: Type[Semiring]):
102
+ logits = cumsum(logits)
103
+ logits = torch.stack([torch.zeros_like(logits), logits], dim=-1)
104
+ b, t, _, k, _ = logits.size()
105
+
106
+ chart = torch.full_like(logits[..., 0, 0], fill_value=semiring.zero, requires_grad=False)
107
+
108
+ z = diag(logits, offset=0)[..., None].permute([0, 3, 4, 1, 2])
109
+
110
+ frames = [z]
111
+ z = semiring.sum(z, dim=-1)
112
+ z = semiring.prod(z, dim=-1)
113
+
114
+ diag_scatter(chart, z[..., 0], offset=0)
115
+ index = torch.arange(t, dtype=chart.dtype, device=chart.device)
116
+
117
+ for w in range(1, t):
118
+ z = diag(logits, offset=w)[..., None].permute([0, 3, 4, 1, 2])
119
+ z = z - left(logits, offset=w) - right(logits, offset=w)
120
+ z = z / ((1 + index[:w]) * (w - index[:w]))[:, None, None]
121
+
122
+ frames.append(z)
123
+ z = semiring.sum(z, dim=-1)
124
+ z = semiring.prod(z, dim=-1)
125
+
126
+ xyz = semiring.mul(z, semiring.mul(left(chart, offset=w), right(chart, offset=w)))
127
+ score = semiring.sum(xyz, dim=-1)
128
+
129
+ diag_scatter(chart, score, offset=w)
130
+
131
+ index = torch.arange(b, dtype=torch.long, device=chart.device)
132
+ return chart[index, 0, token_sizes - 1], frames
133
+
134
+
135
+ class Distrubition(object):
136
+ def __init__(self, logits: Tensor, token_sizes: Tensor) -> None:
137
+ super(Distrubition, self).__init__()
138
+ self.logits = logits
139
+ self.token_sizes = token_sizes
140
+
141
+ @lazy_property
142
+ def log_partitions(self):
143
+ partitions, frames = cky_partitions(
144
+ logits=self.logits,
145
+ token_sizes=self.token_sizes,
146
+ semiring=Log,
147
+ )
148
+
149
+ return partitions, frames
150
+
151
+ @lazy_property
152
+ def max(self):
153
+ partitions, frames = cky_partitions(
154
+ logits=self.logits,
155
+ token_sizes=self.token_sizes,
156
+ semiring=Max,
157
+ )
158
+
159
+ return partitions, frames
160
+
161
+ @lazy_property
162
+ def marginals(self) -> Frames:
163
+ partitions, frames = self.log_partitions
164
+ return torch.autograd.grad(
165
+ partitions, frames, torch.ones_like(partitions),
166
+ create_graph=True, retain_graph=True,
167
+ only_inputs=True, allow_unused=True,
168
+ )
169
+
170
+ @lazy_property
171
+ def grads(self) -> Frames:
172
+ partitions, frames = self.max
173
+ return torch.autograd.grad(
174
+ partitions, frames, torch.ones_like(partitions),
175
+ create_graph=False, retain_graph=False,
176
+ only_inputs=True, allow_unused=True,
177
+ )
178
+
179
+ @staticmethod
180
+ def gather(marginals: Frames, grads: Frames, spans: Tensor):
181
+ b, _, _, k, _ = marginals[0].size()
182
+
183
+ xs, ys, zs = [], [], []
184
+ for w, (x, grad) in enumerate(zip(marginals, grads)):
185
+ mask, y = grad.max(dim=-1, keepdim=True)
186
+ mask = mask.sum(dim=-2, keepdim=True) > 0
187
+
188
+ z = diag(spans, offset=w)[..., None, None, None]
189
+
190
+ xs.append(torch.masked_select(x, mask))
191
+ ys.append(torch.masked_select(y, mask))
192
+ zs.append(torch.masked_select(z, mask))
193
+
194
+ xs = torch.cat(xs, dim=0).view((-1, k, 2))
195
+ ys = torch.cat(ys, dim=0).view((-1, k))
196
+ zs = torch.cat(zs, dim=0)
197
+ return xs, ys, zs
198
+
199
+ @lazy_property
200
+ def argmax(self) -> C:
201
+ b, t, _, _, _ = self.grads[0].size()
202
+
203
+ b = torch.arange(b, dtype=torch.long, device=self.grads[0].device)
204
+ x = torch.arange(t, dtype=torch.long, device=self.grads[0].device)
205
+ y = torch.arange(t, dtype=torch.long, device=self.grads[0].device)
206
+ b, x, y = torch.broadcast_tensors(b[:, None, None], x[None, :, None], y[None, None, :])
207
+
208
+ data = []
209
+ for w, grad in enumerate(self.grads):
210
+ mask, z = grad.max(dim=-1, keepdim=False)
211
+ mask = mask.sum(dim=-1, keepdim=False) > 0
212
+
213
+ data.append(torch.stack([
214
+ torch.masked_select(diag(b, offset=w)[..., None], mask),
215
+ torch.masked_select(diag(x, offset=w)[..., None], mask),
216
+ torch.masked_select(diag(y, offset=w)[..., None], mask),
217
+ torch.masked_select(bits_to_long(z), mask),
218
+ ], dim=-1))
219
+
220
+ data = torch.cat(data, dim=0)
221
+ b = torch.argsort(data[..., 0], dim=0, descending=False)
222
+ return C(data=data[b, 1:], token_sizes=self.token_sizes * 2 - 1)
223
+
224
+
225
+ class HashLayer(nn.Module):
226
+ def __init__(self, config: ParserkerConfig) -> None:
227
+ super(HashLayer, self).__init__()
228
+
229
+ self.num_bits = config.num_bits
230
+ self.bit_size = (config.hidden_size + config.num_bits - 1) // config.num_bits
231
+ self.scale = self.bit_size ** -0.5
232
+
233
+ self.q_proj = nn.Linear(config.hidden_size, self.num_bits * self.bit_size, bias=True)
234
+ self.k_proj = nn.Linear(config.hidden_size, self.num_bits * self.bit_size, bias=True)
235
+
236
+ def forward(self, q: Tensor, k: Tensor):
237
+ q = self.q_proj(q).unflatten(dim=-1, sizes=(self.num_bits, 1, self.bit_size))
238
+ k = self.k_proj(k).unflatten(dim=-1, sizes=(self.num_bits, self.bit_size, 1))
239
+
240
+ return (q[:, :, None] @ k[:, None, :]).flatten(start_dim=-3).transpose(1, 2) * self.scale
241
+
242
+
243
+ class ParserkerModel(PreTrainedModel):
244
+ config_class = ParserkerConfig
245
+ base_model_prefix = "backbone"
246
+ _tied_weights_keys = {}
247
+
248
+ def __init__(self, config: ParserkerConfig, **kwargs):
249
+ super(ParserkerModel, self).__init__(config=config, **kwargs)
250
+
251
+ self.pad_token_id = config.pad_token_id
252
+ self.num_bits = config.num_bits
253
+
254
+ self.backbone = RobertaModel(config, add_pooling_layer=False)
255
+ self.hash_layer = HashLayer(config)
256
+
257
+ @property
258
+ def all_tied_weights_keys(self):
259
+ return getattr(self, "_tied_weights_keys", [])
260
+
261
+ def forward(self, input_ids: Z, duration: Z) -> Tensor:
262
+ out = self.backbone.forward(
263
+ input_ids=input_ids.left(self.pad_token_id).data,
264
+ attention_mask=input_ids.bmask(),
265
+ return_dict=True,
266
+ )
267
+
268
+ tensor = L(data=out.last_hidden_state, token_sizes=input_ids.cat().token_sizes)
269
+ tensor, token_sizes = tensor.seg(duration, segment_mean).trunc((1, 1))
270
+
271
+ logits = self.hash_layer(tensor, tensor)
272
+
273
+ return L(data=logits, token_sizes=token_sizes)
274
+
275
+ def parse(self, input_ids: Z, duration: C):
276
+ logits, token_sizes = self(input_ids, duration)
277
+ logits = logits.clone().requires_grad_(True)
278
+
279
+ dist = Distrubition(logits=logits, token_sizes=token_sizes)
280
+ return dist.argmax
281
+
282
+ def to_tree(self, words, spans) -> Tree:
283
+ stack = []
284
+
285
+ for x, y, z in sorted(spans, key=lambda item: (item[0], -item[1]), reverse=True):
286
+ children = []
287
+ while len(stack) > 0:
288
+ xx, yy, zz = stack.pop()
289
+ if x <= xx and yy <= y:
290
+ children.append(zz)
291
+ else:
292
+ stack.append((xx, yy, zz))
293
+ break
294
+
295
+ if len(children) == 0:
296
+ children = ['__tok']
297
+
298
+ stack.append((x, y, Tree(to_hex(z, self.num_bits), children)))
299
+
300
+ [(_, _, tree)] = stack
301
+
302
+ for index in range(len(tree.leaves())):
303
+ position = tree.leaf_treeposition(index)
304
+ tree[position] = words[index]
305
+
306
+ return tree
tokenization_parserker.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ from typing import Union
3
+
4
+ import torch
5
+ from nltk.tokenize import TreebankWordTokenizer
6
+ from torchrua import C
7
+ from transformers.models.roberta import RobertaTokenizer
8
+ from transformers.tokenization_utils_base import TextInput, PreTokenizedInput, EncodedInput
9
+
10
+ nltk_tokenizer = TreebankWordTokenizer()
11
+
12
+ PTB_UNESCAPE_MAPPING = {
13
+ "«": '"',
14
+ "»": '"',
15
+ "‘": "'",
16
+ "’": "'",
17
+ "“": '"',
18
+ "”": '"',
19
+ "„": '"',
20
+ "‹": "'",
21
+ "›": "'",
22
+ "\u2013": "--", # en dash
23
+ "\u2014": "--", # em dash
24
+ }
25
+
26
+
27
+ def ptb_unescape(words: List[str]) -> List[str]:
28
+ cleaned_words = []
29
+
30
+ for word in words:
31
+ word = PTB_UNESCAPE_MAPPING.get(word, word)
32
+ # This un-escaping for / and * was not yet added for the
33
+ # parser version in https://arxiv.org/abs/1812.11760v1
34
+ # and related model releases (e.g. benepar2_en2)
35
+ word = word.replace("\\/", "/").replace("\\*", "*")
36
+ # Mid-token punctuation occurs in biomedical text
37
+ word = word.replace("-LSB-", "[").replace("-RSB-", "]")
38
+ word = word.replace("-LRB-", "(").replace("-RRB-", ")")
39
+ word = word.replace("-LCB-", "{").replace("-RCB-", "}")
40
+ word = word.replace("``", '"').replace("`", "'").replace("''", '"')
41
+ cleaned_words.append(word)
42
+
43
+ return cleaned_words
44
+
45
+
46
+ class ParserkerTokenizer(RobertaTokenizer):
47
+ def __call__(self, text: Union[TextInput, PreTokenizedInput, EncodedInput], **kwargs):
48
+ input_ids_list = []
49
+ duration_list = []
50
+
51
+ if isinstance(text, str):
52
+ tokens_list = [ptb_unescape(nltk_tokenizer.tokenize(text))]
53
+ else:
54
+ tokens_list = [ptb_unescape(nltk_tokenizer.tokenize(t)) for t in text]
55
+
56
+ for tokens in tokens_list:
57
+ out = super().__call__(
58
+ tokens,
59
+ return_attention_mask=False,
60
+ add_special_tokens=False,
61
+ is_split_into_words=False,
62
+ return_tensors=None,
63
+ )
64
+
65
+ input_ids = [t for ts in out['input_ids'] for t in ts]
66
+ duration = [len(ts) for ts in out['input_ids']]
67
+
68
+ input_ids_list.append([self.bos_token_id, *input_ids, self.eos_token_id])
69
+ duration_list.append([1, *duration, 1])
70
+
71
+ input_ids = C.new([torch.tensor(t, dtype=torch.long) for t in input_ids_list])
72
+ duration = C.new([torch.tensor(t, dtype=torch.long) for t in duration_list])
73
+
74
+ return tokens_list, input_ids, duration
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": true,
3
+ "auto_map": {
4
+ "AutoTokenizer": [
5
+ "tokenization_parserker.ParserkerTokenizer",
6
+ null
7
+ ]
8
+ },
9
+ "backend": "tokenizers",
10
+ "bos_token": "<s>",
11
+ "clean_up_tokenization_spaces": false,
12
+ "cls_token": "<s>",
13
+ "eos_token": "</s>",
14
+ "errors": "replace",
15
+ "is_local": true,
16
+ "local_files_only": false,
17
+ "mask_token": "<mask>",
18
+ "model_max_length": 512,
19
+ "pad_token": "<pad>",
20
+ "sep_token": "</s>",
21
+ "tokenizer_class": "ParserkerTokenizer",
22
+ "trim_offsets": true,
23
+ "unk_token": "<unk>",
24
+ "use_fast": true
25
+ }