arubenruben commited on
Commit
bacde3a
1 Parent(s): f465827

Upload BERT_CRF

Browse files
Files changed (3) hide show
  1. config.json +41 -0
  2. model.py +85 -0
  3. pytorch_model.bin +3 -0
config.json ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "/notebooks/src/hugging_face_pipeline/BERT-CRF/out/model",
3
+ "architectures": [
4
+ "BERT_CRF"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "model.BERT_CRF_Config",
8
+ "AutoModelForTokenClassification": "model.BERT_CRF"
9
+ },
10
+ "bert_name": "neuralmind/bert-large-portuguese-cased",
11
+ "id2label": {
12
+ "0": "O",
13
+ "1": "B-PESSOA",
14
+ "2": "I-PESSOA",
15
+ "3": "B-ORGANIZACAO",
16
+ "4": "I-ORGANIZACAO",
17
+ "5": "B-LOCAL",
18
+ "6": "I-LOCAL",
19
+ "7": "B-TEMPO",
20
+ "8": "I-TEMPO",
21
+ "9": "B-VALOR",
22
+ "10": "I-VALOR"
23
+ },
24
+ "label2id": {
25
+ "B-LOCAL": 5,
26
+ "B-ORGANIZACAO": 3,
27
+ "B-PESSOA": 1,
28
+ "B-TEMPO": 7,
29
+ "B-VALOR": 9,
30
+ "I-LOCAL": 6,
31
+ "I-ORGANIZACAO": 4,
32
+ "I-PESSOA": 2,
33
+ "I-TEMPO": 8,
34
+ "I-VALOR": 10,
35
+ "O": 0
36
+ },
37
+ "model_name": "BERT_CRF",
38
+ "model_type": "BERT_CRF",
39
+ "torch_dtype": "float32",
40
+ "transformers_version": "4.29.1"
41
+ }
model.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ from transformers import PreTrainedModel, PretrainedConfig
3
+ from transformers import BertModel, BertConfig
4
+ from transformers import AutoModelForTokenClassification, AutoConfig
5
+ from torchcrf import CRF
6
+
7
+ class BERT_CRF_Config(PretrainedConfig):
8
+ model_type = "BERT_CRF"
9
+
10
+ def __init__(self, **kwarg):
11
+ super().__init__(**kwarg)
12
+ self.model_name = "BERT_CRF"
13
+
14
+
15
+ class BERT_CRF(PreTrainedModel):
16
+ config_class = BERT_CRF_Config
17
+
18
+ def __init__(self, config):
19
+ super().__init__(config)
20
+
21
+ bert_config = BertConfig.from_pretrained(config.bert_name)
22
+
23
+ bert_config.output_attentions = True
24
+ bert_config.output_hidden_states = True
25
+
26
+ self.bert = BertModel.from_pretrained(config.bert_name, config=bert_config)
27
+
28
+ self.dropout = nn.Dropout(p=0.5)
29
+
30
+ self.linear = nn.Linear(
31
+ self.bert.config.hidden_size, config.num_labels)
32
+
33
+ self.crf = CRF(config.num_labels, batch_first=True)
34
+
35
+ def forward(self, input_ids, token_type_ids, attention_mask, labels, labels_mask):
36
+
37
+ last_hidden_layer = self.bert(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)[
38
+ 'last_hidden_state']
39
+
40
+ last_hidden_layer = self.dropout(last_hidden_layer)
41
+
42
+ logits = self.linear(last_hidden_layer)
43
+
44
+ batch_size = logits.shape[0]
45
+
46
+ output_tags = []
47
+
48
+ if labels is not None:
49
+ loss = 0
50
+
51
+ for seq_logits, seq_labels, seq_mask in zip(logits, labels, labels_mask):
52
+ # Index logits and labels using prediction mask to pass only the
53
+ # first subtoken of each word to CRF.
54
+ seq_logits = seq_logits[seq_mask].unsqueeze(0)
55
+ seq_labels = seq_labels[seq_mask].unsqueeze(0)
56
+
57
+ if seq_logits.numel() != 0:
58
+ loss -= self.crf(seq_logits, seq_labels,
59
+ reduction='token_mean')
60
+
61
+ return loss / batch_size
62
+ else:
63
+ for seq_logits, seq_mask in zip(logits, labels_mask):
64
+ seq_logits = seq_logits[seq_mask].unsqueeze(0)
65
+
66
+ if seq_logits.numel() != 0:
67
+ tags = self.crf.decode(seq_logits)
68
+ else:
69
+ tags = [[]]
70
+
71
+ # Unpack "batch" results
72
+ output_tags.append(tags[0])
73
+
74
+ return output_tags
75
+
76
+
77
+ class ModelRegisterStep():
78
+ def __call__(self, args):
79
+
80
+ AutoConfig.register("BERT_CRF", BERT_CRF_Config)
81
+ AutoModelForTokenClassification.register(BERT_CRF_Config, BERT_CRF)
82
+
83
+ return {
84
+ **args,
85
+ }
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ea365911ab10412519feafffc47fc9192b4ebf2e42f71081ee900c66f3284063
3
+ size 1337762471