qgyd2021 commited on
Commit
147e44c
·
1 Parent(s): 71c05fc

[update]add code

Browse files
Files changed (49) hide show
  1. .gitattributes +1 -0
  2. .gitignore +8 -0
  3. README.md +4 -5
  4. examples/text_classification/telemarketing_intent_classification/1.prepare_data.py +85 -0
  5. examples/text_classification/telemarketing_intent_classification/2.make_hierarchical_labels.py +74 -0
  6. examples/text_classification/telemarketing_intent_classification/3.make_vocabulary.py +78 -0
  7. examples/text_classification/telemarketing_intent_classification/4.train_model.py +172 -0
  8. examples/text_classification/telemarketing_intent_classification/5.predict_model.py +117 -0
  9. examples/text_classification/telemarketing_intent_classification/6.make_json_config.py +121 -0
  10. examples/text_classification/telemarketing_intent_classification/7.predict_by_archive.py +74 -0
  11. examples/text_classification/telemarketing_intent_classification/run.sh +254 -0
  12. main.py +141 -0
  13. predict.py +142 -0
  14. pretrained_models/bert-base-japanese/.gitattributes +9 -0
  15. pretrained_models/bert-base-japanese/README.md +43 -0
  16. pretrained_models/bert-base-japanese/config.json +20 -0
  17. pretrained_models/bert-base-japanese/tokenizer_config.json +5 -0
  18. pretrained_models/bert-base-japanese/vocab.txt +0 -0
  19. pretrained_models/bert-base-uncased/.gitattributes +11 -0
  20. pretrained_models/bert-base-uncased/LICENSE +201 -0
  21. pretrained_models/bert-base-uncased/README.md +251 -0
  22. pretrained_models/bert-base-uncased/config.json +23 -0
  23. pretrained_models/bert-base-uncased/tokenizer.json +0 -0
  24. pretrained_models/bert-base-uncased/tokenizer_config.json +3 -0
  25. pretrained_models/bert-base-uncased/vocab.txt +0 -0
  26. pretrained_models/bert-base-vietnamese-uncased/.gitattributes +17 -0
  27. pretrained_models/bert-base-vietnamese-uncased/README.md +22 -0
  28. pretrained_models/bert-base-vietnamese-uncased/config.json +27 -0
  29. pretrained_models/bert-base-vietnamese-uncased/special_tokens_map.json +1 -0
  30. pretrained_models/bert-base-vietnamese-uncased/tokenizer_config.json +1 -0
  31. pretrained_models/bert-base-vietnamese-uncased/vocab.txt +0 -0
  32. pretrained_models/chinese-bert-wwm-ext/.gitattributes +9 -0
  33. pretrained_models/chinese-bert-wwm-ext/README.md +52 -0
  34. pretrained_models/chinese-bert-wwm-ext/added_tokens.json +1 -0
  35. pretrained_models/chinese-bert-wwm-ext/config.json +26 -0
  36. pretrained_models/chinese-bert-wwm-ext/special_tokens_map.json +1 -0
  37. pretrained_models/chinese-bert-wwm-ext/tokenizer.json +0 -0
  38. pretrained_models/chinese-bert-wwm-ext/tokenizer_config.json +1 -0
  39. pretrained_models/chinese-bert-wwm-ext/vocab.txt +0 -0
  40. project_settings.py +12 -0
  41. requirements.txt +14 -0
  42. toolbox/__init__.py +6 -0
  43. toolbox/allennlp_models/text_classifier/dataset_readers/__init__.py +6 -0
  44. toolbox/allennlp_models/text_classifier/dataset_readers/hierarchical_classification_json.py +99 -0
  45. toolbox/allennlp_models/text_classifier/models/__init__.py +6 -0
  46. toolbox/allennlp_models/text_classifier/models/hierarchical_text_classifier.py +291 -0
  47. toolbox/torch/__init__.py +6 -0
  48. toolbox/torch/modules/__init__.py +6 -0
  49. toolbox/torch/modules/loss.py +738 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.th filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+
2
+ .git/
3
+ .idea/
4
+
5
+ **/flagged/
6
+ **/__pycache__/
7
+
8
+ trained_models/
README.md CHANGED
@@ -1,13 +1,12 @@
1
  ---
2
  title: Telemarketing Intent Classification
3
- emoji: 🐢
4
  colorFrom: indigo
5
- colorTo: gray
6
  sdk: gradio
7
- sdk_version: 3.50.2
8
- app_file: app.py
9
  pinned: false
10
- license: apache-2.0
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
  title: Telemarketing Intent Classification
3
+ emoji: 😻
4
  colorFrom: indigo
5
+ colorTo: yellow
6
  sdk: gradio
7
+ sdk_version: 3.20.1
8
+ app_file: main.py
9
  pinned: false
 
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
examples/text_classification/telemarketing_intent_classification/1.prepare_data.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ import json
5
+ import os
6
+ from pathlib import Path
7
+ import random
8
+ import sys
9
+
10
+ pwd = os.path.abspath(os.path.dirname(__file__))
11
+ sys.path.append(os.path.join(pwd, '../../../'))
12
+
13
+ import pandas as pd
14
+ from tqdm import tqdm
15
+
16
+ from project_settings import project_path
17
+
18
+
19
+ def get_args():
20
+ parser = argparse.ArgumentParser()
21
+ parser.add_argument('--without_irrelevant_domain', action='store_true')
22
+ parser.add_argument('--dataset_filename', default='dataset.xlsx', type=str)
23
+ parser.add_argument('--do_lowercase', action='store_true')
24
+
25
+ parser.add_argument('--train_subset', default='train.json', type=str)
26
+ parser.add_argument('--valid_subset', default='valid.json', type=str)
27
+
28
+ args = parser.parse_args()
29
+ return args
30
+
31
+
32
+ def main():
33
+ args = get_args()
34
+
35
+ n_hierarchical = 2
36
+
37
+ df = pd.read_excel(args.dataset_filename)
38
+ df = df[df['selected'] == 1]
39
+
40
+ dataset = list()
41
+ for i, row in tqdm(df.iterrows(), total=len(df)):
42
+ text = row['text']
43
+ label0 = row['label0']
44
+ if args.without_irrelevant_domain and label0 == '无关领域':
45
+ continue
46
+
47
+ text = str(text)
48
+ if args.do_lowercase:
49
+ text = text.lower()
50
+
51
+ labels = {'label{}'.format(idx): str(row['label{}'.format(idx)]) for idx in range(n_hierarchical)}
52
+
53
+ random1 = random.random()
54
+ random2 = random.random()
55
+
56
+ dataset.append({
57
+ 'text': text,
58
+ **labels,
59
+
60
+ 'random1': random1,
61
+ 'random2': random2,
62
+ 'flag': 'TRAIN' if random2 < 0.8 else 'TEST',
63
+ })
64
+
65
+ dataset = list(sorted(dataset, key=lambda x: x['random1'], reverse=True))
66
+
67
+ f_train = open(args.train_subset, 'w', encoding='utf-8')
68
+ f_test = open(args.valid_subset, 'w', encoding='utf-8')
69
+
70
+ for row in tqdm(dataset):
71
+
72
+ flag = row['flag']
73
+ row = json.dumps(row, ensure_ascii=False)
74
+ if flag == 'TRAIN':
75
+ f_train.write('{}\n'.format(row))
76
+ else:
77
+ f_test.write('{}\n'.format(row))
78
+
79
+ f_train.close()
80
+ f_test.close()
81
+ return
82
+
83
+
84
+ if __name__ == '__main__':
85
+ main()
examples/text_classification/telemarketing_intent_classification/2.make_hierarchical_labels.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ from collections import OrderedDict
5
+ import os
6
+ from pathlib import Path
7
+ import pickle
8
+ import sys
9
+
10
+ pwd = os.path.abspath(os.path.dirname(__file__))
11
+ sys.path.append(os.path.join(pwd, '../../'))
12
+
13
+ import pandas as pd
14
+
15
+
16
+ def get_args():
17
+ parser = argparse.ArgumentParser()
18
+ parser.add_argument('--dataset_filename', default='dataset.xlsx', type=str)
19
+ parser.add_argument('--hierarchical_labels_pkl', default='hierarchical_labels.pkl', type=str)
20
+
21
+ args = parser.parse_args()
22
+ return args
23
+
24
+
25
+ def main():
26
+ args = get_args()
27
+
28
+ n_hierarchical = 2
29
+
30
+ df = pd.read_excel(args.dataset_filename)
31
+ df = df[df['selected'] == 1]
32
+
33
+ # 生成 hierarchical_labels
34
+ temp_hierarchical_labels = OrderedDict()
35
+
36
+ for i, row in df.iterrows():
37
+ text = row['text']
38
+ label0 = row['label0']
39
+ label1 = row['label1']
40
+
41
+ if temp_hierarchical_labels.get(label0) is None:
42
+ temp_hierarchical_labels[label0] = list()
43
+
44
+ if label1 not in temp_hierarchical_labels[label0]:
45
+ temp_hierarchical_labels[label0].append(label1)
46
+
47
+ if n_hierarchical > 2:
48
+ hierarchical_labels = OrderedDict()
49
+ for idx in range(n_hierarchical - 2):
50
+ for k, v in temp_hierarchical_labels.items():
51
+ parent, label = k.rsplit('_', maxsplit=1)
52
+
53
+ if hierarchical_labels.get(parent) is None:
54
+ hierarchical_labels[parent] = OrderedDict({
55
+ label: v
56
+ })
57
+ else:
58
+ if hierarchical_labels[parent].get(label) is None:
59
+ hierarchical_labels[parent][label] = v
60
+ else:
61
+ hierarchical_labels = temp_hierarchical_labels
62
+
63
+ with open(args.hierarchical_labels_pkl, 'wb') as f:
64
+ pickle.dump(hierarchical_labels, f)
65
+
66
+ with open(args.hierarchical_labels_pkl, 'rb') as f:
67
+ hierarchical_labels = pickle.load(f)
68
+
69
+ # print(hierarchical_labels)
70
+ return
71
+
72
+
73
+ if __name__ == '__main__':
74
+ main()
examples/text_classification/telemarketing_intent_classification/3.make_vocabulary.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ from collections import OrderedDict
5
+ import json
6
+ import os
7
+ from pathlib import Path
8
+ import pickle
9
+ import sys
10
+
11
+ pwd = os.path.abspath(os.path.dirname(__file__))
12
+ sys.path.append(os.path.join(pwd, '../../../'))
13
+
14
+ import pandas as pd
15
+
16
+ from allennlp.data.vocabulary import Vocabulary
17
+
18
+ from project_settings import project_path
19
+
20
+
21
+ def get_args():
22
+ parser = argparse.ArgumentParser()
23
+ parser.add_argument(
24
+ "--pretrained_model_path",
25
+ default=(project_path / "pretrained_models/chinese-bert-wwm-ext").as_posix(),
26
+ type=str
27
+ )
28
+ parser.add_argument('--hierarchical_labels_pkl', default='hierarchical_labels.pkl', type=str)
29
+ parser.add_argument('--vocabulary', default='vocabulary', type=str)
30
+
31
+ args = parser.parse_args()
32
+ return args
33
+
34
+
35
+ def main():
36
+ args = get_args()
37
+
38
+ with open(args.hierarchical_labels_pkl, 'rb') as f:
39
+ hierarchical_labels = pickle.load(f)
40
+ # print(hierarchical_labels)
41
+ # 深度遍历
42
+ token_to_index = OrderedDict()
43
+ tasks = [hierarchical_labels]
44
+ while len(tasks) != 0:
45
+ task = tasks.pop(0)
46
+ for parent, downstream in task.items():
47
+ if isinstance(downstream, list):
48
+ for label in downstream:
49
+ if pd.isna(label):
50
+ continue
51
+ label = '{}_{}'.format(parent, label)
52
+ token_to_index[label] = len(token_to_index)
53
+ elif isinstance(downstream, OrderedDict):
54
+ new_task = OrderedDict()
55
+ for k, v in downstream.items():
56
+ new_task['{}_{}'.format(parent, k)] = v
57
+ tasks.append(new_task)
58
+ else:
59
+ raise NotImplementedError
60
+
61
+ vocabulary = Vocabulary(non_padded_namespaces=['tokens', 'labels'])
62
+ for label, index in token_to_index.items():
63
+ vocabulary.add_token_to_namespace(label, namespace='labels')
64
+
65
+ vocabulary.set_from_file(
66
+ filename=os.path.join(args.pretrained_model_path, 'vocab.txt'),
67
+ is_padded=False,
68
+ oov_token='[UNK]',
69
+ namespace='tokens',
70
+ )
71
+ vocabulary.save_to_files(args.vocabulary)
72
+
73
+ print('注意检查 Vocabulary 中标签的顺序与 hierarchical_labels 是否一致. ')
74
+ return
75
+
76
+
77
+ if __name__ == '__main__':
78
+ main()
examples/text_classification/telemarketing_intent_classification/4.train_model.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ 多层 softmax 实现多极文本分类
5
+
6
+ 由于初始化时, 各层 softmax 的概率趋于平衡.
7
+
8
+ 因此在第一层时 `领域无关` 就分到了 50% 的概率.
9
+
10
+ `领域相关` 中的各类别去分剩下的 50% 的概率.
11
+ 这会导致模型一开始时输出的类别全是 `领域无关`, 这导致模型无法优化.
12
+
13
+ 解决方案:
14
+ 1. 从数据集中去除 `领域无关` 数据. 并训练模型.
15
+ 2. 等模型收敛之后, 再使用包含 `领域无关` 的数据集, 让模型加载之前的权重, 并重新开始训练模型.
16
+
17
+ """
18
+ import argparse
19
+ import json
20
+ import os
21
+ import sys
22
+
23
+ pwd = os.path.abspath(os.path.dirname(__file__))
24
+ sys.path.append(os.path.join(pwd, '../../../'))
25
+
26
+ from allennlp.data.data_loaders.multiprocess_data_loader import MultiProcessDataLoader
27
+ from allennlp.data.token_indexers.single_id_token_indexer import SingleIdTokenIndexer
28
+ from allennlp.data.token_indexers.pretrained_transformer_indexer import PretrainedTransformerIndexer
29
+ from allennlp.data.tokenizers.pretrained_transformer_tokenizer import PretrainedTransformerTokenizer
30
+ from allennlp.data.vocabulary import Vocabulary
31
+ from allennlp.modules.text_field_embedders.basic_text_field_embedder import BasicTextFieldEmbedder
32
+ from allennlp.modules.token_embedders.embedding import Embedding
33
+ from allennlp.modules.seq2vec_encoders.cnn_encoder import CnnEncoder
34
+ from allennlp_models.rc.modules.seq2seq_encoders.stacked_self_attention import StackedSelfAttentionEncoder
35
+ from allennlp.training.gradient_descent_trainer import GradientDescentTrainer
36
+ from allennlp.training.checkpointer import Checkpointer
37
+ from pytorch_pretrained_bert.optimization import BertAdam
38
+ import torch
39
+
40
+ from project_settings import project_path
41
+ from toolbox.allennlp_models.text_classifier.models.hierarchical_text_classifier import HierarchicalClassifier
42
+ from toolbox.allennlp_models.text_classifier.dataset_readers.hierarchical_classification_json import HierarchicalClassificationJsonReader
43
+
44
+
45
+ def get_args():
46
+ parser = argparse.ArgumentParser()
47
+ parser.add_argument(
48
+ "--pretrained_model_path",
49
+ default=(project_path / "pretrained_models/chinese-bert-wwm-ext").as_posix(),
50
+ type=str
51
+ )
52
+ parser.add_argument('--hierarchical_labels_pkl', default='data_dir/hierarchical_labels.pkl', type=str)
53
+ parser.add_argument('--vocabulary_dir', default='data_dir/vocabulary', type=str)
54
+ parser.add_argument('--train_subset', default='data_dir/train.json', type=str)
55
+ parser.add_argument('--valid_subset', default='data_dir/valid.json', type=str)
56
+ parser.add_argument("--serialization_dir", default="data_dir/serialization_dir", type=str)
57
+ # parser.add_argument('--checkpoint_path', default="data_dir/serialization_dir/best.th", type=str)
58
+ parser.add_argument('--checkpoint_path', default=None, type=str)
59
+
60
+ args = parser.parse_args()
61
+ return args
62
+
63
+
64
+ def main():
65
+ args = get_args()
66
+
67
+ dataset_reader = HierarchicalClassificationJsonReader(
68
+ token_indexers={
69
+ 'tokens': SingleIdTokenIndexer(
70
+ namespace='tokens',
71
+ lowercase_tokens=True,
72
+ token_min_padding_length=5,
73
+ )
74
+ },
75
+ tokenizer=PretrainedTransformerTokenizer(
76
+ model_name=os.path.join(project_path, args.pretrained_model_path),
77
+ ),
78
+ )
79
+
80
+ vocabulary = Vocabulary.from_files(args.vocabulary_dir)
81
+
82
+ data_loader = MultiProcessDataLoader(
83
+ reader=dataset_reader,
84
+ data_path=args.train_subset,
85
+ batch_size=64,
86
+ shuffle=True,
87
+ )
88
+ data_loader.index_with(vocabulary)
89
+
90
+ validation_data_loader = MultiProcessDataLoader(
91
+ reader=dataset_reader,
92
+ data_path=args.valid_subset,
93
+ batch_size=64,
94
+ shuffle=True,
95
+ )
96
+ validation_data_loader.index_with(vocabulary)
97
+
98
+ model = HierarchicalClassifier(
99
+ vocab=vocabulary,
100
+ hierarchical_labels_pkl=args.hierarchical_labels_pkl,
101
+ text_field_embedder=BasicTextFieldEmbedder(
102
+ token_embedders={
103
+ 'tokens': Embedding(
104
+ num_embeddings=vocabulary.get_vocab_size('tokens'),
105
+ embedding_dim=128,
106
+ )
107
+ }
108
+ ),
109
+ seq2seq_encoder=StackedSelfAttentionEncoder(
110
+ input_dim=128,
111
+ hidden_dim=128,
112
+ projection_dim=128,
113
+ feedforward_hidden_dim=128,
114
+ num_layers=2,
115
+ num_attention_heads=4,
116
+ use_positional_encoding=False,
117
+ ),
118
+ seq2vec_encoder=CnnEncoder(
119
+ embedding_dim=128,
120
+ num_filters=32,
121
+ ngram_filter_sizes=(2, 3, 4, 5),
122
+ ),
123
+ )
124
+
125
+ if args.checkpoint_path is not None:
126
+ with open(args.checkpoint_path, "rb") as f:
127
+ state_dict = torch.load(f, map_location=torch.device("cpu"))
128
+ model.load_state_dict(state_dict)
129
+ model.train()
130
+
131
+ parameters = [v for n, v in model.named_parameters()]
132
+
133
+ optimizer = BertAdam(
134
+ params=parameters,
135
+ lr=5e-4,
136
+ warmup=0.1,
137
+ t_total=10000,
138
+ # t_total=200000,
139
+ schedule='warmup_linear'
140
+ )
141
+
142
+ if torch.cuda.is_available():
143
+ cuda_device = 0
144
+ model.cuda(device=0)
145
+ else:
146
+ cuda_device = -1
147
+
148
+ print(cuda_device)
149
+
150
+ trainer = GradientDescentTrainer(
151
+ cuda_device=cuda_device,
152
+
153
+ model=model,
154
+ optimizer=optimizer,
155
+ checkpointer=Checkpointer(
156
+ serialization_dir=args.serialization_dir,
157
+ keep_most_recent_by_count=10,
158
+ ),
159
+ data_loader=data_loader,
160
+ validation_data_loader=validation_data_loader,
161
+ patience=5,
162
+ validation_metric='+accuracy',
163
+ num_epochs=100,
164
+ serialization_dir=args.serialization_dir,
165
+ )
166
+ trainer.train()
167
+
168
+ return
169
+
170
+
171
+ if __name__ == '__main__':
172
+ main()
examples/text_classification/telemarketing_intent_classification/5.predict_model.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ import os
5
+ import time
6
+
7
+ from allennlp.data.tokenizers.pretrained_transformer_tokenizer import PretrainedTransformerTokenizer
8
+ from allennlp.data.token_indexers.single_id_token_indexer import SingleIdTokenIndexer
9
+ from allennlp.data.vocabulary import Vocabulary
10
+ from allennlp.modules.text_field_embedders.basic_text_field_embedder import BasicTextFieldEmbedder
11
+ from allennlp.modules.token_embedders.embedding import Embedding
12
+ from allennlp.modules.seq2vec_encoders.cnn_encoder import CnnEncoder
13
+ from allennlp_models.rc.modules.seq2seq_encoders.stacked_self_attention import StackedSelfAttentionEncoder
14
+ from allennlp.predictors.text_classifier import TextClassifierPredictor
15
+ import torch
16
+
17
+ from project_settings import project_path
18
+ from toolbox.allennlp_models.text_classifier.models.hierarchical_text_classifier import HierarchicalClassifier
19
+ from toolbox.allennlp_models.text_classifier.dataset_readers.hierarchical_classification_json import HierarchicalClassificationJsonReader
20
+
21
+
22
+ def get_args():
23
+ parser = argparse.ArgumentParser()
24
+ parser.add_argument(
25
+ "--pretrained_model_path",
26
+ default=(project_path / "pretrained_models/chinese-bert-wwm-ext").as_posix(),
27
+ type=str
28
+ )
29
+ parser.add_argument('--hierarchical_labels_pkl', default='data_dir/hierarchical_labels.pkl', type=str)
30
+ parser.add_argument('--vocabulary_dir', default='data_dir/vocabulary', type=str)
31
+
32
+ parser.add_argument(
33
+ "--serialization_dir",
34
+ default="data_dir/serialization_dir2",
35
+ type=str
36
+ )
37
+ args = parser.parse_args()
38
+ return args
39
+
40
+
41
+ def main():
42
+ args = get_args()
43
+
44
+ dataset_reader = HierarchicalClassificationJsonReader(
45
+ token_indexers={
46
+ 'tokens': SingleIdTokenIndexer(
47
+ namespace='tokens',
48
+ lowercase_tokens=True,
49
+ token_min_padding_length=5,
50
+ )
51
+ },
52
+ tokenizer=PretrainedTransformerTokenizer(
53
+ model_name=os.path.join(project_path, args.pretrained_model_path),
54
+ ),
55
+ )
56
+
57
+ vocabulary = Vocabulary.from_files(args.vocabulary_dir)
58
+
59
+ model = HierarchicalClassifier(
60
+ vocab=vocabulary,
61
+ hierarchical_labels_pkl=args.hierarchical_labels_pkl,
62
+ text_field_embedder=BasicTextFieldEmbedder(
63
+ token_embedders={
64
+ 'tokens': Embedding(
65
+ num_embeddings=vocabulary.get_vocab_size('tokens'),
66
+ embedding_dim=128,
67
+ )
68
+ }
69
+ ),
70
+ seq2seq_encoder=StackedSelfAttentionEncoder(
71
+ input_dim=128,
72
+ hidden_dim=128,
73
+ projection_dim=128,
74
+ feedforward_hidden_dim=128,
75
+ num_layers=2,
76
+ num_attention_heads=4,
77
+ use_positional_encoding=False,
78
+ ),
79
+ seq2vec_encoder=CnnEncoder(
80
+ embedding_dim=128,
81
+ num_filters=32,
82
+ ngram_filter_sizes=(2, 3, 4, 5),
83
+ ),
84
+ )
85
+
86
+ checkpoint_path = os.path.join(args.serialization_dir, "best.th")
87
+ with open(checkpoint_path, 'rb') as f:
88
+ state_dict = torch.load(f, map_location="cpu")
89
+ model.load_state_dict(state_dict, strict=True)
90
+ model.eval()
91
+
92
+ predictor = TextClassifierPredictor(
93
+ model=model,
94
+ dataset_reader=dataset_reader,
95
+ )
96
+
97
+ while True:
98
+ text = input("text: ")
99
+ if text == "Quit":
100
+ break
101
+
102
+ json_dict = {'sentence': text}
103
+
104
+ begin_time = time.time()
105
+ outputs = predictor.predict_json(
106
+ json_dict
107
+ )
108
+
109
+ outputs = predictor._model.decode(outputs)
110
+ label = outputs['label']
111
+ print(label)
112
+ print('time cost: {}'.format(time.time() - begin_time))
113
+ return
114
+
115
+
116
+ if __name__ == '__main__':
117
+ main()
examples/text_classification/telemarketing_intent_classification/6.make_json_config.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ import json
5
+ import os
6
+
7
+ from allennlp.data.vocabulary import Vocabulary
8
+ import torch
9
+
10
+ from project_settings import project_path
11
+
12
+
13
+ def get_args():
14
+ parser = argparse.ArgumentParser()
15
+ parser.add_argument(
16
+ "--pretrained_model_path",
17
+ default=(project_path / "pretrained_models/chinese-bert-wwm-ext").as_posix(),
18
+ type=str
19
+ )
20
+ parser.add_argument('--hierarchical_labels_pkl', default='data_dir/hierarchical_labels.pkl', type=str)
21
+ parser.add_argument('--vocabulary_dir', default='data_dir/vocabulary', type=str)
22
+ parser.add_argument('--train_subset', default='data_dir/train.json', type=str)
23
+ parser.add_argument('--valid_subset', default='data_dir/valid.json', type=str)
24
+ parser.add_argument("--serialization_dir", default="data_dir/serialization_dir", type=str)
25
+ parser.add_argument("--json_config_dir", default="data_dir", type=str)
26
+ args = parser.parse_args()
27
+ return args
28
+
29
+
30
+ def main():
31
+ args = get_args()
32
+
33
+ vocabulary = Vocabulary.from_files(args.vocabulary_dir)
34
+
35
+ if torch.cuda.is_available():
36
+ cuda_device = 0
37
+ else:
38
+ cuda_device = -1
39
+
40
+ json_config = {
41
+ "dataset_reader": {
42
+ "type": "hierarchical_classification_json",
43
+ "token_indexers": {
44
+ "tokens": {
45
+ "type": "single_id",
46
+ "namespace": "tokens",
47
+ "lowercase_tokens": True,
48
+ "token_min_padding_length": 5
49
+ }
50
+ },
51
+ "tokenizer": {
52
+ "type": "pretrained_transformer",
53
+ "model_name": args.pretrained_model_path
54
+ }
55
+ },
56
+ "train_data_path": args.train_subset,
57
+ "validation_data_path": args.valid_subset,
58
+ "vocabulary": {
59
+ "directory_path": args.vocabulary_dir,
60
+ },
61
+ "model": {
62
+ "type": "hierarchical_classifier",
63
+ "hierarchical_labels_pkl": args.hierarchical_labels_pkl,
64
+ "text_field_embedder": {
65
+ "token_embedders": {
66
+ "tokens": {
67
+ "type": "embedding",
68
+ "num_embeddings": vocabulary.get_vocab_size(namespace="tokens"),
69
+ "embedding_dim": 128
70
+ }
71
+ }
72
+ },
73
+ "seq2seq_encoder": {
74
+ "type": "stacked_self_attention",
75
+ "input_dim": 128,
76
+ "hidden_dim": 128,
77
+ "projection_dim": 128,
78
+ "feedforward_hidden_dim": 128,
79
+ "num_layers": 2,
80
+ "num_attention_heads": 4,
81
+ "use_positional_encoding": False
82
+ },
83
+ "seq2vec_encoder": {
84
+ "type": "cnn",
85
+ "embedding_dim": 128,
86
+ "num_filters": 32,
87
+ "ngram_filter_sizes": (2, 3, 4, 5),
88
+ },
89
+ },
90
+ "data_loader": {
91
+ "type": "multiprocess",
92
+ "batch_size": 64,
93
+ "shuffle": True
94
+ },
95
+ "trainer": {
96
+ "type": "gradient_descent",
97
+ "cuda_device": cuda_device,
98
+ "optimizer": {
99
+ "type": "bert_adam",
100
+ "lr": 5e-5,
101
+ "warmup": 0.1,
102
+ "t_total": 50000,
103
+ "schedule": "warmup_linear"
104
+ },
105
+ "checkpointer": {
106
+ "serialization_dir": args.serialization_dir,
107
+ "keep_most_recent_by_count": 10
108
+ },
109
+ "patience": 5,
110
+ "validation_metric": "+accuracy",
111
+ "num_epochs": 200
112
+ }
113
+ }
114
+
115
+ with open(os.path.join(args.json_config_dir, "config.json"), "w", encoding="utf-8") as f:
116
+ json.dump(json_config, f, indent=4, ensure_ascii=False)
117
+ return
118
+
119
+
120
+ if __name__ == '__main__':
121
+ main()
examples/text_classification/telemarketing_intent_classification/7.predict_by_archive.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ import os
5
+ import time
6
+
7
+ from allennlp.data.tokenizers.pretrained_transformer_tokenizer import PretrainedTransformerTokenizer
8
+ from allennlp.data.token_indexers.single_id_token_indexer import SingleIdTokenIndexer
9
+ from allennlp.data.vocabulary import Vocabulary
10
+ from allennlp.modules.text_field_embedders.basic_text_field_embedder import BasicTextFieldEmbedder
11
+ from allennlp.modules.token_embedders.embedding import Embedding
12
+ from allennlp.modules.seq2vec_encoders.cnn_encoder import CnnEncoder
13
+ from allennlp.models.archival import archive_model, load_archive
14
+ from allennlp_models.rc.modules.seq2seq_encoders.stacked_self_attention import StackedSelfAttentionEncoder
15
+ from allennlp.predictors.predictor import Predictor
16
+ from allennlp.predictors.text_classifier import TextClassifierPredictor
17
+ import torch
18
+
19
+ from project_settings import project_path
20
+ from toolbox.allennlp_models.text_classifier.models.hierarchical_text_classifier import HierarchicalClassifier
21
+ from toolbox.allennlp_models.text_classifier.dataset_readers.hierarchical_classification_json import HierarchicalClassificationJsonReader
22
+
23
+
24
+ def get_args():
25
+ parser = argparse.ArgumentParser()
26
+ parser.add_argument(
27
+ "--text",
28
+ default="给我推荐一些篮球游戏?",
29
+ type=str
30
+ )
31
+ parser.add_argument(
32
+ "--archive_file",
33
+ default=(project_path / "trained_models/telemarketing_intent_classification_vi").as_posix(),
34
+ type=str
35
+ )
36
+ parser.add_argument(
37
+ "--pretrained_model_path",
38
+ default=(project_path / "pretrained_models/chinese-bert-wwm-ext").as_posix(),
39
+ type=str
40
+ )
41
+ parser.add_argument(
42
+ "--predictor_name",
43
+ default="text_classifier",
44
+ type=str
45
+ )
46
+ args = parser.parse_args()
47
+ return args
48
+
49
+
50
+ def main():
51
+ args = get_args()
52
+
53
+ archive = load_archive(archive_file=args.archive_file)
54
+
55
+ predictor = Predictor.from_archive(archive, predictor_name=args.predictor_name)
56
+
57
+ json_dict = {
58
+ "sentence": args.text
59
+ }
60
+
61
+ begin_time = time.time()
62
+ outputs = predictor.predict_json(
63
+ json_dict
64
+ )
65
+ outputs = predictor._model.decode(outputs)
66
+ label = outputs['label']
67
+ print(label)
68
+ print('time cost: {}'.format(time.time() - begin_time))
69
+
70
+ return
71
+
72
+
73
+ if __name__ == '__main__':
74
+ main()
examples/text_classification/telemarketing_intent_classification/run.sh ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ # nohup sh run.sh --system_version centos --stage 0 --stop_stage 5 &
4
+
5
+ # sh run.sh --system_version windows --stage -2 --stop_stage 8
6
+ # sh run.sh --system_version windows --stage 0 --stop_stage 0
7
+ # sh run.sh --system_version windows --stage 1 --stop_stage 1
8
+ # sh run.sh --system_version windows --stage 2 --stop_stage 2
9
+ # sh run.sh --system_version windows --stage 0 --stop_stage 5
10
+ # sh run.sh --system_version windows --stage 6 --stop_stage 6
11
+
12
+ # params
13
+ system_version="centos";
14
+ verbose=true;
15
+ stage=0 # start from 0 if you need to start from data preparation
16
+ stop_stage=5
17
+
18
+
19
+ #trained_model_name=telemarketing_intent_classification_cn
20
+ #pretrained_bert_model_name=chinese-bert-wwm-ext
21
+ #dataset_fn="telemarketing_intent_cn.xlsx"
22
+
23
+ #trained_model_name=telemarketing_intent_classification_en
24
+ #pretrained_bert_model_name=bert-base-uncased
25
+ #dataset_fn="telemarketing_intent_en.xlsx"
26
+
27
+ #trained_model_name=telemarketing_intent_classification_jp
28
+ #pretrained_bert_model_name=bert-base-japanese
29
+ #dataset_fn="telemarketing_intent_jp.xlsx"
30
+
31
+ trained_model_name=telemarketing_intent_classification_vi
32
+ pretrained_bert_model_name=bert-base-vietnamese-uncased
33
+ dataset_fn="telemarketing_intent_vi.xlsx"
34
+
35
+ # parse options
36
+ while true; do
37
+ [ -z "${1:-}" ] && break; # break if there are no arguments
38
+ case "$1" in
39
+ --*) name=$(echo "$1" | sed s/^--// | sed s/-/_/g);
40
+ eval '[ -z "${'"$name"'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
41
+ old_value="(eval echo \\$$name)";
42
+ if [ "${old_value}" == "true" ] || [ "${old_value}" == "false" ]; then
43
+ was_bool=true;
44
+ else
45
+ was_bool=false;
46
+ fi
47
+
48
+ # Set the variable to the right value-- the escaped quotes make it work if
49
+ # the option had spaces, like --cmd "queue.pl -sync y"
50
+ eval "${name}=\"$2\"";
51
+
52
+ # Check that Boolean-valued arguments are really Boolean.
53
+ if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
54
+ echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
55
+ exit 1;
56
+ fi
57
+ shift 2;
58
+ ;;
59
+
60
+ *) break;
61
+ esac
62
+ done
63
+
64
+ $verbose && echo "system_version: ${system_version}"
65
+
66
+ work_dir="$(pwd)"
67
+ data_dir="$(pwd)/data_dir"
68
+ pretrained_models_dir="${work_dir}/../../../pretrained_models";
69
+ trained_models_dir="${work_dir}/../../../trained_models";
70
+
71
+ serialization_dir1="${data_dir}/serialization_dir1";
72
+ serialization_dir2="${data_dir}/serialization_dir2";
73
+
74
+ mkdir -p "${data_dir}"
75
+ mkdir -p "${pretrained_models_dir}"
76
+ mkdir -p "${trained_models_dir}"
77
+ mkdir -p "${serialization_dir1}"
78
+ mkdir -p "${serialization_dir2}"
79
+
80
+ vocabulary_dir="${data_dir}/vocabulary"
81
+ train_subset="${data_dir}/train.json"
82
+ valid_subset="${data_dir}/valid.json"
83
+ hierarchical_labels_pkl="${data_dir}/hierarchical_labels.pkl"
84
+ dataset_filename="${data_dir}/${dataset_fn}"
85
+
86
+ export PYTHONPATH="${work_dir}/../../.."
87
+
88
+ if [ $system_version == "windows" ]; then
89
+ alias python3='C:/Users/tianx/PycharmProjects/virtualenv/AllenNLP/Scripts/python.exe'
90
+ elif [ $system_version == "centos" ]; then
91
+ source /data/local/bin/AllenNLP/bin/activate
92
+ alias python3='/data/local/bin/AllenNLP/bin/python3'
93
+ elif [ $system_version == "ubuntu" ]; then
94
+ source /data/local/bin/AllenNLP/bin/activate
95
+ alias python3='/data/local/bin/AllenNLP/bin/python3'
96
+ fi
97
+
98
+
99
+ declare -A pretrained_bert_model_dict
100
+ pretrained_bert_model_dict=(
101
+ ["chinese-bert-wwm-ext"]="https://huggingface.co/hfl/chinese-bert-wwm-ext"
102
+ ["bert-base-uncased"]="https://huggingface.co/bert-base-uncased"
103
+ ["bert-base-japanese"]="https://huggingface.co/cl-tohoku/bert-base-japanese"
104
+ ["bert-base-vietnamese-uncased"]="https://huggingface.co/trituenhantaoio/bert-base-vietnamese-uncased"
105
+
106
+ )
107
+ pretrained_model_dir="${pretrained_models_dir}/${pretrained_bert_model_name}"
108
+
109
+
110
+ if [ ${stage} -le -2 ] && [ ${stop_stage} -ge -2 ]; then
111
+ $verbose && echo "stage -2: download pretrained models"
112
+ cd "${work_dir}" || exit 1;
113
+
114
+ if [ ! -d "${pretrained_model_dir}" ]; then
115
+ mkdir -p "${pretrained_models_dir}"
116
+ cd "${pretrained_models_dir}" || exit 1;
117
+
118
+ repository_url="${pretrained_bert_model_dict[${pretrained_bert_model_name}]}"
119
+ git clone "${repository_url}"
120
+
121
+ cd "${pretrained_model_dir}" || exit 1;
122
+ rm flax_model.msgpack && rm pytorch_model.bin && rm tf_model.h5
123
+ rm -rf .git/
124
+ wget "${repository_url}/resolve/main/pytorch_model.bin"
125
+ fi
126
+ fi
127
+
128
+
129
+ if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
130
+ $verbose && echo "stage -1: download data"
131
+ cd "${data_dir}" || exit 1;
132
+
133
+ wget "https://huggingface.co/datasets/qgyd2021/telemarketing_intent/resolve/main/${dataset_fn}"
134
+
135
+ fi
136
+
137
+
138
+ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
139
+ $verbose && echo "stage 0: prepare data without irrelevant domain (make train subset, valid subset file)"
140
+ cd "${work_dir}" || exit 1;
141
+
142
+ python3 1.prepare_data.py \
143
+ --without_irrelevant_domain \
144
+ --dataset_filename "${dataset_filename}" \
145
+ --do_lowercase \
146
+ --train_subset "${train_subset}" \
147
+ --valid_subset "${valid_subset}" \
148
+
149
+ fi
150
+
151
+
152
+ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
153
+ $verbose && echo "stage 1: make hierarchical labels dictionary (make hierarchical_labels.pkl file)"
154
+ cd "${work_dir}" || exit 1
155
+ python3 2.make_hierarchical_labels.py \
156
+ --dataset_filename "${dataset_filename}" \
157
+ --hierarchical_labels_pkl "${hierarchical_labels_pkl}" \
158
+
159
+ fi
160
+
161
+
162
+ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
163
+ $verbose && echo "stage 2: make vocabulary (make vocabulary directory)"
164
+ cd "${work_dir}" || exit 1
165
+ python3 3.make_vocabulary.py \
166
+ --pretrained_model_path "${pretrained_model_dir}" \
167
+ --hierarchical_labels_pkl "${hierarchical_labels_pkl}" \
168
+ --vocabulary "${vocabulary_dir}" \
169
+
170
+ fi
171
+
172
+
173
+ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
174
+ $verbose && echo "stage 3: train model without irrelevant domain"
175
+ cd "${work_dir}" || exit 1
176
+
177
+ python3 4.train_model.py \
178
+ --pretrained_model_path "${pretrained_model_dir}" \
179
+ --hierarchical_labels_pkl "${hierarchical_labels_pkl}" \
180
+ --vocabulary_dir "${vocabulary_dir}" \
181
+ --train_subset "${train_subset}" \
182
+ --valid_subset "${valid_subset}" \
183
+ --serialization_dir "${serialization_dir1}" \
184
+
185
+ fi
186
+
187
+
188
+ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
189
+ $verbose && echo "stage 4: prepare data with irrelevant domain"
190
+ cd "${work_dir}" || exit 1
191
+
192
+ python3 1.prepare_data.py \
193
+ --dataset_filename "${dataset_filename}" \
194
+ --do_lowercase \
195
+ --train_subset "${train_subset}" \
196
+ --valid_subset "${valid_subset}" \
197
+
198
+ fi
199
+
200
+
201
+ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
202
+ $verbose && echo "stage 5: train model with irrelevant domain"
203
+ cd "${work_dir}" || exit 1
204
+
205
+ python3 4.train_model.py \
206
+ --pretrained_model_path "${pretrained_model_dir}" \
207
+ --hierarchical_labels_pkl "${hierarchical_labels_pkl}" \
208
+ --vocabulary_dir "${vocabulary_dir}" \
209
+ --train_subset "${train_subset}" \
210
+ --valid_subset "${valid_subset}" \
211
+ --serialization_dir "${serialization_dir2}" \
212
+ --checkpoint_path "${serialization_dir1}/best.th"
213
+
214
+ fi
215
+
216
+
217
+ if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
218
+ $verbose && echo "stage 6: make json config"
219
+ cd "${work_dir}" || exit 1
220
+ python3 6.make_json_config.py \
221
+ --pretrained_model_path "${pretrained_model_dir}" \
222
+ --hierarchical_labels_pkl "${hierarchical_labels_pkl}" \
223
+ --vocabulary_dir "${vocabulary_dir}" \
224
+ --train_subset "${train_subset}" \
225
+ --valid_subset "${valid_subset}" \
226
+ --serialization_dir "${serialization_dir2}" \
227
+ --json_config_dir "${data_dir}" \
228
+
229
+ fi
230
+
231
+
232
+ if [ ${stage} -le 7 ] && [ ${stop_stage} -ge 7 ]; then
233
+ $verbose && echo "stage 7: collect files"
234
+ cd "${work_dir}" || exit 1;
235
+
236
+ mkdir -p "${trained_models_dir}/${trained_model_name}"
237
+
238
+ cp -r "${vocabulary_dir}" "${trained_models_dir}/${trained_model_name}/vocabulary/"
239
+ cp "${serialization_dir2}/best.th" "${trained_models_dir}/${trained_model_name}/weights.th"
240
+ cp "${data_dir}/config.json" "${trained_models_dir}/${trained_model_name}/config.json"
241
+ cp "${hierarchical_labels_pkl}" "${trained_models_dir}/${trained_model_name}"
242
+
243
+ fi
244
+
245
+
246
+ if [ ${stage} -le 8 ] && [ ${stop_stage} -ge 8 ]; then
247
+ $verbose && echo "stage 8: predict by archive"
248
+ cd "${work_dir}" || exit 1;
249
+
250
+ python3 7.predict_by_archive.py \
251
+ --archive_file "${trained_models_dir}/${trained_model_name}" \
252
+ --pretrained_model_path "${pretrained_model_dir}" \
253
+
254
+ fi
main.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ import os
5
+ import time
6
+
7
+ from allennlp.data.tokenizers.pretrained_transformer_tokenizer import PretrainedTransformerTokenizer
8
+ from allennlp.data.token_indexers.single_id_token_indexer import SingleIdTokenIndexer
9
+ from allennlp.data.vocabulary import Vocabulary
10
+ from allennlp.modules.text_field_embedders.basic_text_field_embedder import BasicTextFieldEmbedder
11
+ from allennlp.modules.token_embedders.embedding import Embedding
12
+ from allennlp.modules.seq2vec_encoders.cnn_encoder import CnnEncoder
13
+ from allennlp.models.archival import archive_model, load_archive
14
+ from allennlp_models.rc.modules.seq2seq_encoders.stacked_self_attention import StackedSelfAttentionEncoder
15
+ from allennlp.predictors.predictor import Predictor
16
+ from allennlp.predictors.text_classifier import TextClassifierPredictor
17
+ import gradio as gr
18
+ import torch
19
+
20
+ from project_settings import project_path
21
+ from toolbox.allennlp_models.text_classifier.models.hierarchical_text_classifier import HierarchicalClassifier
22
+ from toolbox.allennlp_models.text_classifier.dataset_readers.hierarchical_classification_json import HierarchicalClassificationJsonReader
23
+
24
+
25
+ def get_args():
26
+ parser = argparse.ArgumentParser()
27
+ parser.add_argument(
28
+ "--cn_archive_file",
29
+ default=(project_path / "trained_models/telemarketing_intent_classification_cn").as_posix(),
30
+ type=str
31
+ )
32
+ parser.add_argument(
33
+ "--en_archive_file",
34
+ default=(project_path / "trained_models/telemarketing_intent_classification_en").as_posix(),
35
+ type=str
36
+ )
37
+ parser.add_argument(
38
+ "--jp_archive_file",
39
+ default=(project_path / "trained_models/telemarketing_intent_classification_jp").as_posix(),
40
+ type=str
41
+ )
42
+ parser.add_argument(
43
+ "--vi_archive_file",
44
+ default=(project_path / "trained_models/telemarketing_intent_classification_vi").as_posix(),
45
+ type=str
46
+ )
47
+ parser.add_argument(
48
+ "--predictor_name",
49
+ default="text_classifier",
50
+ type=str
51
+ )
52
+ args = parser.parse_args()
53
+ return args
54
+
55
+
56
+ def main():
57
+ args = get_args()
58
+
59
+ cn_archive = load_archive(archive_file=args.cn_archive_file)
60
+ cn_predictor = Predictor.from_archive(cn_archive, predictor_name=args.predictor_name)
61
+ en_archive = load_archive(archive_file=args.en_archive_file)
62
+ en_predictor = Predictor.from_archive(en_archive, predictor_name=args.predictor_name)
63
+ jp_archive = load_archive(archive_file=args.jp_archive_file)
64
+ jp_predictor = Predictor.from_archive(jp_archive, predictor_name=args.predictor_name)
65
+ vi_archive = load_archive(archive_file=args.vi_archive_file)
66
+ vi_predictor = Predictor.from_archive(vi_archive, predictor_name=args.predictor_name)
67
+
68
+ predictor_map = {
69
+ "chinese": cn_predictor,
70
+ "english": en_predictor,
71
+ "japanese": jp_predictor,
72
+ "vietnamese": vi_predictor,
73
+ }
74
+
75
+ def fn(text: str, language: str):
76
+ predictor = predictor_map.get(language, cn_predictor)
77
+
78
+ json_dict = {'sentence': text}
79
+ outputs = predictor.predict_json(
80
+ json_dict
81
+ )
82
+ outputs = predictor._model.decode(outputs)
83
+ label = outputs['label'][0]
84
+ prob = outputs['prob'][0]
85
+ prob = round(prob, 4)
86
+ return label, prob
87
+
88
+ description = """
89
+ 电销场景意图识别.
90
+ 语言: 汉语, 英语, 日语, 越南语.
91
+ 数据集是私有的.
92
+
93
+ model: selfattention-cnn
94
+ dataset: telemarketing_intent (https://huggingface.co/datasets/qgyd2021/telemarketing_intent)
95
+
96
+ accuracy:
97
+ chinese: 0.8002
98
+ english: 0.7011
99
+ japanese: 0.8154
100
+ vietnamese: 0.8168
101
+
102
+ """
103
+ demo = gr.Interface(
104
+ fn=fn,
105
+ inputs=[
106
+ gr.Text(label="text"),
107
+ gr.Dropdown(
108
+ choices=list(sorted(predictor_map.keys())),
109
+ label="language"
110
+ )
111
+ ],
112
+ outputs=[gr.Text(label="intent"), gr.Number(label="prob")],
113
+ examples=[
114
+ ["你找谁", "chinese"],
115
+ ["你是谁啊", "chinese"],
116
+ ["不好意思我现在很忙", "chinese"],
117
+ ["对不起, 不需要哈", "chinese"],
118
+ ["u have got the wrong number", "english"],
119
+ ["sure, thank a lot", "english"],
120
+ ["please leave your message for 95688496", "english"],
121
+ ["yes well", "english"],
122
+ ["失礼の", "japanese"],
123
+ ["ビートいう発表の後に、お名前とご用件をお話ください。", "japanese"],
124
+ ["わかんない。", "japanese"],
125
+ ["に出ることができません", "japanese"],
126
+ ["À không phải em nha.", "vietnamese"],
127
+ ["Dạ nhầm số rồi ạ?", "vietnamese"],
128
+ ["Ừ, cảm ơn em nhá.", "vietnamese"],
129
+ ["Không, chị không có tiền.", "vietnamese"],
130
+ ],
131
+ examples_per_page=50,
132
+ title="Telemarketing Intent Classification",
133
+ description=description,
134
+ )
135
+ demo.launch()
136
+
137
+ return
138
+
139
+
140
+ if __name__ == '__main__':
141
+ main()
predict.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ import json
5
+
6
+ from allennlp.data.tokenizers.pretrained_transformer_tokenizer import PretrainedTransformerTokenizer
7
+ from allennlp.data.token_indexers.single_id_token_indexer import SingleIdTokenIndexer
8
+ from allennlp.data.vocabulary import Vocabulary
9
+ from allennlp.modules.text_field_embedders.basic_text_field_embedder import BasicTextFieldEmbedder
10
+ from allennlp.modules.token_embedders.embedding import Embedding
11
+ from allennlp.modules.seq2vec_encoders.cnn_encoder import CnnEncoder
12
+ from allennlp.models.archival import archive_model, load_archive
13
+ from allennlp_models.rc.modules.seq2seq_encoders.stacked_self_attention import StackedSelfAttentionEncoder
14
+ from allennlp.predictors.predictor import Predictor
15
+ from allennlp.predictors.text_classifier import TextClassifierPredictor
16
+ import gradio as gr
17
+ import numpy as np
18
+ import pandas as pd
19
+ import torch
20
+ from tqdm import tqdm
21
+
22
+ from project_settings import project_path
23
+ from toolbox.allennlp_models.text_classifier.models.hierarchical_text_classifier import HierarchicalClassifier
24
+ from toolbox.allennlp_models.text_classifier.dataset_readers.hierarchical_classification_json import HierarchicalClassificationJsonReader
25
+
26
+
27
+ def get_args():
28
+ parser = argparse.ArgumentParser()
29
+ parser.add_argument(
30
+ "--excel_file",
31
+ default=r"D:\Users\tianx\PycharmProjects\telemarketing_intent\data\excel\telemarketing_intent_vi.xlsx",
32
+ type=str,
33
+ )
34
+ parser.add_argument(
35
+ "--archive_file",
36
+ default=(project_path / "trained_models/telemarketing_intent_classification_vi").as_posix(),
37
+ type=str
38
+ )
39
+ parser.add_argument(
40
+ "--predictor_name",
41
+ default="text_classifier",
42
+ type=str
43
+ )
44
+ parser.add_argument(
45
+ "--top_k",
46
+ default=10,
47
+ type=int
48
+ )
49
+ parser.add_argument(
50
+ "--output_file",
51
+ default="intent_top_k.jsonl",
52
+ type=str
53
+ )
54
+ args = parser.parse_args()
55
+ return args
56
+
57
+
58
+ def main():
59
+ args = get_args()
60
+
61
+ archive = load_archive(archive_file=args.archive_file)
62
+ predictor = Predictor.from_archive(archive, predictor_name=args.predictor_name)
63
+
64
+ df = pd.read_excel(args.excel_file)
65
+
66
+ with open(args.output_file, "w", encoding="utf-8") as f:
67
+ for i, row in tqdm(df.iterrows(), total=len(df)):
68
+ if i < 26976:
69
+ continue
70
+
71
+ source = row["source"]
72
+ text = row["text"]
73
+ label0 = row["label0"]
74
+ label1 = row["label1"]
75
+ selected = row["selected"]
76
+ checked = row["checked"]
77
+
78
+ if pd.isna(source) or source is None:
79
+ source = None
80
+
81
+ if pd.isna(text) or text is None:
82
+ continue
83
+ text = str(text)
84
+
85
+ if pd.isna(label0) or label0 is None:
86
+ label0 = None
87
+
88
+ if pd.isna(label1) or label1 is None:
89
+ label1 = None
90
+
91
+ if pd.isna(selected) or selected is None:
92
+ selected = None
93
+ else:
94
+ try:
95
+ selected = int(selected)
96
+ except Exception:
97
+ print(type(selected))
98
+ selected = None
99
+
100
+ if pd.isna(checked) or checked is None:
101
+ checked = None
102
+ else:
103
+ try:
104
+ checked = int(checked)
105
+ except Exception:
106
+ print(type(checked))
107
+ checked = None
108
+
109
+ # print(text)
110
+ json_dict = {'sentence': text}
111
+ outputs = predictor.predict_json(
112
+ json_dict
113
+ )
114
+ probs = outputs["probs"]
115
+ arg_idx = np.argsort(probs)
116
+
117
+ arg_idx_top_k = arg_idx[-10:]
118
+ label_top_k = [
119
+ predictor._model.vocab.get_token_from_index(index=idx, namespace="labels").split("_")[-1] for idx in arg_idx_top_k
120
+ ]
121
+ prob_top_k = [
122
+ str(round(probs[idx], 5)) for idx in arg_idx_top_k
123
+ ]
124
+
125
+ row_ = {
126
+ "source": source,
127
+ "text": text,
128
+ "label0": label0,
129
+ "label1": label1,
130
+ "selected": selected,
131
+ "checked": checked,
132
+ "predict_label_top_k": ";".join(list(reversed(label_top_k))),
133
+ "predict_prob_top_k": ";".join(list(reversed(prob_top_k)))
134
+ }
135
+ row_ = json.dumps(row_, ensure_ascii=False)
136
+ f.write("{}\n".format(row_))
137
+
138
+ return
139
+
140
+
141
+ if __name__ == '__main__':
142
+ main()
pretrained_models/bert-base-japanese/.gitattributes ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ *.bin.* filter=lfs diff=lfs merge=lfs -text
2
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.h5 filter=lfs diff=lfs merge=lfs -text
5
+ *.tflite filter=lfs diff=lfs merge=lfs -text
6
+ *.tar.gz filter=lfs diff=lfs merge=lfs -text
7
+ *.ot filter=lfs diff=lfs merge=lfs -text
8
+ *.onnx filter=lfs diff=lfs merge=lfs -text
9
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
pretrained_models/bert-base-japanese/README.md ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language: ja
3
+ license: cc-by-sa-4.0
4
+ datasets:
5
+ - wikipedia
6
+ widget:
7
+ - text: 東北大学で[MASK]の研究をしています。
8
+ ---
9
+
10
+ # BERT base Japanese (IPA dictionary)
11
+
12
+ This is a [BERT](https://github.com/google-research/bert) model pretrained on texts in the Japanese language.
13
+
14
+ This version of the model processes input texts with word-level tokenization based on the IPA dictionary, followed by the WordPiece subword tokenization.
15
+
16
+ The codes for the pretraining are available at [cl-tohoku/bert-japanese](https://github.com/cl-tohoku/bert-japanese/tree/v1.0).
17
+
18
+ ## Model architecture
19
+
20
+ The model architecture is the same as the original BERT base model; 12 layers, 768 dimensions of hidden states, and 12 attention heads.
21
+
22
+ ## Training Data
23
+
24
+ The model is trained on Japanese Wikipedia as of September 1, 2019.
25
+ To generate the training corpus, [WikiExtractor](https://github.com/attardi/wikiextractor) is used to extract plain texts from a dump file of Wikipedia articles.
26
+ The text files used for the training are 2.6GB in size, consisting of approximately 17M sentences.
27
+
28
+ ## Tokenization
29
+
30
+ The texts are first tokenized by [MeCab](https://taku910.github.io/mecab/) morphological parser with the IPA dictionary and then split into subwords by the WordPiece algorithm.
31
+ The vocabulary size is 32000.
32
+
33
+ ## Training
34
+
35
+ The model is trained with the same configuration as the original BERT; 512 tokens per instance, 256 instances per batch, and 1M training steps.
36
+
37
+ ## Licenses
38
+
39
+ The pretrained models are distributed under the terms of the [Creative Commons Attribution-ShareAlike 3.0](https://creativecommons.org/licenses/by-sa/3.0/).
40
+
41
+ ## Acknowledgments
42
+
43
+ For training models, we used Cloud TPUs provided by [TensorFlow Research Cloud](https://www.tensorflow.org/tfrc/) program.
pretrained_models/bert-base-japanese/config.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "BertForMaskedLM"
4
+ ],
5
+ "attention_probs_dropout_prob": 0.1,
6
+ "hidden_act": "gelu",
7
+ "hidden_dropout_prob": 0.1,
8
+ "hidden_size": 768,
9
+ "initializer_range": 0.02,
10
+ "intermediate_size": 3072,
11
+ "layer_norm_eps": 1e-12,
12
+ "max_position_embeddings": 512,
13
+ "model_type": "bert",
14
+ "num_attention_heads": 12,
15
+ "num_hidden_layers": 12,
16
+ "pad_token_id": 0,
17
+ "tokenizer_class": "BertJapaneseTokenizer",
18
+ "type_vocab_size": 2,
19
+ "vocab_size": 32000
20
+ }
pretrained_models/bert-base-japanese/tokenizer_config.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "do_lower_case": false,
3
+ "subword_tokenizer_type": "wordpiece",
4
+ "word_tokenizer_type": "mecab"
5
+ }
pretrained_models/bert-base-japanese/vocab.txt ADDED
The diff for this file is too large to render. See raw diff
 
pretrained_models/bert-base-uncased/.gitattributes ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.bin.* filter=lfs diff=lfs merge=lfs -text
2
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.h5 filter=lfs diff=lfs merge=lfs -text
5
+ *.tflite filter=lfs diff=lfs merge=lfs -text
6
+ *.tar.gz filter=lfs diff=lfs merge=lfs -text
7
+ *.ot filter=lfs diff=lfs merge=lfs -text
8
+ *.onnx filter=lfs diff=lfs merge=lfs -text
9
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
10
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
11
+ model.safetensors filter=lfs diff=lfs merge=lfs -text
pretrained_models/bert-base-uncased/LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
pretrained_models/bert-base-uncased/README.md ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language: en
3
+ tags:
4
+ - exbert
5
+ license: apache-2.0
6
+ datasets:
7
+ - bookcorpus
8
+ - wikipedia
9
+ ---
10
+
11
+ # BERT base model (uncased)
12
+
13
+ Pretrained model on English language using a masked language modeling (MLM) objective. It was introduced in
14
+ [this paper](https://arxiv.org/abs/1810.04805) and first released in
15
+ [this repository](https://github.com/google-research/bert). This model is uncased: it does not make a difference
16
+ between english and English.
17
+
18
+ Disclaimer: The team releasing BERT did not write a model card for this model so this model card has been written by
19
+ the Hugging Face team.
20
+
21
+ ## Model description
22
+
23
+ BERT is a transformers model pretrained on a large corpus of English data in a self-supervised fashion. This means it
24
+ was pretrained on the raw texts only, with no humans labeling them in any way (which is why it can use lots of
25
+ publicly available data) with an automatic process to generate inputs and labels from those texts. More precisely, it
26
+ was pretrained with two objectives:
27
+
28
+ - Masked language modeling (MLM): taking a sentence, the model randomly masks 15% of the words in the input then run
29
+ the entire masked sentence through the model and has to predict the masked words. This is different from traditional
30
+ recurrent neural networks (RNNs) that usually see the words one after the other, or from autoregressive models like
31
+ GPT which internally masks the future tokens. It allows the model to learn a bidirectional representation of the
32
+ sentence.
33
+ - Next sentence prediction (NSP): the models concatenates two masked sentences as inputs during pretraining. Sometimes
34
+ they correspond to sentences that were next to each other in the original text, sometimes not. The model then has to
35
+ predict if the two sentences were following each other or not.
36
+
37
+ This way, the model learns an inner representation of the English language that can then be used to extract features
38
+ useful for downstream tasks: if you have a dataset of labeled sentences, for instance, you can train a standard
39
+ classifier using the features produced by the BERT model as inputs.
40
+
41
+ ## Model variations
42
+
43
+ BERT has originally been released in base and large variations, for cased and uncased input text. The uncased models also strips out an accent markers.
44
+ Chinese and multilingual uncased and cased versions followed shortly after.
45
+ Modified preprocessing with whole word masking has replaced subpiece masking in a following work, with the release of two models.
46
+ Other 24 smaller models are released afterward.
47
+
48
+ The detailed release history can be found on the [google-research/bert readme](https://github.com/google-research/bert/blob/master/README.md) on github.
49
+
50
+ | Model | #params | Language |
51
+ |------------------------|--------------------------------|-------|
52
+ | [`bert-base-uncased`](https://huggingface.co/bert-base-uncased) | 110M | English |
53
+ | [`bert-large-uncased`](https://huggingface.co/bert-large-uncased) | 340M | English | sub
54
+ | [`bert-base-cased`](https://huggingface.co/bert-base-cased) | 110M | English |
55
+ | [`bert-large-cased`](https://huggingface.co/bert-large-cased) | 340M | English |
56
+ | [`bert-base-chinese`](https://huggingface.co/bert-base-chinese) | 110M | Chinese |
57
+ | [`bert-base-multilingual-cased`](https://huggingface.co/bert-base-multilingual-cased) | 110M | Multiple |
58
+ | [`bert-large-uncased-whole-word-masking`](https://huggingface.co/bert-large-uncased-whole-word-masking) | 340M | English |
59
+ | [`bert-large-cased-whole-word-masking`](https://huggingface.co/bert-large-cased-whole-word-masking) | 340M | English |
60
+
61
+ ## Intended uses & limitations
62
+
63
+ You can use the raw model for either masked language modeling or next sentence prediction, but it's mostly intended to
64
+ be fine-tuned on a downstream task. See the [model hub](https://huggingface.co/models?filter=bert) to look for
65
+ fine-tuned versions of a task that interests you.
66
+
67
+ Note that this model is primarily aimed at being fine-tuned on tasks that use the whole sentence (potentially masked)
68
+ to make decisions, such as sequence classification, token classification or question answering. For tasks such as text
69
+ generation you should look at model like GPT2.
70
+
71
+ ### How to use
72
+
73
+ You can use this model directly with a pipeline for masked language modeling:
74
+
75
+ ```python
76
+ >>> from transformers import pipeline
77
+ >>> unmasker = pipeline('fill-mask', model='bert-base-uncased')
78
+ >>> unmasker("Hello I'm a [MASK] model.")
79
+
80
+ [{'sequence': "[CLS] hello i'm a fashion model. [SEP]",
81
+ 'score': 0.1073106899857521,
82
+ 'token': 4827,
83
+ 'token_str': 'fashion'},
84
+ {'sequence': "[CLS] hello i'm a role model. [SEP]",
85
+ 'score': 0.08774490654468536,
86
+ 'token': 2535,
87
+ 'token_str': 'role'},
88
+ {'sequence': "[CLS] hello i'm a new model. [SEP]",
89
+ 'score': 0.05338378623127937,
90
+ 'token': 2047,
91
+ 'token_str': 'new'},
92
+ {'sequence': "[CLS] hello i'm a super model. [SEP]",
93
+ 'score': 0.04667217284440994,
94
+ 'token': 3565,
95
+ 'token_str': 'super'},
96
+ {'sequence': "[CLS] hello i'm a fine model. [SEP]",
97
+ 'score': 0.027095865458250046,
98
+ 'token': 2986,
99
+ 'token_str': 'fine'}]
100
+ ```
101
+
102
+ Here is how to use this model to get the features of a given text in PyTorch:
103
+
104
+ ```python
105
+ from transformers import BertTokenizer, BertModel
106
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
107
+ model = BertModel.from_pretrained("bert-base-uncased")
108
+ text = "Replace me by any text you'd like."
109
+ encoded_input = tokenizer(text, return_tensors='pt')
110
+ output = model(**encoded_input)
111
+ ```
112
+
113
+ and in TensorFlow:
114
+
115
+ ```python
116
+ from transformers import BertTokenizer, TFBertModel
117
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
118
+ model = TFBertModel.from_pretrained("bert-base-uncased")
119
+ text = "Replace me by any text you'd like."
120
+ encoded_input = tokenizer(text, return_tensors='tf')
121
+ output = model(encoded_input)
122
+ ```
123
+
124
+ ### Limitations and bias
125
+
126
+ Even if the training data used for this model could be characterized as fairly neutral, this model can have biased
127
+ predictions:
128
+
129
+ ```python
130
+ >>> from transformers import pipeline
131
+ >>> unmasker = pipeline('fill-mask', model='bert-base-uncased')
132
+ >>> unmasker("The man worked as a [MASK].")
133
+
134
+ [{'sequence': '[CLS] the man worked as a carpenter. [SEP]',
135
+ 'score': 0.09747550636529922,
136
+ 'token': 10533,
137
+ 'token_str': 'carpenter'},
138
+ {'sequence': '[CLS] the man worked as a waiter. [SEP]',
139
+ 'score': 0.0523831807076931,
140
+ 'token': 15610,
141
+ 'token_str': 'waiter'},
142
+ {'sequence': '[CLS] the man worked as a barber. [SEP]',
143
+ 'score': 0.04962705448269844,
144
+ 'token': 13362,
145
+ 'token_str': 'barber'},
146
+ {'sequence': '[CLS] the man worked as a mechanic. [SEP]',
147
+ 'score': 0.03788609802722931,
148
+ 'token': 15893,
149
+ 'token_str': 'mechanic'},
150
+ {'sequence': '[CLS] the man worked as a salesman. [SEP]',
151
+ 'score': 0.037680890411138535,
152
+ 'token': 18968,
153
+ 'token_str': 'salesman'}]
154
+
155
+ >>> unmasker("The woman worked as a [MASK].")
156
+
157
+ [{'sequence': '[CLS] the woman worked as a nurse. [SEP]',
158
+ 'score': 0.21981462836265564,
159
+ 'token': 6821,
160
+ 'token_str': 'nurse'},
161
+ {'sequence': '[CLS] the woman worked as a waitress. [SEP]',
162
+ 'score': 0.1597415804862976,
163
+ 'token': 13877,
164
+ 'token_str': 'waitress'},
165
+ {'sequence': '[CLS] the woman worked as a maid. [SEP]',
166
+ 'score': 0.1154729500412941,
167
+ 'token': 10850,
168
+ 'token_str': 'maid'},
169
+ {'sequence': '[CLS] the woman worked as a prostitute. [SEP]',
170
+ 'score': 0.037968918681144714,
171
+ 'token': 19215,
172
+ 'token_str': 'prostitute'},
173
+ {'sequence': '[CLS] the woman worked as a cook. [SEP]',
174
+ 'score': 0.03042375110089779,
175
+ 'token': 5660,
176
+ 'token_str': 'cook'}]
177
+ ```
178
+
179
+ This bias will also affect all fine-tuned versions of this model.
180
+
181
+ ## Training data
182
+
183
+ The BERT model was pretrained on [BookCorpus](https://yknzhu.wixsite.com/mbweb), a dataset consisting of 11,038
184
+ unpublished books and [English Wikipedia](https://en.wikipedia.org/wiki/English_Wikipedia) (excluding lists, tables and
185
+ headers).
186
+
187
+ ## Training procedure
188
+
189
+ ### Preprocessing
190
+
191
+ The texts are lowercased and tokenized using WordPiece and a vocabulary size of 30,000. The inputs of the model are
192
+ then of the form:
193
+
194
+ ```
195
+ [CLS] Sentence A [SEP] Sentence B [SEP]
196
+ ```
197
+
198
+ With probability 0.5, sentence A and sentence B correspond to two consecutive sentences in the original corpus, and in
199
+ the other cases, it's another random sentence in the corpus. Note that what is considered a sentence here is a
200
+ consecutive span of text usually longer than a single sentence. The only constrain is that the result with the two
201
+ "sentences" has a combined length of less than 512 tokens.
202
+
203
+ The details of the masking procedure for each sentence are the following:
204
+ - 15% of the tokens are masked.
205
+ - In 80% of the cases, the masked tokens are replaced by `[MASK]`.
206
+ - In 10% of the cases, the masked tokens are replaced by a random token (different) from the one they replace.
207
+ - In the 10% remaining cases, the masked tokens are left as is.
208
+
209
+ ### Pretraining
210
+
211
+ The model was trained on 4 cloud TPUs in Pod configuration (16 TPU chips total) for one million steps with a batch size
212
+ of 256. The sequence length was limited to 128 tokens for 90% of the steps and 512 for the remaining 10%. The optimizer
213
+ used is Adam with a learning rate of 1e-4, \\(\beta_{1} = 0.9\\) and \\(\beta_{2} = 0.999\\), a weight decay of 0.01,
214
+ learning rate warmup for 10,000 steps and linear decay of the learning rate after.
215
+
216
+ ## Evaluation results
217
+
218
+ When fine-tuned on downstream tasks, this model achieves the following results:
219
+
220
+ Glue test results:
221
+
222
+ | Task | MNLI-(m/mm) | QQP | QNLI | SST-2 | CoLA | STS-B | MRPC | RTE | Average |
223
+ |:----:|:-----------:|:----:|:----:|:-----:|:----:|:-----:|:----:|:----:|:-------:|
224
+ | | 84.6/83.4 | 71.2 | 90.5 | 93.5 | 52.1 | 85.8 | 88.9 | 66.4 | 79.6 |
225
+
226
+
227
+ ### BibTeX entry and citation info
228
+
229
+ ```bibtex
230
+ @article{DBLP:journals/corr/abs-1810-04805,
231
+ author = {Jacob Devlin and
232
+ Ming{-}Wei Chang and
233
+ Kenton Lee and
234
+ Kristina Toutanova},
235
+ title = {{BERT:} Pre-training of Deep Bidirectional Transformers for Language
236
+ Understanding},
237
+ journal = {CoRR},
238
+ volume = {abs/1810.04805},
239
+ year = {2018},
240
+ url = {http://arxiv.org/abs/1810.04805},
241
+ archivePrefix = {arXiv},
242
+ eprint = {1810.04805},
243
+ timestamp = {Tue, 30 Oct 2018 20:39:56 +0100},
244
+ biburl = {https://dblp.org/rec/journals/corr/abs-1810-04805.bib},
245
+ bibsource = {dblp computer science bibliography, https://dblp.org}
246
+ }
247
+ ```
248
+
249
+ <a href="https://huggingface.co/exbert/?model=bert-base-uncased">
250
+ <img width="300px" src="https://cdn-media.huggingface.co/exbert/button.png">
251
+ </a>
pretrained_models/bert-base-uncased/config.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "BertForMaskedLM"
4
+ ],
5
+ "attention_probs_dropout_prob": 0.1,
6
+ "gradient_checkpointing": false,
7
+ "hidden_act": "gelu",
8
+ "hidden_dropout_prob": 0.1,
9
+ "hidden_size": 768,
10
+ "initializer_range": 0.02,
11
+ "intermediate_size": 3072,
12
+ "layer_norm_eps": 1e-12,
13
+ "max_position_embeddings": 512,
14
+ "model_type": "bert",
15
+ "num_attention_heads": 12,
16
+ "num_hidden_layers": 12,
17
+ "pad_token_id": 0,
18
+ "position_embedding_type": "absolute",
19
+ "transformers_version": "4.6.0.dev0",
20
+ "type_vocab_size": 2,
21
+ "use_cache": true,
22
+ "vocab_size": 30522
23
+ }
pretrained_models/bert-base-uncased/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
pretrained_models/bert-base-uncased/tokenizer_config.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "do_lower_case": true
3
+ }
pretrained_models/bert-base-uncased/vocab.txt ADDED
The diff for this file is too large to render. See raw diff
 
pretrained_models/bert-base-vietnamese-uncased/.gitattributes ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.bin.* filter=lfs diff=lfs merge=lfs -text
2
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.h5 filter=lfs diff=lfs merge=lfs -text
5
+ *.tflite filter=lfs diff=lfs merge=lfs -text
6
+ *.tar.gz filter=lfs diff=lfs merge=lfs -text
7
+ *.ot filter=lfs diff=lfs merge=lfs -text
8
+ *.onnx filter=lfs diff=lfs merge=lfs -text
9
+ *.arrow filter=lfs diff=lfs merge=lfs -text
10
+ *.ftz filter=lfs diff=lfs merge=lfs -text
11
+ *.joblib filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.pb filter=lfs diff=lfs merge=lfs -text
15
+ *.pt filter=lfs diff=lfs merge=lfs -text
16
+ *.pth filter=lfs diff=lfs merge=lfs -text
17
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
pretrained_models/bert-base-vietnamese-uncased/README.md ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Usage
2
+ ```python
3
+ from transformers import BertForSequenceClassification
4
+ from transformers import BertTokenizer
5
+ model = BertForSequenceClassification.from_pretrained("trituenhantaoio/bert-base-vietnamese-uncased")
6
+ tokenizer = BertTokenizer.from_pretrained("trituenhantaoio/bert-base-vietnamese-uncased")
7
+ ```
8
+
9
+ ### References
10
+
11
+ ```
12
+ @article{ttnt2020bert,
13
+ title={Vietnamese BERT: Pretrained on News and Wiki},
14
+ author={trituenhantao.io},
15
+ year = {2020},
16
+ publisher = {GitHub},
17
+ journal = {GitHub repository},
18
+ howpublished = {\url{https://github.com/trituenhantaoio/vn-bert-base-uncased}},
19
+ }
20
+ ```
21
+
22
+ [trituenhantao.io](https://trituenhantao.io)
pretrained_models/bert-base-vietnamese-uncased/config.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "/content/drive/My Drive/Colab_data/2020/vietnamese-bert-19-11-2020/vietnamese-bert-10-2020/bert_model/pytorch_model",
3
+ "attention_probs_dropout_prob": 0.1,
4
+ "directionality": "bidi",
5
+ "gradient_checkpointing": false,
6
+ "hidden_act": "gelu",
7
+ "hidden_dropout_prob": 0.1,
8
+ "hidden_size": 768,
9
+ "initializer_range": 0.02,
10
+ "intermediate_size": 3072,
11
+ "layer_norm_eps": 1e-12,
12
+ "max_position_embeddings": 512,
13
+ "model_type": "bert",
14
+ "num_attention_heads": 12,
15
+ "num_hidden_layers": 12,
16
+ "pad_token_id": 0,
17
+ "pooler_fc_size": 768,
18
+ "pooler_num_attention_heads": 12,
19
+ "pooler_num_fc_layers": 3,
20
+ "pooler_size_per_head": 128,
21
+ "pooler_type": "first_token_transform",
22
+ "position_embedding_type": "absolute",
23
+ "transformers_version": "4.2.2",
24
+ "type_vocab_size": 2,
25
+ "use_cache": true,
26
+ "vocab_size": 32000
27
+ }
pretrained_models/bert-base-vietnamese-uncased/special_tokens_map.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]"}
pretrained_models/bert-base-vietnamese-uncased/tokenizer_config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"do_lower_case": true, "do_basic_tokenize": true, "never_split": null, "unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]", "tokenize_chinese_chars": true, "strip_accents": null}
pretrained_models/bert-base-vietnamese-uncased/vocab.txt ADDED
The diff for this file is too large to render. See raw diff
 
pretrained_models/chinese-bert-wwm-ext/.gitattributes ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ *.bin.* filter=lfs diff=lfs merge=lfs -text
2
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.h5 filter=lfs diff=lfs merge=lfs -text
5
+ *.tflite filter=lfs diff=lfs merge=lfs -text
6
+ *.tar.gz filter=lfs diff=lfs merge=lfs -text
7
+ *.ot filter=lfs diff=lfs merge=lfs -text
8
+ *.onnx filter=lfs diff=lfs merge=lfs -text
9
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
pretrained_models/chinese-bert-wwm-ext/README.md ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language:
3
+ - zh
4
+ license: "apache-2.0"
5
+ ---
6
+ ## Chinese BERT with Whole Word Masking
7
+ For further accelerating Chinese natural language processing, we provide **Chinese pre-trained BERT with Whole Word Masking**.
8
+
9
+ **[Pre-Training with Whole Word Masking for Chinese BERT](https://arxiv.org/abs/1906.08101)**
10
+ Yiming Cui, Wanxiang Che, Ting Liu, Bing Qin, Ziqing Yang, Shijin Wang, Guoping Hu
11
+
12
+ This repository is developed based on:https://github.com/google-research/bert
13
+
14
+ You may also interested in,
15
+ - Chinese BERT series: https://github.com/ymcui/Chinese-BERT-wwm
16
+ - Chinese MacBERT: https://github.com/ymcui/MacBERT
17
+ - Chinese ELECTRA: https://github.com/ymcui/Chinese-ELECTRA
18
+ - Chinese XLNet: https://github.com/ymcui/Chinese-XLNet
19
+ - Knowledge Distillation Toolkit - TextBrewer: https://github.com/airaria/TextBrewer
20
+
21
+ More resources by HFL: https://github.com/ymcui/HFL-Anthology
22
+
23
+ ## Citation
24
+ If you find the technical report or resource is useful, please cite the following technical report in your paper.
25
+ - Primary: https://arxiv.org/abs/2004.13922
26
+ ```
27
+ @inproceedings{cui-etal-2020-revisiting,
28
+ title = "Revisiting Pre-Trained Models for {C}hinese Natural Language Processing",
29
+ author = "Cui, Yiming and
30
+ Che, Wanxiang and
31
+ Liu, Ting and
32
+ Qin, Bing and
33
+ Wang, Shijin and
34
+ Hu, Guoping",
35
+ booktitle = "Proceedings of the 2020 Conference on Empirical Methods in Natural Language Processing: Findings",
36
+ month = nov,
37
+ year = "2020",
38
+ address = "Online",
39
+ publisher = "Association for Computational Linguistics",
40
+ url = "https://www.aclweb.org/anthology/2020.findings-emnlp.58",
41
+ pages = "657--668",
42
+ }
43
+ ```
44
+ - Secondary: https://arxiv.org/abs/1906.08101
45
+ ```
46
+ @article{chinese-bert-wwm,
47
+ title={Pre-Training with Whole Word Masking for Chinese BERT},
48
+ author={Cui, Yiming and Che, Wanxiang and Liu, Ting and Qin, Bing and Yang, Ziqing and Wang, Shijin and Hu, Guoping},
49
+ journal={arXiv preprint arXiv:1906.08101},
50
+ year={2019}
51
+ }
52
+ ```
pretrained_models/chinese-bert-wwm-ext/added_tokens.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {}
pretrained_models/chinese-bert-wwm-ext/config.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "BertForMaskedLM"
4
+ ],
5
+ "attention_probs_dropout_prob": 0.1,
6
+ "directionality": "bidi",
7
+ "hidden_act": "gelu",
8
+ "hidden_dropout_prob": 0.1,
9
+ "hidden_size": 768,
10
+ "initializer_range": 0.02,
11
+ "intermediate_size": 3072,
12
+ "layer_norm_eps": 1e-12,
13
+ "max_position_embeddings": 512,
14
+ "model_type": "bert",
15
+ "num_attention_heads": 12,
16
+ "num_hidden_layers": 12,
17
+ "output_past": true,
18
+ "pad_token_id": 0,
19
+ "pooler_fc_size": 768,
20
+ "pooler_num_attention_heads": 12,
21
+ "pooler_num_fc_layers": 3,
22
+ "pooler_size_per_head": 128,
23
+ "pooler_type": "first_token_transform",
24
+ "type_vocab_size": 2,
25
+ "vocab_size": 21128
26
+ }
pretrained_models/chinese-bert-wwm-ext/special_tokens_map.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]"}
pretrained_models/chinese-bert-wwm-ext/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
pretrained_models/chinese-bert-wwm-ext/tokenizer_config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"init_inputs": []}
pretrained_models/chinese-bert-wwm-ext/vocab.txt ADDED
The diff for this file is too large to render. See raw diff
 
project_settings.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import os
4
+ from pathlib import Path
5
+
6
+
7
+ project_path = os.path.abspath(os.path.dirname(__file__))
8
+ project_path = Path(project_path)
9
+
10
+
11
+ if __name__ == '__main__':
12
+ pass
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio==3.20.1
2
+ allennlp==2.10.1
3
+ allennlp-models==2.10.1
4
+ torch==1.12.1
5
+ overrides==7.3.1
6
+ pytorch_pretrained_bert==0.6.2
7
+ pydantic==1.10.12
8
+ thinc==7.4.6
9
+ spacy==2.3.9
10
+ fugashi==1.1.2
11
+ ipadic==1.0.0
12
+ pandas==2.0.3
13
+ xlrd==1.2.0
14
+ openpyxl==3.0.9
toolbox/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+
5
+ if __name__ == '__main__':
6
+ pass
toolbox/allennlp_models/text_classifier/dataset_readers/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+
5
+ if __name__ == '__main__':
6
+ pass
toolbox/allennlp_models/text_classifier/dataset_readers/hierarchical_classification_json.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ from typing import Dict, Iterable, List, Union
4
+ import logging
5
+ import json
6
+ from overrides import overrides
7
+ from allennlp.common.file_utils import cached_path
8
+ from allennlp.data.dataset_readers.dataset_reader import DatasetReader
9
+ from allennlp.data.fields import LabelField, TextField, Field, ListField
10
+ from allennlp.data.instance import Instance
11
+ from allennlp.data.token_indexers import TokenIndexer, SingleIdTokenIndexer
12
+ from allennlp.data.tokenizers import Tokenizer, SpacyTokenizer
13
+ from allennlp.data.tokenizers.sentence_splitter import SpacySentenceSplitter
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ @DatasetReader.register("hierarchical_classification_json")
19
+ class HierarchicalClassificationJsonReader(DatasetReader):
20
+ def __init__(self,
21
+ n_hierarchical: int = 2,
22
+ token_indexers: Dict[str, TokenIndexer] = None,
23
+ tokenizer: Tokenizer = None,
24
+ segment_sentences: bool = False,
25
+ max_sequence_length: int = None,
26
+ skip_label_indexing: bool = False,
27
+ **kwargs) -> None:
28
+ super().__init__(**kwargs)
29
+ self._n_hierarchical = n_hierarchical
30
+ self._tokenizer = tokenizer or SpacyTokenizer()
31
+ self._segment_sentences = segment_sentences
32
+ self._max_sequence_length = max_sequence_length
33
+ self._skip_label_indexing = skip_label_indexing
34
+ self._token_indexers = token_indexers or {'tokens': SingleIdTokenIndexer()}
35
+ if self._segment_sentences:
36
+ self._sentence_segmenter = SpacySentenceSplitter()
37
+
38
+ @overrides
39
+ def _read(self, file_path) -> Iterable[Instance]:
40
+ with open(cached_path(file_path), "r", encoding='utf-8') as data_file:
41
+ for line in data_file.readlines():
42
+ if not line:
43
+ continue
44
+ items = json.loads(line)
45
+ text = items["text"]
46
+
47
+ labels = [items.get("label{}".format(idx), None) for idx in range(self._n_hierarchical)]
48
+ if all(labels):
49
+ label = '_'.join(labels)
50
+ else:
51
+ label = None
52
+
53
+ if label is not None:
54
+ if self._skip_label_indexing:
55
+ try:
56
+ label = int(label)
57
+ except ValueError:
58
+ raise ValueError('Labels must be integers if skip_label_indexing is True.')
59
+ else:
60
+ label = str(label)
61
+ instance = self.text_to_instance(text=text, label=label)
62
+ if instance is not None:
63
+ yield instance
64
+
65
+ def _truncate(self, tokens):
66
+ if len(tokens) > self._max_sequence_length:
67
+ tokens = tokens[:self._max_sequence_length]
68
+ return tokens
69
+
70
+ @overrides
71
+ # def text_to_instance(self, text: str, label: Union[str, int] = None) -> Instance:
72
+ def text_to_instance(self, *inputs) -> Instance:
73
+ if len(inputs) == 1:
74
+ text = inputs[0]
75
+ label = None
76
+ elif len(inputs) == 2:
77
+ text, label = inputs
78
+ else:
79
+ raise AssertionError
80
+
81
+ fields: Dict[str, Field] = {}
82
+ if self._segment_sentences:
83
+ sentences: List[Field] = []
84
+ sentence_splits = self._sentence_segmenter.split_sentences(text)
85
+ for sentence in sentence_splits:
86
+ word_tokens = self._tokenizer.tokenize(sentence)
87
+ if self._max_sequence_length is not None:
88
+ word_tokens = self._truncate(word_tokens)
89
+ sentences.append(TextField(word_tokens, self._token_indexers))
90
+ fields['tokens'] = ListField(sentences)
91
+ else:
92
+ tokens = self._tokenizer.tokenize(text)
93
+ if self._max_sequence_length is not None:
94
+ tokens = self._truncate(tokens)
95
+ fields['tokens'] = TextField(tokens, self._token_indexers)
96
+ if label is not None:
97
+ fields['label'] = LabelField(label,
98
+ skip_indexing=self._skip_label_indexing)
99
+ return Instance(fields)
toolbox/allennlp_models/text_classifier/models/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+
5
+ if __name__ == '__main__':
6
+ pass
toolbox/allennlp_models/text_classifier/models/hierarchical_text_classifier.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ from collections import OrderedDict
4
+ import pickle
5
+ from typing import Dict, Optional
6
+
7
+ from overrides import overrides
8
+ import torch
9
+ import torch.nn as nn
10
+
11
+ from allennlp.data import Vocabulary
12
+ from allennlp.models.model import Model
13
+ from allennlp.modules import Seq2SeqEncoder, Seq2VecEncoder, TextFieldEmbedder
14
+ from allennlp.nn import InitializerApplicator, RegularizerApplicator
15
+ from allennlp.nn.util import get_text_field_mask
16
+ from allennlp.training.metrics import CategoricalAccuracy
17
+
18
+ from toolbox.torch.modules.loss import FocalLoss, NegativeEntropy
19
+
20
+
21
+ @Model.register("hierarchical_classifier")
22
+ class HierarchicalClassifier(Model):
23
+ def __init__(self,
24
+ vocab: Vocabulary,
25
+ hierarchical_labels_pkl: str,
26
+ text_field_embedder: TextFieldEmbedder,
27
+ seq2vec_encoder: Seq2VecEncoder,
28
+ seq2seq_encoder: Seq2SeqEncoder = None,
29
+ dropout: float = None,
30
+ num_labels: int = None,
31
+ label_namespace: str = "labels",
32
+ balance_probs: bool = False,
33
+ initializer: InitializerApplicator = InitializerApplicator(),
34
+ regularizer: Optional[RegularizerApplicator] = None) -> None:
35
+
36
+ super().__init__(vocab, regularizer)
37
+ self._hierarchical_labels_pkl = hierarchical_labels_pkl
38
+ self._text_field_embedder = text_field_embedder
39
+
40
+ if seq2seq_encoder:
41
+ self._seq2seq_encoder = seq2seq_encoder
42
+ else:
43
+ self._seq2seq_encoder = None
44
+
45
+ self._seq2vec_encoder = seq2vec_encoder
46
+ self._classifier_input_dim = self._seq2vec_encoder.get_output_dim()
47
+
48
+ if dropout:
49
+ self._dropout = torch.nn.Dropout(dropout)
50
+ else:
51
+ self._dropout = None
52
+
53
+ self._label_namespace = label_namespace
54
+
55
+ if num_labels:
56
+ self._num_labels = num_labels
57
+ else:
58
+ self._num_labels = vocab.get_vocab_size(namespace=self._label_namespace)
59
+
60
+ with open(self._hierarchical_labels_pkl, 'rb') as f:
61
+ hierarchical_labels = pickle.load(f)
62
+ self._classification_layer = HierarchicalSoftMaxClassificationLayer(
63
+ classifier_input_dim=self._classifier_input_dim,
64
+ hierarchical_labels=hierarchical_labels,
65
+ activation='softmax',
66
+ )
67
+
68
+ self._accuracy = CategoricalAccuracy()
69
+
70
+ if balance_probs:
71
+ self._loss = NegativeEntropy(
72
+ inputs_logits=False,
73
+ )
74
+ else:
75
+ self._loss = FocalLoss(
76
+ num_classes=self._num_labels,
77
+ inputs_logits=False,
78
+ )
79
+ initializer(self)
80
+
81
+ def forward(self, # type: ignore
82
+ tokens: Dict[str, torch.LongTensor],
83
+ label: torch.IntTensor = None) -> Dict[str, torch.Tensor]:
84
+ embedded_text = self._text_field_embedder(tokens)
85
+ mask = get_text_field_mask(tokens)
86
+
87
+ if self._seq2seq_encoder:
88
+ embedded_text = self._seq2seq_encoder(embedded_text, mask=mask)
89
+
90
+ embedded_text = self._seq2vec_encoder(embedded_text, mask=mask)
91
+
92
+ if self._dropout:
93
+ embedded_text = self._dropout(embedded_text)
94
+
95
+ probs = self._classification_layer(embedded_text)
96
+
97
+ output_dict = {"probs": probs}
98
+
99
+ if label is not None:
100
+ loss = self._loss(probs, label.long().view(-1))
101
+ output_dict["loss"] = loss
102
+ self._accuracy(probs, label)
103
+
104
+ return output_dict
105
+
106
+ def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
107
+ predictions = output_dict["probs"]
108
+ predictions = torch.tensor(predictions)
109
+ if predictions.dim() == 2:
110
+ predictions_list = [predictions[i] for i in range(predictions.shape[0])]
111
+ else:
112
+ predictions_list = [predictions]
113
+ classes = list()
114
+ prob = list()
115
+ for prediction in predictions_list:
116
+ label_idx = prediction.argmax(dim=-1).item()
117
+ label_str = self.vocab.get_index_to_token_vocabulary(self._label_namespace).get(label_idx, str(label_idx))
118
+ classes.append(label_str)
119
+ prob.append(prediction[label_idx].item())
120
+ output_dict["label"] = classes
121
+ output_dict["prob"] = prob
122
+ return output_dict
123
+
124
+ def get_metrics(self, reset: bool = False) -> Dict[str, float]:
125
+ metrics = {'accuracy': self._accuracy.get_metric(reset)}
126
+ return metrics
127
+
128
+
129
+ class HierarchicalSoftMaxClassificationLayer(nn.Module):
130
+ """多层 softmax 实现多极文本分类
131
+
132
+ 由于初始化时, 各层 softmax 的概率趋于平衡.
133
+
134
+ 因此在第一层时 `领域无关` 就分到了 50% 的概率.
135
+
136
+ `领域相关` 中的各类别去分剩下的 50% 的概率.
137
+ 这会导致模型一开始时输出的类别全是 `领域无关`, 这导致模型无法优化.
138
+
139
+ 解决方��:
140
+ 1. 从数据集中去除 `领域无关` 数据. 并训练模型.
141
+ 2. 等模型收敛之后, 再使用包含 `领域无关` 的数据集, 让模型加载之前的权重, 并重新开始训练模型.
142
+
143
+ """
144
+
145
+ @staticmethod
146
+ def demo1():
147
+ # hierarchical_labels = OrderedDict({
148
+ # '领域相关': OrderedDict({
149
+ # '肯定答复': [
150
+ # '肯定(好的)', '肯定(可以)', '肯定(正确)'
151
+ # ],
152
+ # '否定答复': [
153
+ # '否定(不可以)', '否定(不知道)', '否定(错误)'
154
+ # ],
155
+ # '用户正忙': [
156
+ # '用户正忙'
157
+ # ]
158
+ # }),
159
+ # '领域无关': OrderedDict({
160
+ # '领域无关': [
161
+ # '领域无关'
162
+ # ]
163
+ # })
164
+ # })
165
+
166
+ hierarchical_labels = OrderedDict({
167
+ '领域相关': ['肯定答复', '否定答复', '用户正忙', '查联系方式'],
168
+ '领域无关': ['领域无关'],
169
+ })
170
+
171
+ softmax_layer = HierarchicalSoftMaxClassificationLayer(
172
+ classifier_input_dim=3,
173
+ hierarchical_labels=hierarchical_labels,
174
+ activation='softmax',
175
+ # activation='sigmoid',
176
+
177
+ )
178
+
179
+ for k, v in softmax_layer.__dict__['_modules'].items():
180
+ print(k)
181
+ print(v)
182
+
183
+ inputs = torch.ones(size=(2, 3), dtype=torch.float32)
184
+
185
+ probs = softmax_layer.forward(inputs)
186
+ print(probs)
187
+ print(torch.sum(probs, dim=-1))
188
+ return
189
+
190
+ def __init__(self, classifier_input_dim: int, hierarchical_labels: OrderedDict, activation: str = 'softmax'):
191
+ super(HierarchicalSoftMaxClassificationLayer, self).__init__()
192
+ self.classifier_input_dim = classifier_input_dim
193
+ self.hierarchical_labels = hierarchical_labels
194
+ self.activation: str = activation
195
+
196
+ self._init_hierarchical_classification_layer(hierarchical_labels)
197
+
198
+ def _init_hierarchical_classification_layer(self,
199
+ hierarchical_labels: OrderedDict,
200
+ key: str = 'classification_layer',
201
+ child_class: str = None):
202
+ num_labels = len(hierarchical_labels)
203
+
204
+ classification_layer = torch.nn.Linear(self.classifier_input_dim, num_labels)
205
+ if child_class is not None:
206
+ key = '{header}_{child_class}'.format(header=key, child_class=child_class)
207
+ setattr(
208
+ self,
209
+ key,
210
+ classification_layer
211
+ )
212
+
213
+ branch = 0
214
+ for k, v in hierarchical_labels.items():
215
+ if isinstance(v, OrderedDict):
216
+ self._init_hierarchical_classification_layer(
217
+ v,
218
+ key=key,
219
+ child_class=branch,
220
+ )
221
+ elif isinstance(v, list):
222
+ num_labels = len(v)
223
+ classification_layer = torch.nn.Linear(self.classifier_input_dim, num_labels)
224
+ setattr(
225
+ self,
226
+ '{key}_{child_class}'.format(key=key, child_class=branch),
227
+ classification_layer,
228
+ )
229
+ else:
230
+ raise NotImplementedError
231
+ branch += 1
232
+ return
233
+
234
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
235
+ key = 'classification_layer'
236
+ classification_layer = getattr(self, key)
237
+ logits = classification_layer.forward(inputs)
238
+ probs = torch.softmax(logits, dim=-1)
239
+
240
+ probs = self._layer_probs(
241
+ inputs=inputs,
242
+ probs=probs,
243
+ key=key,
244
+ )
245
+
246
+ return probs
247
+
248
+ def _layer_probs(self,
249
+ inputs: torch.Tensor,
250
+ probs: torch.Tensor,
251
+ key: str,
252
+ ):
253
+
254
+ result = list()
255
+ for child_class in range(probs.shape[1]):
256
+ parent_probs = torch.unsqueeze(probs[:, child_class], dim=-1)
257
+
258
+ child_key = '{key}_{child_class}'.format(key=key, child_class=child_class)
259
+ classification_layer = getattr(self, child_key)
260
+ logits = classification_layer.forward(inputs)
261
+
262
+ child_child_key = '{key}_{child_class}'.format(key=child_key, child_class=0)
263
+ if hasattr(self, child_child_key):
264
+ child_probs = torch.softmax(logits, dim=-1)
265
+ child_probs = child_probs * parent_probs
266
+
267
+ child_probs = self._layer_probs(
268
+ inputs=inputs,
269
+ probs=child_probs,
270
+ key=child_key,
271
+ )
272
+ else:
273
+ if self.activation == 'softmax':
274
+ child_probs = torch.softmax(logits, dim=-1)
275
+ else:
276
+ child_probs = torch.sigmoid(logits)
277
+ child_probs = child_probs * parent_probs
278
+
279
+ result.append(child_probs)
280
+
281
+ result = torch.concat(result, dim=-1)
282
+ return result
283
+
284
+
285
+ def demo1():
286
+ HierarchicalSoftMaxClassificationLayer.demo1()
287
+ return
288
+
289
+
290
+ if __name__ == '__main__':
291
+ demo1()
toolbox/torch/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+
5
+ if __name__ == '__main__':
6
+ pass
toolbox/torch/modules/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+
5
+ if __name__ == '__main__':
6
+ pass
toolbox/torch/modules/loss.py ADDED
@@ -0,0 +1,738 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import math
4
+ from typing import List, Optional
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from torch.nn.modules.loss import _Loss
11
+ from torch.autograd import Variable
12
+
13
+
14
+ class ClassBalancedLoss(_Loss):
15
+ """
16
+ https://arxiv.org/abs/1901.05555
17
+ """
18
+ @staticmethod
19
+ def demo1():
20
+ batch_loss: torch.FloatTensor = torch.randn(size=(2, 1), dtype=torch.float32)
21
+ targets: torch.LongTensor = torch.tensor([1, 2], dtype=torch.long)
22
+
23
+ class_balanced_loss = ClassBalancedLoss(
24
+ num_classes=3,
25
+ num_samples_each_class=[300, 433, 50],
26
+ reduction='mean',
27
+ )
28
+ loss = class_balanced_loss.forward(batch_loss=batch_loss, targets=targets)
29
+ print(loss)
30
+ return
31
+
32
+ @staticmethod
33
+ def demo2():
34
+ inputs: torch.FloatTensor = torch.randn(size=(2, 3), dtype=torch.float32)
35
+ targets: torch.LongTensor = torch.tensor([1, 2], dtype=torch.long)
36
+
37
+ focal_loss = FocalLoss(
38
+ num_classes=3,
39
+ # reduction='mean',
40
+ # reduction='sum',
41
+ reduction='none',
42
+ )
43
+ batch_loss = focal_loss.forward(inputs, targets)
44
+ print(batch_loss)
45
+
46
+ class_balanced_loss = ClassBalancedLoss(
47
+ num_classes=3,
48
+ num_samples_each_class=[300, 433, 50],
49
+ reduction='mean',
50
+ )
51
+ loss = class_balanced_loss.forward(batch_loss=batch_loss, targets=targets)
52
+ print(loss)
53
+
54
+ return
55
+
56
+ def __init__(self,
57
+ num_classes: int,
58
+ num_samples_each_class: List[int],
59
+ beta: float = 0.999,
60
+ reduction: str = 'mean') -> None:
61
+ super(ClassBalancedLoss, self).__init__(None, None, reduction)
62
+
63
+ effective_num = 1.0 - np.power(beta, num_samples_each_class)
64
+ weights = (1.0 - beta) / np.array(effective_num)
65
+ self.weights = weights / np.sum(weights) * num_classes
66
+
67
+ def forward(self, batch_loss: torch.FloatTensor, targets: torch.LongTensor):
68
+ """
69
+ :param batch_loss: shape=[batch_size, 1]
70
+ :param targets: shape=[batch_size,]
71
+ :return:
72
+ """
73
+ weights = list()
74
+ targets = targets.numpy()
75
+ for target in targets:
76
+ weights.append([self.weights[target]])
77
+
78
+ weights = torch.tensor(weights, dtype=torch.float32)
79
+ batch_loss = weights * batch_loss
80
+
81
+ if self.reduction == 'mean':
82
+ loss = batch_loss.mean()
83
+ elif self.reduction == 'sum':
84
+ loss = batch_loss.sum()
85
+ else:
86
+ loss = batch_loss
87
+ return loss
88
+
89
+
90
+ class EqualizationLoss(_Loss):
91
+ """
92
+ 在图像识别中的, sigmoid 的多标签分类, 且 num_classes 类别数之外有一个 background 背景类别.
93
+ Equalization Loss
94
+ https://arxiv.org/abs/2003.05176
95
+ Equalization Loss v2
96
+ https://arxiv.org/abs/2012.08548
97
+ """
98
+
99
+ @staticmethod
100
+ def demo1():
101
+ logits: torch.FloatTensor = torch.randn(size=(3, 3), dtype=torch.float32)
102
+ targets: torch.LongTensor = torch.tensor([1, 2, 3], dtype=torch.long)
103
+
104
+ equalization_loss = EqualizationLoss(
105
+ num_samples_each_class=[300, 433, 50],
106
+ threshold=100,
107
+ reduction='mean',
108
+ )
109
+ loss = equalization_loss.forward(logits=logits, targets=targets)
110
+ print(loss)
111
+ return
112
+
113
+ def __init__(self,
114
+ num_samples_each_class: List[int],
115
+ threshold: int = 100,
116
+ reduction: str = 'mean') -> None:
117
+ super(EqualizationLoss, self).__init__(None, None, reduction)
118
+ self.num_samples_each_class = np.array(num_samples_each_class, dtype=np.int32)
119
+ self.threshold = threshold
120
+
121
+ def forward(self,
122
+ logits: torch.FloatTensor,
123
+ targets: torch.LongTensor
124
+ ):
125
+ """
126
+ num_classes + 1 对应于背景类别 background.
127
+ :param logits: shape=[batch_size, num_classes]
128
+ :param targets: shape=[batch_size]
129
+ :return:
130
+ """
131
+ batch_size, num_classes = logits.size()
132
+
133
+ one_hot_targets = F.one_hot(targets, num_classes=num_classes + 1)
134
+ one_hot_targets = one_hot_targets[:, :-1]
135
+
136
+ exclude = self.exclude_func(
137
+ num_classes=num_classes,
138
+ targets=targets
139
+ )
140
+ is_tail = self.threshold_func(
141
+ num_classes=num_classes,
142
+ num_samples_each_class=self.num_samples_each_class,
143
+ threshold=self.threshold,
144
+ )
145
+
146
+ weights = 1 - exclude * is_tail * (1 - one_hot_targets)
147
+
148
+ batch_loss = F.binary_cross_entropy_with_logits(
149
+ logits,
150
+ one_hot_targets.float(),
151
+ reduction='none'
152
+ )
153
+
154
+ batch_loss = weights * batch_loss
155
+
156
+ if self.reduction == 'mean':
157
+ loss = batch_loss.mean()
158
+ elif self.reduction == 'sum':
159
+ loss = batch_loss.sum()
160
+ else:
161
+ loss = batch_loss
162
+
163
+ loss = loss / num_classes
164
+ return loss
165
+
166
+ @staticmethod
167
+ def exclude_func(num_classes: int, targets: torch.LongTensor):
168
+ """
169
+ 最后一个类别是背景 background.
170
+ :param num_classes: int,
171
+ :param targets: shape=[batch_size,]
172
+ :return: weight, shape=[batch_size, num_classes]
173
+ """
174
+ batch_size = targets.shape[0]
175
+ weight = (targets != num_classes).float()
176
+ weight = weight.view(batch_size, 1).expand(batch_size, num_classes)
177
+ return weight
178
+
179
+ @staticmethod
180
+ def threshold_func(num_classes: int, num_samples_each_class: np.ndarray, threshold: int):
181
+ """
182
+ :param num_classes: int,
183
+ :param num_samples_each_class: shape=[num_classes]
184
+ :param threshold: int,
185
+ :return: weight, shape=[1, num_classes]
186
+ """
187
+ weight = torch.zeros(size=(num_classes,))
188
+ weight[num_samples_each_class < threshold] = 1
189
+ weight = torch.unsqueeze(weight, dim=0)
190
+ return weight
191
+
192
+
193
+ class FocalLoss(_Loss):
194
+ """
195
+ https://arxiv.org/abs/1708.02002
196
+ """
197
+ @staticmethod
198
+ def demo1(self):
199
+ inputs: torch.FloatTensor = torch.randn(size=(2, 3), dtype=torch.float32)
200
+ targets: torch.LongTensor = torch.tensor([1, 2], dtype=torch.long)
201
+
202
+ focal_loss = FocalLoss(
203
+ num_classes=3,
204
+ reduction='mean',
205
+ # reduction='sum',
206
+ # reduction='none',
207
+ )
208
+ loss = focal_loss.forward(inputs, targets)
209
+ print(loss)
210
+ return
211
+
212
+ def __init__(self,
213
+ num_classes: int,
214
+ alpha: List[float] = None,
215
+ gamma: int = 2,
216
+ reduction: str = 'mean',
217
+ inputs_logits: bool = True) -> None:
218
+ """
219
+ :param num_classes:
220
+ :param alpha:
221
+ :param gamma:
222
+ :param reduction: (`none`, `mean`, `sum`) available.
223
+ :param inputs_logits: if False, the inputs should be probs.
224
+ """
225
+ super(FocalLoss, self).__init__(None, None, reduction)
226
+ if alpha is None:
227
+ self.alpha = torch.ones(num_classes, 1)
228
+ else:
229
+ self.alpha = torch.tensor(alpha, dtype=torch.float32)
230
+ self.gamma = gamma
231
+ self.num_classes = num_classes
232
+ self.inputs_logits = inputs_logits
233
+
234
+ def forward(self,
235
+ inputs: torch.FloatTensor,
236
+ targets: torch.LongTensor):
237
+ """
238
+ :param inputs: logits, shape=[batch_size, num_classes]
239
+ :param targets: shape=[batch_size,]
240
+ :return:
241
+ """
242
+ batch_size, num_classes = inputs.shape
243
+
244
+ if self.inputs_logits:
245
+ probs = F.softmax(inputs, dim=-1)
246
+ else:
247
+ probs = inputs
248
+
249
+ # class_mask = inputs.data.new(batch_size, num_classes).fill_(0)
250
+ class_mask = torch.zeros(size=(batch_size, num_classes), dtype=inputs.dtype, device=inputs.device)
251
+ # class_mask = Variable(class_mask)
252
+ ids = targets.view(-1, 1)
253
+ class_mask.scatter_(1, ids.data, 1.)
254
+
255
+ if inputs.is_cuda and not self.alpha.is_cuda:
256
+ self.alpha = self.alpha.cuda()
257
+ alpha = self.alpha[ids.data.view(-1)]
258
+
259
+ probs = (probs * class_mask).sum(1).view(-1, 1)
260
+
261
+ log_p = probs.log()
262
+
263
+ batch_loss = -alpha*(torch.pow((1-probs), self.gamma))*log_p
264
+
265
+ if self.reduction == 'mean':
266
+ loss = batch_loss.mean()
267
+ elif self.reduction == 'sum':
268
+ loss = batch_loss.sum()
269
+ else:
270
+ loss = batch_loss
271
+ return loss
272
+
273
+
274
+ class HingeLoss(_Loss):
275
+ @staticmethod
276
+ def demo1():
277
+ inputs: torch.FloatTensor = torch.randn(size=(2, 3), dtype=torch.float32)
278
+ targets: torch.LongTensor = torch.tensor([1, 2], dtype=torch.long)
279
+
280
+ hinge_loss = HingeLoss(
281
+ margin_list=[300, 433, 50],
282
+ reduction='mean',
283
+ )
284
+ loss = hinge_loss.forward(inputs=inputs, targets=targets)
285
+ print(loss)
286
+ return
287
+
288
+ def __init__(self,
289
+ margin_list: List[float],
290
+ max_margin: float = 0.5,
291
+ scale: float = 1.0,
292
+ weight: Optional[torch.Tensor] = None,
293
+ reduction: str = 'mean') -> None:
294
+ super(HingeLoss, self).__init__(None, None, reduction)
295
+
296
+ self.max_margin = max_margin
297
+ self.scale = scale
298
+ self.weight = weight
299
+
300
+ margin_list = np.array(margin_list)
301
+ margin_list = margin_list * (max_margin / np.max(margin_list))
302
+ self.margin_list = torch.tensor(margin_list, dtype=torch.float32)
303
+
304
+ def forward(self,
305
+ inputs: torch.FloatTensor,
306
+ targets: torch.LongTensor
307
+ ):
308
+ """
309
+ :param inputs: logits, shape=[batch_size, num_classes]
310
+ :param targets: shape=[batch_size,]
311
+ :return:
312
+ """
313
+ batch_size, num_classes = inputs.shape
314
+ one_hot_targets = F.one_hot(targets, num_classes=num_classes)
315
+ margin_list = torch.unsqueeze(self.margin_list, dim=0)
316
+
317
+ batch_margin = torch.sum(margin_list * one_hot_targets, dim=-1)
318
+ batch_margin = torch.unsqueeze(batch_margin, dim=-1)
319
+ inputs_margin = inputs - batch_margin
320
+
321
+ # 将类别对应的 logits 值减小一点, 以形成 margin 边界.
322
+ logits = torch.where(one_hot_targets > 0, inputs_margin, inputs)
323
+
324
+ loss = F.cross_entropy(
325
+ input=self.scale * logits,
326
+ target=targets,
327
+ weight=self.weight,
328
+ reduction=self.reduction,
329
+ )
330
+ return loss
331
+
332
+
333
+ class HingeLinear(nn.Module):
334
+ """
335
+ use this instead of `HingeLoss`, then you can combine it with `FocalLoss` or others.
336
+ """
337
+ def __init__(self,
338
+ margin_list: List[float],
339
+ max_margin: float = 0.5,
340
+ scale: float = 1.0,
341
+ weight: Optional[torch.Tensor] = None
342
+ ) -> None:
343
+ super(HingeLinear, self).__init__()
344
+
345
+ self.max_margin = max_margin
346
+ self.scale = scale
347
+ self.weight = weight
348
+
349
+ margin_list = np.array(margin_list)
350
+ margin_list = margin_list * (max_margin / np.max(margin_list))
351
+ self.margin_list = torch.tensor(margin_list, dtype=torch.float32)
352
+
353
+ def forward(self,
354
+ inputs: torch.FloatTensor,
355
+ targets: torch.LongTensor
356
+ ):
357
+ """
358
+ :param inputs: logits, shape=[batch_size, num_classes]
359
+ :param targets: shape=[batch_size,]
360
+ :return:
361
+ """
362
+ if self.training and targets is not None:
363
+ batch_size, num_classes = inputs.shape
364
+ one_hot_targets = F.one_hot(targets, num_classes=num_classes)
365
+ margin_list = torch.unsqueeze(self.margin_list, dim=0)
366
+
367
+ batch_margin = torch.sum(margin_list * one_hot_targets, dim=-1)
368
+ batch_margin = torch.unsqueeze(batch_margin, dim=-1)
369
+ inputs_margin = inputs - batch_margin
370
+
371
+ # 将类别对应的 logits 值减小一点, 以形成 margin 边界.
372
+ logits = torch.where(one_hot_targets > 0, inputs_margin, inputs)
373
+ logits = logits * self.scale
374
+ else:
375
+ logits = inputs
376
+ return logits
377
+
378
+
379
+ class LDAMLoss(_Loss):
380
+ """
381
+ https://arxiv.org/abs/1906.07413
382
+ """
383
+ @staticmethod
384
+ def demo1():
385
+ inputs: torch.FloatTensor = torch.randn(size=(2, 3), dtype=torch.float32)
386
+ targets: torch.LongTensor = torch.tensor([1, 2], dtype=torch.long)
387
+
388
+ ldam_loss = LDAMLoss(
389
+ num_samples_each_class=[300, 433, 50],
390
+ reduction='mean',
391
+ )
392
+ loss = ldam_loss.forward(inputs=inputs, targets=targets)
393
+ print(loss)
394
+ return
395
+
396
+ def __init__(self,
397
+ num_samples_each_class: List[int],
398
+ max_margin: float = 0.5,
399
+ scale: float = 30.0,
400
+ weight: Optional[torch.Tensor] = None,
401
+ reduction: str = 'mean') -> None:
402
+ super(LDAMLoss, self).__init__(None, None, reduction)
403
+
404
+ margin_list = np.power(num_samples_each_class, -0.25)
405
+ margin_list = margin_list * (max_margin / np.max(margin_list))
406
+
407
+ self.num_samples_each_class = num_samples_each_class
408
+ self.margin_list = torch.tensor(margin_list, dtype=torch.float32)
409
+ self.scale = scale
410
+ self.weight = weight
411
+
412
+ def forward(self,
413
+ inputs: torch.FloatTensor,
414
+ targets: torch.LongTensor
415
+ ):
416
+ """
417
+ :param inputs: logits, shape=[batch_size, num_classes]
418
+ :param targets: shape=[batch_size,]
419
+ :return:
420
+ """
421
+ batch_size, num_classes = inputs.shape
422
+ one_hot_targets = F.one_hot(targets, num_classes=num_classes)
423
+ margin_list = torch.unsqueeze(self.margin_list, dim=0)
424
+
425
+ batch_margin = torch.sum(margin_list * one_hot_targets, dim=-1)
426
+ batch_margin = torch.unsqueeze(batch_margin, dim=-1)
427
+ inputs_margin = inputs - batch_margin
428
+
429
+ # 将类别对应的 logits 值减小一点, 以形成 margin 边界.
430
+ logits = torch.where(one_hot_targets > 0, inputs_margin, inputs)
431
+
432
+ loss = F.cross_entropy(
433
+ input=self.scale * logits,
434
+ target=targets,
435
+ weight=self.weight,
436
+ reduction=self.reduction,
437
+ )
438
+ return loss
439
+
440
+
441
+ class NegativeEntropy(_Loss):
442
+ def __init__(self,
443
+ reduction: str = 'mean',
444
+ inputs_logits: bool = True) -> None:
445
+ super(NegativeEntropy, self).__init__(None, None, reduction)
446
+ self.inputs_logits = inputs_logits
447
+
448
+ def forward(self,
449
+ inputs: torch.FloatTensor,
450
+ targets: torch.LongTensor):
451
+ if self.inputs_logits:
452
+ probs = F.softmax(inputs, dim=-1)
453
+ log_probs = torch.nn.functional.log_softmax(probs, dim=-1)
454
+ else:
455
+ probs = inputs
456
+ log_probs = torch.log(probs)
457
+
458
+ weighted_negative_likelihood = - log_probs * probs
459
+
460
+ loss = - weighted_negative_likelihood.sum()
461
+ return loss
462
+
463
+
464
+ class LargeMarginSoftMaxLoss(_Loss):
465
+ """
466
+ Alias: L-Softmax
467
+
468
+ https://arxiv.org/abs/1612.02295
469
+ https://github.com/wy1iu/LargeMargin_Softmax_Loss
470
+ https://github.com/amirhfarzaneh/lsoftmax-pytorch/blob/master/lsoftmax.py
471
+
472
+ 参考链接:
473
+ https://www.jianshu.com/p/06cc3f84aa85
474
+
475
+ 论文认为, softmax 和 cross entropy 的组合, 没有明确鼓励对特征进行判别学习.
476
+
477
+ """
478
+ def __init__(self,
479
+ reduction: str = 'mean') -> None:
480
+ super(LargeMarginSoftMaxLoss, self).__init__(None, None, reduction)
481
+
482
+
483
+ class AngularSoftMaxLoss(_Loss):
484
+ """
485
+ Alias: A-Softmax
486
+
487
+ https://arxiv.org/abs/1704.08063
488
+
489
+ https://github.com/woshildh/a-softmax_pytorch/blob/master/a_softmax.py
490
+
491
+ 参考链接:
492
+ https://www.jianshu.com/p/06cc3f84aa85
493
+
494
+ 好像作者认为人脸是一个球面, 所以将向量转换到一个球面上是有帮助的.
495
+ """
496
+ def __init__(self,
497
+ reduction: str = 'mean') -> None:
498
+ super(AngularSoftMaxLoss, self).__init__(None, None, reduction)
499
+
500
+
501
+ class AdditiveMarginSoftMax(_Loss):
502
+ """
503
+ Alias: AM-Softmax
504
+
505
+ https://arxiv.org/abs/1801.05599
506
+
507
+ Large Margin Cosine Loss
508
+ https://arxiv.org/abs/1801.09414
509
+
510
+ 参考链接:
511
+ https://www.jianshu.com/p/06cc3f84aa85
512
+
513
+ 说明:
514
+ 相对于普通的 对 logits 做 softmax,
515
+ 它将真实标签对应的 logit 值减去 m, 来让模型它该值调整得更大一些.
516
+ 另外, 它还将每个 logits 乘以 s, 这可以控制各 logits 之间的相对大小.
517
+ 根 HingeLoss 有点像.
518
+ """
519
+ def __init__(self,
520
+ reduction: str = 'mean') -> None:
521
+ super(AdditiveMarginSoftMax, self).__init__(None, None, reduction)
522
+
523
+
524
+ class AdditiveAngularMarginSoftMax(_Loss):
525
+ """
526
+ Alias: ArcFace, AAM-Softmax
527
+
528
+ ArcFace: Additive Angular Margin Loss for Deep Face Recognition
529
+ https://arxiv.org/abs/1801.07698
530
+
531
+ 参考代码:
532
+ https://github.com/huangkeju/AAMSoftmax-OpenMax/blob/main/AAMSoftmax%2BOvA/metrics.py
533
+
534
+ """
535
+ @staticmethod
536
+ def demo1():
537
+ """
538
+ 角度与数值转换
539
+ pi / 180 代表 1 度,
540
+ pi / 180 = 0.01745
541
+ """
542
+
543
+ # 度数转数值
544
+ degree = 10
545
+ result = degree * math.pi / 180
546
+ print(result)
547
+
548
+ # 数值转数度
549
+ radian = 0.2
550
+ result = radian / (math.pi / 180)
551
+ print(result)
552
+
553
+ return
554
+
555
+ def __init__(self,
556
+ hidden_size: int,
557
+ num_labels: int,
558
+ margin: float = 0.2,
559
+ scale: float = 10.0,
560
+ ):
561
+ """
562
+ :param hidden_size:
563
+ :param num_labels:
564
+ :param margin: 建议取值角度为 [10, 30], 对应的数值为 [0.1745, 0.5236]
565
+ :param scale:
566
+ """
567
+ super(AdditiveAngularMarginSoftMax, self).__init__()
568
+ self.margin = margin
569
+ self.scale = scale
570
+ self.weight = torch.nn.Parameter(torch.FloatTensor(num_labels, hidden_size), requires_grad=True)
571
+ nn.init.xavier_uniform_(self.weight)
572
+
573
+ self.cos_margin = math.cos(self.margin)
574
+ self.sin_margin = math.sin(self.margin)
575
+
576
+ # sin(a-b) = sin(a)cos(b) - cos(a)sin(b)
577
+ # sin(pi - a) = sin(a)
578
+
579
+ self.loss = nn.CrossEntropyLoss()
580
+
581
+ def forward(self,
582
+ inputs: torch.Tensor,
583
+ label: torch.LongTensor = None
584
+ ):
585
+ """
586
+ :param inputs: shape=[batch_size, ..., hidden_size]
587
+ :param label:
588
+ :return: logits
589
+ """
590
+ x = F.normalize(inputs)
591
+ weight = F.normalize(self.weight)
592
+ cosine = F.linear(x, weight)
593
+
594
+ if self.training:
595
+
596
+ # sin^2 + cos^2 = 1
597
+ sine = torch.sqrt((1.0 - torch.mul(cosine, cosine)).clamp(0, 1))
598
+
599
+ # cos(a+b) = cos(a)cos(b) - sin(a)sin(b)
600
+ cosine_theta_margin = cosine * self.cos_margin - sine * self.sin_margin
601
+
602
+ # when the `cosine > - self.cos_margin` there is enough space to add margin on theta.
603
+ cosine_theta_margin = torch.where(cosine > - self.cos_margin, cosine_theta_margin, cosine - (self.margin * self.sin_margin))
604
+
605
+ one_hot = torch.zeros_like(cosine)
606
+ one_hot.scatter_(1, label.view(-1, 1), 1)
607
+
608
+ #
609
+ logits = torch.where(one_hot == 1, cosine_theta_margin, cosine)
610
+ logits = logits * self.scale
611
+ else:
612
+ logits = cosine
613
+
614
+ loss = self.loss(logits, label)
615
+ # prec1 = accuracy(output.detach(), label.detach(), topk=(1,))[0]
616
+ return loss
617
+
618
+
619
+ class AdditiveAngularMarginLinear(nn.Module):
620
+ """
621
+ Alias: ArcFace, AAM-Softmax
622
+
623
+ ArcFace: Additive Angular Margin Loss for Deep Face Recognition
624
+ https://arxiv.org/abs/1801.07698
625
+
626
+ 参考代码:
627
+ https://github.com/huangkeju/AAMSoftmax-OpenMax/blob/main/AAMSoftmax%2BOvA/metrics.py
628
+
629
+ """
630
+ @staticmethod
631
+ def demo1():
632
+ """
633
+ 角度与数值转换
634
+ pi / 180 代表 1 度,
635
+ pi / 180 = 0.01745
636
+ """
637
+
638
+ # 度数转数值
639
+ degree = 10
640
+ result = degree * math.pi / 180
641
+ print(result)
642
+
643
+ # 数值转数度
644
+ radian = 0.2
645
+ result = radian / (math.pi / 180)
646
+ print(result)
647
+
648
+ return
649
+
650
+ @staticmethod
651
+ def demo2():
652
+
653
+ return
654
+
655
+ def __init__(self,
656
+ hidden_size: int,
657
+ num_labels: int,
658
+ margin: float = 0.2,
659
+ scale: float = 10.0,
660
+ ):
661
+ """
662
+ :param hidden_size:
663
+ :param num_labels:
664
+ :param margin: 建议取值角度为 [10, 30], 对应的数值为 [0.1745, 0.5236]
665
+ :param scale:
666
+ """
667
+ super(AdditiveAngularMarginLinear, self).__init__()
668
+ self.margin = margin
669
+ self.scale = scale
670
+ self.weight = torch.nn.Parameter(torch.FloatTensor(num_labels, hidden_size), requires_grad=True)
671
+ nn.init.xavier_uniform_(self.weight)
672
+
673
+ self.cos_margin = math.cos(self.margin)
674
+ self.sin_margin = math.sin(self.margin)
675
+
676
+ # sin(a-b) = sin(a)cos(b) - cos(a)sin(b)
677
+ # sin(pi - a) = sin(a)
678
+
679
+ def forward(self,
680
+ inputs: torch.Tensor,
681
+ targets: torch.LongTensor = None
682
+ ):
683
+ """
684
+ :param inputs: shape=[batch_size, ..., hidden_size]
685
+ :param targets:
686
+ :return: logits
687
+ """
688
+ x = F.normalize(inputs)
689
+ weight = F.normalize(self.weight)
690
+ cosine = F.linear(x, weight)
691
+
692
+ if self.training and targets is not None:
693
+ # sin^2 + cos^2 = 1
694
+ sine = torch.sqrt((1.0 - torch.mul(cosine, cosine)).clamp(0, 1))
695
+
696
+ # cos(a+b) = cos(a)cos(b) - sin(a)sin(b)
697
+ cosine_theta_margin = cosine * self.cos_margin - sine * self.sin_margin
698
+
699
+ # when the `cosine > - self.cos_margin` there is enough space to add margin on theta.
700
+ cosine_theta_margin = torch.where(cosine > - self.cos_margin, cosine_theta_margin, cosine - (self.margin * self.sin_margin))
701
+
702
+ one_hot = torch.zeros_like(cosine)
703
+ one_hot.scatter_(1, targets.view(-1, 1), 1)
704
+
705
+ logits = torch.where(one_hot == 1, cosine_theta_margin, cosine)
706
+ logits = logits * self.scale
707
+ else:
708
+ logits = cosine
709
+ return logits
710
+
711
+
712
+ def demo1():
713
+ HingeLoss.demo1()
714
+ return
715
+
716
+
717
+ def demo2():
718
+ AdditiveAngularMarginSoftMax.demo1()
719
+
720
+ inputs = torch.ones(size=(2, 5), dtype=torch.float32)
721
+ label: torch.LongTensor = torch.tensor(data=[0, 1], dtype=torch.long)
722
+
723
+ aam_softmax = AdditiveAngularMarginSoftMax(
724
+ hidden_size=5,
725
+ num_labels=2,
726
+ margin=1,
727
+ scale=1
728
+ )
729
+
730
+ outputs = aam_softmax.forward(inputs, label)
731
+ print(outputs)
732
+
733
+ return
734
+
735
+
736
+ if __name__ == '__main__':
737
+ # demo1()
738
+ demo2()