napsternxg
commited on
Commit
•
e70cca1
1
Parent(s):
8a3f0c5
End of training
Browse files- README.md +69 -0
- added_tokens.json +7 -0
- all_results.json +41 -0
- config.json +55 -0
- crf.py +243 -0
- crf_model.py +78 -0
- pytorch_model.bin +3 -0
- special_tokens_map.json +7 -0
- tokenizer.json +0 -0
- tokenizer_config.json +65 -0
- trainer_state.json +112 -0
- training_args.bin +3 -0
- validation_results.json +41 -0
- vocab.txt +0 -0
README.md
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
license: apache-2.0
|
3 |
+
base_model: sentence-transformers/paraphrase-MiniLM-L3-v2
|
4 |
+
tags:
|
5 |
+
- generated_from_trainer
|
6 |
+
datasets:
|
7 |
+
- nyt_ingredients
|
8 |
+
model-index:
|
9 |
+
- name: nyt_ingredients-crf-tagger-paraphrase-MiniLM-L3-v2
|
10 |
+
results: []
|
11 |
+
---
|
12 |
+
|
13 |
+
<!-- This model card has been generated automatically according to the information the Trainer had access to. You
|
14 |
+
should probably proofread and complete it, then remove this comment. -->
|
15 |
+
|
16 |
+
# nyt_ingredients-crf-tagger-paraphrase-MiniLM-L3-v2
|
17 |
+
|
18 |
+
This model is a fine-tuned version of [sentence-transformers/paraphrase-MiniLM-L3-v2](https://huggingface.co/sentence-transformers/paraphrase-MiniLM-L3-v2) on the nyt_ingredients dataset.
|
19 |
+
It achieves the following results on the evaluation set:
|
20 |
+
- Loss: 10.2590
|
21 |
+
- Comment: {'precision': 0.03657262277951933, 'recall': 0.0264750378214826, 'f1': 0.030715225976305396, 'number': 1322}
|
22 |
+
- Name: {'precision': 0.5238095238095238, 'recall': 0.01245753114382786, 'f1': 0.024336283185840708, 'number': 1766}
|
23 |
+
- Qty: {'precision': 0.0234375, 'recall': 0.0020920502092050207, 'f1': 0.003841229193341869, 'number': 1434}
|
24 |
+
- Range End: {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 17}
|
25 |
+
- Unit: {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 1166}
|
26 |
+
- Overall Precision: 0.0419
|
27 |
+
- Overall Recall: 0.0105
|
28 |
+
- Overall F1: 0.0168
|
29 |
+
- Overall Accuracy: 0.1284
|
30 |
+
|
31 |
+
## Model description
|
32 |
+
|
33 |
+
More information needed
|
34 |
+
|
35 |
+
## Intended uses & limitations
|
36 |
+
|
37 |
+
More information needed
|
38 |
+
|
39 |
+
## Training and evaluation data
|
40 |
+
|
41 |
+
More information needed
|
42 |
+
|
43 |
+
## Training procedure
|
44 |
+
|
45 |
+
### Training hyperparameters
|
46 |
+
|
47 |
+
The following hyperparameters were used during training:
|
48 |
+
- learning_rate: 5e-05
|
49 |
+
- train_batch_size: 32
|
50 |
+
- eval_batch_size: 32
|
51 |
+
- seed: 42
|
52 |
+
- optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-08
|
53 |
+
- lr_scheduler_type: linear
|
54 |
+
- num_epochs: 2
|
55 |
+
|
56 |
+
### Training results
|
57 |
+
|
58 |
+
| Training Loss | Epoch | Step | Validation Loss | Comment | Name | Qty | Range End | Unit | Overall Precision | Overall Recall | Overall F1 | Overall Accuracy |
|
59 |
+
|:-------------:|:-----:|:----:|:---------------:|:------------------------------------------------------------------------------------------------------------:|:-------------------------------------------------------------------------------------------------------------:|:-----------------------------------------------------------------------------------------------------:|:----------------------------------------------------------:|:------------------------------------------------------------:|:-----------------:|:--------------:|:----------:|:----------------:|
|
60 |
+
| No log | 1.0 | 54 | 11.5992 | {'precision': 0.03826530612244898, 'recall': 0.0340393343419062, 'f1': 0.036028823058446756, 'number': 1322} | {'precision': 0.9047619047619048, 'recall': 0.010758776896942242, 'f1': 0.021264689423614997, 'number': 1766} | {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 1434} | {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 17} | {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 1166} | 0.0526 | 0.0112 | 0.0185 | 0.1319 |
|
61 |
+
| No log | 2.0 | 108 | 10.2590 | {'precision': 0.03657262277951933, 'recall': 0.0264750378214826, 'f1': 0.030715225976305396, 'number': 1322} | {'precision': 0.5238095238095238, 'recall': 0.01245753114382786, 'f1': 0.024336283185840708, 'number': 1766} | {'precision': 0.0234375, 'recall': 0.0020920502092050207, 'f1': 0.003841229193341869, 'number': 1434} | {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 17} | {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 1166} | 0.0419 | 0.0105 | 0.0168 | 0.1284 |
|
62 |
+
|
63 |
+
|
64 |
+
### Framework versions
|
65 |
+
|
66 |
+
- Transformers 4.34.0
|
67 |
+
- Pytorch 2.0.1+cu118
|
68 |
+
- Datasets 2.14.5
|
69 |
+
- Tokenizers 0.14.0
|
added_tokens.json
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"[CLS]": 101,
|
3 |
+
"[MASK]": 103,
|
4 |
+
"[PAD]": 0,
|
5 |
+
"[SEP]": 102,
|
6 |
+
"[UNK]": 100
|
7 |
+
}
|
all_results.json
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"epoch": 2.0,
|
3 |
+
"eval_COMMENT": {
|
4 |
+
"f1": 0.030715225976305396,
|
5 |
+
"number": 1322,
|
6 |
+
"precision": 0.03657262277951933,
|
7 |
+
"recall": 0.0264750378214826
|
8 |
+
},
|
9 |
+
"eval_NAME": {
|
10 |
+
"f1": 0.024336283185840708,
|
11 |
+
"number": 1766,
|
12 |
+
"precision": 0.5238095238095238,
|
13 |
+
"recall": 0.01245753114382786
|
14 |
+
},
|
15 |
+
"eval_QTY": {
|
16 |
+
"f1": 0.003841229193341869,
|
17 |
+
"number": 1434,
|
18 |
+
"precision": 0.0234375,
|
19 |
+
"recall": 0.0020920502092050207
|
20 |
+
},
|
21 |
+
"eval_RANGE_END": {
|
22 |
+
"f1": 0.0,
|
23 |
+
"number": 17,
|
24 |
+
"precision": 0.0,
|
25 |
+
"recall": 0.0
|
26 |
+
},
|
27 |
+
"eval_UNIT": {
|
28 |
+
"f1": 0.0,
|
29 |
+
"number": 1166,
|
30 |
+
"precision": 0.0,
|
31 |
+
"recall": 0.0
|
32 |
+
},
|
33 |
+
"eval_loss": 10.259025573730469,
|
34 |
+
"eval_overall_accuracy": 0.12838815472171314,
|
35 |
+
"eval_overall_f1": 0.016813787305590584,
|
36 |
+
"eval_overall_precision": 0.04189944134078212,
|
37 |
+
"eval_overall_recall": 0.010517090271691499,
|
38 |
+
"eval_runtime": 15.7061,
|
39 |
+
"eval_samples_per_second": 108.365,
|
40 |
+
"eval_steps_per_second": 3.438
|
41 |
+
}
|
config.json
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "sentence-transformers/paraphrase-MiniLM-L3-v2",
|
3 |
+
"architectures": [
|
4 |
+
"PretrainedCRFModel"
|
5 |
+
],
|
6 |
+
"attention_probs_dropout_prob": 0.1,
|
7 |
+
"auto_map": {
|
8 |
+
"AutoModel": "crf_model.PretrainedCRFModel"
|
9 |
+
},
|
10 |
+
"classifier_dropout": null,
|
11 |
+
"gradient_checkpointing": false,
|
12 |
+
"hidden_act": "gelu",
|
13 |
+
"hidden_dropout_prob": 0.1,
|
14 |
+
"hidden_size": 384,
|
15 |
+
"id2label": {
|
16 |
+
"0": "O",
|
17 |
+
"1": "B-COMMENT",
|
18 |
+
"2": "I-COMMENT",
|
19 |
+
"3": "B-NAME",
|
20 |
+
"4": "I-NAME",
|
21 |
+
"5": "B-RANGE_END",
|
22 |
+
"6": "I-RANGE_END",
|
23 |
+
"7": "B-QTY",
|
24 |
+
"8": "I-QTY",
|
25 |
+
"9": "B-UNIT",
|
26 |
+
"10": "I-UNIT"
|
27 |
+
},
|
28 |
+
"initializer_range": 0.02,
|
29 |
+
"intermediate_size": 1536,
|
30 |
+
"label2id": {
|
31 |
+
"B-COMMENT": 1,
|
32 |
+
"B-NAME": 3,
|
33 |
+
"B-QTY": 7,
|
34 |
+
"B-RANGE_END": 5,
|
35 |
+
"B-UNIT": 9,
|
36 |
+
"I-COMMENT": 2,
|
37 |
+
"I-NAME": 4,
|
38 |
+
"I-QTY": 8,
|
39 |
+
"I-RANGE_END": 6,
|
40 |
+
"I-UNIT": 10,
|
41 |
+
"O": 0
|
42 |
+
},
|
43 |
+
"layer_norm_eps": 1e-12,
|
44 |
+
"max_position_embeddings": 512,
|
45 |
+
"model_type": "bert",
|
46 |
+
"num_attention_heads": 12,
|
47 |
+
"num_hidden_layers": 3,
|
48 |
+
"pad_token_id": 0,
|
49 |
+
"position_embedding_type": "absolute",
|
50 |
+
"torch_dtype": "float32",
|
51 |
+
"transformers_version": "4.34.0",
|
52 |
+
"type_vocab_size": 2,
|
53 |
+
"use_cache": true,
|
54 |
+
"vocab_size": 30522
|
55 |
+
}
|
crf.py
ADDED
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import unittest
|
2 |
+
from dataclasses import dataclass
|
3 |
+
from typing import Optional, Tuple
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from torch import nn
|
7 |
+
from torch.autograd import Variable
|
8 |
+
|
9 |
+
|
10 |
+
@dataclass
|
11 |
+
class CRFOutput:
|
12 |
+
loss: Optional[torch.tensor]
|
13 |
+
real_path_score: Optional[torch.tensor]
|
14 |
+
total_score: torch.tensor
|
15 |
+
best_path_score: torch.tensor
|
16 |
+
best_path: torch.tensor
|
17 |
+
|
18 |
+
|
19 |
+
class MaskedCRFLoss(nn.Module):
|
20 |
+
__constants__ = ["num_tags", "mask_id"]
|
21 |
+
|
22 |
+
num_tags: int
|
23 |
+
mask_id: int
|
24 |
+
|
25 |
+
def __init__(self, num_tags: int, mask_id: int = 0):
|
26 |
+
super().__init__()
|
27 |
+
self.num_tags = num_tags
|
28 |
+
self.mask_id = mask_id
|
29 |
+
self.transitions = nn.Parameter(torch.randn(num_tags, num_tags))
|
30 |
+
self.start_transitions = nn.Parameter(torch.randn(num_tags))
|
31 |
+
self.stop_transitions = nn.Parameter(torch.randn(num_tags))
|
32 |
+
|
33 |
+
def extra_repr(self) -> str:
|
34 |
+
s = "num_tags={num_tags}, mask_id={mask_id}"
|
35 |
+
return s.format(**self.__dict__)
|
36 |
+
|
37 |
+
def forward(self, emissions, tags, mask, return_best_path=False):
|
38 |
+
# emissions: (seq_length, batch_size, num_tags)
|
39 |
+
# tags: (seq_length, batch_size)
|
40 |
+
# mask: (seq_length, batch_size)
|
41 |
+
|
42 |
+
seq_length, batch_size = tags.shape
|
43 |
+
mask = mask.float()
|
44 |
+
# set return_best_path as True always during eval.
|
45 |
+
# During training it slows things down as best path is not needed
|
46 |
+
if not self.training:
|
47 |
+
return_best_path = True
|
48 |
+
# Compute the total likelihood
|
49 |
+
total_score, best_path_score, best_path = self.compute_log_partition_function(
|
50 |
+
emissions, mask, return_best_path=return_best_path
|
51 |
+
)
|
52 |
+
|
53 |
+
if tags is None:
|
54 |
+
return CRFOutput(None, None, total_score, best_path_score, best_path)
|
55 |
+
|
56 |
+
# Compute the likelihood of the real path
|
57 |
+
real_path_score = torch.zeros(batch_size).to(tags.device)
|
58 |
+
real_path_score += self.start_transitions[tags[0]] # batch_size
|
59 |
+
for i in range(1, seq_length):
|
60 |
+
current_tag = tags[i]
|
61 |
+
real_path_score += self.transitions[tags[i - 1], current_tag] * mask[i]
|
62 |
+
real_path_score += emissions[i, range(batch_size), current_tag] * mask[i]
|
63 |
+
# Transition to STOP_TAG
|
64 |
+
real_path_score += self.stop_transitions[tags[-1]]
|
65 |
+
|
66 |
+
# Return the negative log likelihood
|
67 |
+
loss = torch.mean(total_score - real_path_score)
|
68 |
+
return CRFOutput(loss, real_path_score, total_score, best_path_score, best_path)
|
69 |
+
|
70 |
+
def compute_log_partition_function(self, emissions, mask, return_best_path=False):
|
71 |
+
init_alphas = self.start_transitions + emissions[0] # (batch_size, num_tags)
|
72 |
+
forward_var = init_alphas
|
73 |
+
forward_viterbi_var = init_alphas
|
74 |
+
# backpointers holds the best tag id at each time step, we accumulate these in reverse order
|
75 |
+
backpointers = []
|
76 |
+
|
77 |
+
for i, emission in enumerate(emissions[1:, :, :], 1):
|
78 |
+
broadcast_emission = emission.unsqueeze(2) # (batch_size, num_tags, 1)
|
79 |
+
broadcast_transmissions = self.transitions.unsqueeze(
|
80 |
+
0
|
81 |
+
) # (1, num_tags, num_tags)
|
82 |
+
|
83 |
+
# Compute next
|
84 |
+
next_tag_var = (
|
85 |
+
forward_var.unsqueeze(1) + broadcast_emission + broadcast_transmissions
|
86 |
+
)
|
87 |
+
next_tag_viterbi_var = (
|
88 |
+
forward_viterbi_var.unsqueeze(1)
|
89 |
+
+ broadcast_emission
|
90 |
+
+ broadcast_transmissions
|
91 |
+
)
|
92 |
+
|
93 |
+
next_unmasked_forward_var = torch.logsumexp(next_tag_var, dim=2)
|
94 |
+
viterbi_scores, best_next_tags = torch.max(next_tag_viterbi_var, dim=2)
|
95 |
+
# If mask == 1 use the next_unmasked_forward_var else copy the forward_var
|
96 |
+
# Update forward_var
|
97 |
+
forward_var = (
|
98 |
+
mask[i].unsqueeze(-1) * next_unmasked_forward_var
|
99 |
+
+ (1 - mask[i]).unsqueeze(-1) * forward_var
|
100 |
+
)
|
101 |
+
# Update viterbi with mask
|
102 |
+
forward_viterbi_var = (
|
103 |
+
mask[i].unsqueeze(-1) * viterbi_scores
|
104 |
+
+ (1 - mask[i]).unsqueeze(-1) * forward_viterbi_var
|
105 |
+
)
|
106 |
+
backpointers.append(best_next_tags)
|
107 |
+
|
108 |
+
# Transition to STOP_TAG
|
109 |
+
terminal_var = forward_var + self.stop_transitions
|
110 |
+
terminal_viterbi_var = forward_viterbi_var + self.stop_transitions
|
111 |
+
|
112 |
+
alpha = torch.logsumexp(terminal_var, dim=1)
|
113 |
+
best_path_score, best_final_tags = torch.max(terminal_viterbi_var, dim=1)
|
114 |
+
|
115 |
+
best_path = None
|
116 |
+
if return_best_path:
|
117 |
+
# backtrace
|
118 |
+
best_path = [best_final_tags]
|
119 |
+
for bptrs, mask_data in zip(reversed(backpointers), torch.flip(mask, [0])):
|
120 |
+
best_tag_id = torch.gather(
|
121 |
+
bptrs, 1, best_final_tags.unsqueeze(1)
|
122 |
+
).squeeze(1)
|
123 |
+
best_final_tags.masked_scatter_(
|
124 |
+
mask_data.to(dtype=torch.bool),
|
125 |
+
best_tag_id.masked_select(mask_data.to(dtype=torch.bool)),
|
126 |
+
)
|
127 |
+
best_path.append(best_final_tags)
|
128 |
+
# Reverse the order because we were appending in reverse
|
129 |
+
best_path = torch.stack(best_path[::-1])
|
130 |
+
best_path = best_path.where(mask == 1, -100)
|
131 |
+
|
132 |
+
return alpha, best_path_score, best_path
|
133 |
+
|
134 |
+
def viterbi_decode(self, emissions, mask):
|
135 |
+
seq_len, batch_size, num_tags = emissions.shape
|
136 |
+
|
137 |
+
# backpointers holds the best tag id at each time step, we accumulate these in reverse order
|
138 |
+
backpointers = []
|
139 |
+
|
140 |
+
# Initialize the viterbi variables in log space
|
141 |
+
init_vvars = self.start_transitions + emissions[0] # (batch_size, num_tags)
|
142 |
+
forward_var = init_vvars
|
143 |
+
|
144 |
+
for i, emission in enumerate(emissions[1:, :, :], 1):
|
145 |
+
broadcast_emission = emission.unsqueeze(2)
|
146 |
+
broadcast_transmissions = self.transitions.unsqueeze(0)
|
147 |
+
next_tag_var = (
|
148 |
+
forward_var.unsqueeze(1) + broadcast_emission + broadcast_transmissions
|
149 |
+
)
|
150 |
+
|
151 |
+
viterbi_scores, best_next_tags = torch.max(next_tag_var, 2)
|
152 |
+
# If mask == 1 use the next_unmasked_forward_var else copy the forward_var
|
153 |
+
forward_var = (
|
154 |
+
mask[i].unsqueeze(-1) * viterbi_scores
|
155 |
+
+ (1 - mask[i]).unsqueeze(-1) * forward_var
|
156 |
+
)
|
157 |
+
backpointers.append(best_next_tags)
|
158 |
+
|
159 |
+
# Transition to STOP_TAG
|
160 |
+
terminal_var = forward_var + self.stop_transitions
|
161 |
+
best_path_score, best_final_tags = torch.max(terminal_var, dim=1)
|
162 |
+
|
163 |
+
# backtrace
|
164 |
+
best_path = [best_final_tags]
|
165 |
+
for bptrs, mask_data in zip(reversed(backpointers), torch.flip(mask, [0])):
|
166 |
+
best_tag_id = torch.gather(bptrs, 1, best_final_tags.unsqueeze(1)).squeeze(
|
167 |
+
1
|
168 |
+
)
|
169 |
+
best_final_tags.masked_scatter_(
|
170 |
+
mask_data.to(dtype=torch.bool),
|
171 |
+
best_tag_id.masked_select(mask_data.to(dtype=torch.bool)),
|
172 |
+
)
|
173 |
+
best_path.append(best_final_tags)
|
174 |
+
|
175 |
+
# Reverse the order because we were appending in reverse
|
176 |
+
best_path = torch.stack(best_path[::-1])
|
177 |
+
best_path = best_path.where(mask == 1, -100)
|
178 |
+
|
179 |
+
return best_path, best_path_score
|
180 |
+
|
181 |
+
|
182 |
+
class MaskedCRFLossTest(unittest.TestCase):
|
183 |
+
def setUp(self):
|
184 |
+
self.num_tags = 5
|
185 |
+
self.mask_id = 0
|
186 |
+
|
187 |
+
self.crf_model = MaskedCRFLoss(self.num_tags, self.mask_id)
|
188 |
+
|
189 |
+
self.seq_length, self.batch_size = 11, 5
|
190 |
+
# Making up some inputs
|
191 |
+
# emissions = Variable(torch.randn(seq_length, batch_size, num_tags))
|
192 |
+
# tags = Variable(torch.randint(num_tags, (seq_length, batch_size)))
|
193 |
+
# mask = Variable(torch.ones(seq_length, batch_size))
|
194 |
+
self.emissions = torch.randn(self.seq_length, self.batch_size, self.num_tags)
|
195 |
+
self.tags = torch.randint(self.num_tags, (self.seq_length, self.batch_size))
|
196 |
+
# mask = torch.ones(seq_length, batch_size)
|
197 |
+
self.mask = torch.randint(2, (self.seq_length, self.batch_size))
|
198 |
+
|
199 |
+
def test_forward(self):
|
200 |
+
# Checking if forward runs successfully
|
201 |
+
try:
|
202 |
+
output = self.crf_model(self.emissions, self.tags, self.mask)
|
203 |
+
print("Forward function runs successfully!")
|
204 |
+
except Exception as e:
|
205 |
+
print("Forward function couldn't run successfully:", e)
|
206 |
+
|
207 |
+
def test_viterbi_decode(self):
|
208 |
+
# Checking if viterbi_decode runs successfully
|
209 |
+
try:
|
210 |
+
path, best_path_score = self.crf_model.viterbi_decode(
|
211 |
+
self.emissions, self.mask
|
212 |
+
)
|
213 |
+
print(path.T)
|
214 |
+
print("Viterbi decoding function runs successfully!")
|
215 |
+
except Exception as e:
|
216 |
+
print("Viterbi decoding function couldn't run successfully:", e)
|
217 |
+
|
218 |
+
def test_forward_output(self):
|
219 |
+
# Simple check if losses are non-negative
|
220 |
+
output = self.crf_model(self.emissions, self.tags, self.mask)
|
221 |
+
loss = output.loss
|
222 |
+
self.assertTrue((loss > 0).all())
|
223 |
+
|
224 |
+
def test_compute_log_partition_function_output(self):
|
225 |
+
# Simply checking if the output is non-negative
|
226 |
+
(
|
227 |
+
partition,
|
228 |
+
best_path_score,
|
229 |
+
best_path,
|
230 |
+
) = self.crf_model.compute_log_partition_function(self.emissions, self.mask)
|
231 |
+
self.assertTrue((partition > 0).all())
|
232 |
+
|
233 |
+
def test_viterbi_decode_output(self):
|
234 |
+
print(self.mask.T)
|
235 |
+
# Check whether the output shape is correct and lies within valid tag range
|
236 |
+
path, best_path_score = self.crf_model.viterbi_decode(self.emissions, self.mask)
|
237 |
+
print(path.T)
|
238 |
+
self.assertEqual(
|
239 |
+
path.shape, (self.seq_length, self.batch_size)
|
240 |
+
) # checking dimensions
|
241 |
+
self.assertTrue(
|
242 |
+
((0 <= path) | (path == -100)).all() and (path < self.num_tags).all()
|
243 |
+
) # checking tag validity
|
crf_model.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from typing import Optional, Tuple
|
3 |
+
|
4 |
+
import datasets
|
5 |
+
import torch
|
6 |
+
from datasets import load_dataset
|
7 |
+
from torch import nn
|
8 |
+
from transformers import AutoConfig, AutoModelForTokenClassification, AutoTokenizer
|
9 |
+
from transformers.modeling_outputs import TokenClassifierOutput
|
10 |
+
from transformers.modeling_utils import PreTrainedModel
|
11 |
+
|
12 |
+
from .crf import MaskedCRFLoss
|
13 |
+
|
14 |
+
|
15 |
+
@dataclass
|
16 |
+
class TokenClassifierCRFOutput(TokenClassifierOutput):
|
17 |
+
loss: Optional[torch.FloatTensor] = None
|
18 |
+
real_path_score: Optional[torch.FloatTensor] = None
|
19 |
+
total_score: torch.FloatTensor = None
|
20 |
+
best_path_score: torch.FloatTensor = None
|
21 |
+
best_path: Optional[torch.LongTensor] = None
|
22 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
23 |
+
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
24 |
+
|
25 |
+
|
26 |
+
class PretrainedCRFModel(PreTrainedModel):
|
27 |
+
config_class = AutoConfig
|
28 |
+
|
29 |
+
def __init__(self, config):
|
30 |
+
super().__init__(config)
|
31 |
+
self.encoder = AutoModelForTokenClassification.from_pretrained(
|
32 |
+
config._name_or_path, config=config
|
33 |
+
)
|
34 |
+
self.crf_model = MaskedCRFLoss(self.config.num_labels)
|
35 |
+
self.post_init()
|
36 |
+
|
37 |
+
def forward(
|
38 |
+
self,
|
39 |
+
input_ids=None,
|
40 |
+
token_type_ids=None,
|
41 |
+
attention_mask=None,
|
42 |
+
labels=None,
|
43 |
+
return_best_path=False,
|
44 |
+
**kwargs
|
45 |
+
):
|
46 |
+
encoder_output = self.encoder(
|
47 |
+
input_ids=input_ids,
|
48 |
+
token_type_ids=token_type_ids,
|
49 |
+
attention_mask=attention_mask,
|
50 |
+
**kwargs
|
51 |
+
)
|
52 |
+
|
53 |
+
# Convert output to seq length as first dim
|
54 |
+
|
55 |
+
emissions = encoder_output.logits.transpose(1, 0)
|
56 |
+
tags = labels.transpose(1, 0)
|
57 |
+
mask = tags != -100
|
58 |
+
tags = tags.where(mask, 0) # CRF cant support -100 id
|
59 |
+
|
60 |
+
crf_output = self.crf_model(
|
61 |
+
emissions, tags, mask, return_best_path=return_best_path
|
62 |
+
)
|
63 |
+
|
64 |
+
# Convert best_path to batch first
|
65 |
+
best_path = crf_output.best_path
|
66 |
+
if best_path is not None:
|
67 |
+
best_path = best_path.transpose(1, 0)
|
68 |
+
|
69 |
+
output = TokenClassifierCRFOutput(
|
70 |
+
loss=crf_output.loss,
|
71 |
+
real_path_score=crf_output.real_path_score,
|
72 |
+
total_score=crf_output.total_score,
|
73 |
+
best_path_score=crf_output.best_path_score,
|
74 |
+
best_path=best_path,
|
75 |
+
hidden_states=encoder_output.hidden_states,
|
76 |
+
attentions=encoder_output.attentions,
|
77 |
+
)
|
78 |
+
return output
|
pytorch_model.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:957a43de6de74f7ecd4ba71f7c05807847718572ae4a3af8cd70c7f799baf978
|
3 |
+
size 69004255
|
special_tokens_map.json
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cls_token": "[CLS]",
|
3 |
+
"mask_token": "[MASK]",
|
4 |
+
"pad_token": "[PAD]",
|
5 |
+
"sep_token": "[SEP]",
|
6 |
+
"unk_token": "[UNK]"
|
7 |
+
}
|
tokenizer.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
tokenizer_config.json
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"added_tokens_decoder": {
|
3 |
+
"0": {
|
4 |
+
"content": "[PAD]",
|
5 |
+
"lstrip": false,
|
6 |
+
"normalized": false,
|
7 |
+
"rstrip": false,
|
8 |
+
"single_word": false,
|
9 |
+
"special": true
|
10 |
+
},
|
11 |
+
"100": {
|
12 |
+
"content": "[UNK]",
|
13 |
+
"lstrip": false,
|
14 |
+
"normalized": false,
|
15 |
+
"rstrip": false,
|
16 |
+
"single_word": false,
|
17 |
+
"special": true
|
18 |
+
},
|
19 |
+
"101": {
|
20 |
+
"content": "[CLS]",
|
21 |
+
"lstrip": false,
|
22 |
+
"normalized": false,
|
23 |
+
"rstrip": false,
|
24 |
+
"single_word": false,
|
25 |
+
"special": true
|
26 |
+
},
|
27 |
+
"102": {
|
28 |
+
"content": "[SEP]",
|
29 |
+
"lstrip": false,
|
30 |
+
"normalized": false,
|
31 |
+
"rstrip": false,
|
32 |
+
"single_word": false,
|
33 |
+
"special": true
|
34 |
+
},
|
35 |
+
"103": {
|
36 |
+
"content": "[MASK]",
|
37 |
+
"lstrip": false,
|
38 |
+
"normalized": false,
|
39 |
+
"rstrip": false,
|
40 |
+
"single_word": false,
|
41 |
+
"special": true
|
42 |
+
}
|
43 |
+
},
|
44 |
+
"additional_special_tokens": [],
|
45 |
+
"clean_up_tokenization_spaces": true,
|
46 |
+
"cls_token": "[CLS]",
|
47 |
+
"do_basic_tokenize": true,
|
48 |
+
"do_lower_case": true,
|
49 |
+
"mask_token": "[MASK]",
|
50 |
+
"max_length": 128,
|
51 |
+
"model_max_length": 512,
|
52 |
+
"never_split": null,
|
53 |
+
"pad_to_multiple_of": null,
|
54 |
+
"pad_token": "[PAD]",
|
55 |
+
"pad_token_type_id": 0,
|
56 |
+
"padding_side": "right",
|
57 |
+
"sep_token": "[SEP]",
|
58 |
+
"stride": 0,
|
59 |
+
"strip_accents": null,
|
60 |
+
"tokenize_chinese_chars": true,
|
61 |
+
"tokenizer_class": "BertTokenizer",
|
62 |
+
"truncation_side": "right",
|
63 |
+
"truncation_strategy": "longest_first",
|
64 |
+
"unk_token": "[UNK]"
|
65 |
+
}
|
trainer_state.json
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"best_metric": null,
|
3 |
+
"best_model_checkpoint": null,
|
4 |
+
"epoch": 2.0,
|
5 |
+
"eval_steps": 500,
|
6 |
+
"global_step": 108,
|
7 |
+
"is_hyper_param_search": false,
|
8 |
+
"is_local_process_zero": true,
|
9 |
+
"is_world_process_zero": true,
|
10 |
+
"log_history": [
|
11 |
+
{
|
12 |
+
"epoch": 1.0,
|
13 |
+
"eval_COMMENT": {
|
14 |
+
"f1": 0.036028823058446756,
|
15 |
+
"number": 1322,
|
16 |
+
"precision": 0.03826530612244898,
|
17 |
+
"recall": 0.0340393343419062
|
18 |
+
},
|
19 |
+
"eval_NAME": {
|
20 |
+
"f1": 0.021264689423614997,
|
21 |
+
"number": 1766,
|
22 |
+
"precision": 0.9047619047619048,
|
23 |
+
"recall": 0.010758776896942242
|
24 |
+
},
|
25 |
+
"eval_QTY": {
|
26 |
+
"f1": 0.0,
|
27 |
+
"number": 1434,
|
28 |
+
"precision": 0.0,
|
29 |
+
"recall": 0.0
|
30 |
+
},
|
31 |
+
"eval_RANGE_END": {
|
32 |
+
"f1": 0.0,
|
33 |
+
"number": 17,
|
34 |
+
"precision": 0.0,
|
35 |
+
"recall": 0.0
|
36 |
+
},
|
37 |
+
"eval_UNIT": {
|
38 |
+
"f1": 0.0,
|
39 |
+
"number": 1166,
|
40 |
+
"precision": 0.0,
|
41 |
+
"recall": 0.0
|
42 |
+
},
|
43 |
+
"eval_loss": 11.59915828704834,
|
44 |
+
"eval_overall_accuracy": 0.13186071187421627,
|
45 |
+
"eval_overall_f1": 0.018494437220054907,
|
46 |
+
"eval_overall_precision": 0.05263157894736842,
|
47 |
+
"eval_overall_recall": 0.011218229623137599,
|
48 |
+
"eval_runtime": 14.249,
|
49 |
+
"eval_samples_per_second": 119.447,
|
50 |
+
"eval_steps_per_second": 3.79,
|
51 |
+
"step": 54
|
52 |
+
},
|
53 |
+
{
|
54 |
+
"epoch": 2.0,
|
55 |
+
"eval_COMMENT": {
|
56 |
+
"f1": 0.030715225976305396,
|
57 |
+
"number": 1322,
|
58 |
+
"precision": 0.03657262277951933,
|
59 |
+
"recall": 0.0264750378214826
|
60 |
+
},
|
61 |
+
"eval_NAME": {
|
62 |
+
"f1": 0.024336283185840708,
|
63 |
+
"number": 1766,
|
64 |
+
"precision": 0.5238095238095238,
|
65 |
+
"recall": 0.01245753114382786
|
66 |
+
},
|
67 |
+
"eval_QTY": {
|
68 |
+
"f1": 0.003841229193341869,
|
69 |
+
"number": 1434,
|
70 |
+
"precision": 0.0234375,
|
71 |
+
"recall": 0.0020920502092050207
|
72 |
+
},
|
73 |
+
"eval_RANGE_END": {
|
74 |
+
"f1": 0.0,
|
75 |
+
"number": 17,
|
76 |
+
"precision": 0.0,
|
77 |
+
"recall": 0.0
|
78 |
+
},
|
79 |
+
"eval_UNIT": {
|
80 |
+
"f1": 0.0,
|
81 |
+
"number": 1166,
|
82 |
+
"precision": 0.0,
|
83 |
+
"recall": 0.0
|
84 |
+
},
|
85 |
+
"eval_loss": 10.259025573730469,
|
86 |
+
"eval_overall_accuracy": 0.12838815472171314,
|
87 |
+
"eval_overall_f1": 0.016813787305590584,
|
88 |
+
"eval_overall_precision": 0.04189944134078212,
|
89 |
+
"eval_overall_recall": 0.010517090271691499,
|
90 |
+
"eval_runtime": 15.3099,
|
91 |
+
"eval_samples_per_second": 111.17,
|
92 |
+
"eval_steps_per_second": 3.527,
|
93 |
+
"step": 108
|
94 |
+
},
|
95 |
+
{
|
96 |
+
"epoch": 2.0,
|
97 |
+
"step": 108,
|
98 |
+
"total_flos": 3334407253032.0,
|
99 |
+
"train_loss": 12.8824462890625,
|
100 |
+
"train_runtime": 151.3626,
|
101 |
+
"train_samples_per_second": 22.489,
|
102 |
+
"train_steps_per_second": 0.714
|
103 |
+
}
|
104 |
+
],
|
105 |
+
"logging_steps": 500,
|
106 |
+
"max_steps": 108,
|
107 |
+
"num_train_epochs": 2,
|
108 |
+
"save_steps": 500,
|
109 |
+
"total_flos": 3334407253032.0,
|
110 |
+
"trial_name": null,
|
111 |
+
"trial_params": null
|
112 |
+
}
|
training_args.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2c5254dbe381bcf40ffb21d73c66be21031100bd1243da923a4f4347812cc60c
|
3 |
+
size 4155
|
validation_results.json
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"epoch": 2.0,
|
3 |
+
"eval_COMMENT": {
|
4 |
+
"f1": 0.030715225976305396,
|
5 |
+
"number": 1322,
|
6 |
+
"precision": 0.03657262277951933,
|
7 |
+
"recall": 0.0264750378214826
|
8 |
+
},
|
9 |
+
"eval_NAME": {
|
10 |
+
"f1": 0.024336283185840708,
|
11 |
+
"number": 1766,
|
12 |
+
"precision": 0.5238095238095238,
|
13 |
+
"recall": 0.01245753114382786
|
14 |
+
},
|
15 |
+
"eval_QTY": {
|
16 |
+
"f1": 0.003841229193341869,
|
17 |
+
"number": 1434,
|
18 |
+
"precision": 0.0234375,
|
19 |
+
"recall": 0.0020920502092050207
|
20 |
+
},
|
21 |
+
"eval_RANGE_END": {
|
22 |
+
"f1": 0.0,
|
23 |
+
"number": 17,
|
24 |
+
"precision": 0.0,
|
25 |
+
"recall": 0.0
|
26 |
+
},
|
27 |
+
"eval_UNIT": {
|
28 |
+
"f1": 0.0,
|
29 |
+
"number": 1166,
|
30 |
+
"precision": 0.0,
|
31 |
+
"recall": 0.0
|
32 |
+
},
|
33 |
+
"eval_loss": 10.259025573730469,
|
34 |
+
"eval_overall_accuracy": 0.12838815472171314,
|
35 |
+
"eval_overall_f1": 0.016813787305590584,
|
36 |
+
"eval_overall_precision": 0.04189944134078212,
|
37 |
+
"eval_overall_recall": 0.010517090271691499,
|
38 |
+
"eval_runtime": 15.7061,
|
39 |
+
"eval_samples_per_second": 108.365,
|
40 |
+
"eval_steps_per_second": 3.438
|
41 |
+
}
|
vocab.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|