[update]add code
Browse files- .gitattributes +1 -0
- .gitignore +8 -0
- README.md +4 -5
- examples/text_classification/telemarketing_intent_classification/1.prepare_data.py +85 -0
- examples/text_classification/telemarketing_intent_classification/2.make_hierarchical_labels.py +74 -0
- examples/text_classification/telemarketing_intent_classification/3.make_vocabulary.py +78 -0
- examples/text_classification/telemarketing_intent_classification/4.train_model.py +172 -0
- examples/text_classification/telemarketing_intent_classification/5.predict_model.py +117 -0
- examples/text_classification/telemarketing_intent_classification/6.make_json_config.py +121 -0
- examples/text_classification/telemarketing_intent_classification/7.predict_by_archive.py +74 -0
- examples/text_classification/telemarketing_intent_classification/run.sh +254 -0
- main.py +141 -0
- predict.py +142 -0
- pretrained_models/bert-base-japanese/.gitattributes +9 -0
- pretrained_models/bert-base-japanese/README.md +43 -0
- pretrained_models/bert-base-japanese/config.json +20 -0
- pretrained_models/bert-base-japanese/tokenizer_config.json +5 -0
- pretrained_models/bert-base-japanese/vocab.txt +0 -0
- pretrained_models/bert-base-uncased/.gitattributes +11 -0
- pretrained_models/bert-base-uncased/LICENSE +201 -0
- pretrained_models/bert-base-uncased/README.md +251 -0
- pretrained_models/bert-base-uncased/config.json +23 -0
- pretrained_models/bert-base-uncased/tokenizer.json +0 -0
- pretrained_models/bert-base-uncased/tokenizer_config.json +3 -0
- pretrained_models/bert-base-uncased/vocab.txt +0 -0
- pretrained_models/bert-base-vietnamese-uncased/.gitattributes +17 -0
- pretrained_models/bert-base-vietnamese-uncased/README.md +22 -0
- pretrained_models/bert-base-vietnamese-uncased/config.json +27 -0
- pretrained_models/bert-base-vietnamese-uncased/special_tokens_map.json +1 -0
- pretrained_models/bert-base-vietnamese-uncased/tokenizer_config.json +1 -0
- pretrained_models/bert-base-vietnamese-uncased/vocab.txt +0 -0
- pretrained_models/chinese-bert-wwm-ext/.gitattributes +9 -0
- pretrained_models/chinese-bert-wwm-ext/README.md +52 -0
- pretrained_models/chinese-bert-wwm-ext/added_tokens.json +1 -0
- pretrained_models/chinese-bert-wwm-ext/config.json +26 -0
- pretrained_models/chinese-bert-wwm-ext/special_tokens_map.json +1 -0
- pretrained_models/chinese-bert-wwm-ext/tokenizer.json +0 -0
- pretrained_models/chinese-bert-wwm-ext/tokenizer_config.json +1 -0
- pretrained_models/chinese-bert-wwm-ext/vocab.txt +0 -0
- project_settings.py +12 -0
- requirements.txt +14 -0
- toolbox/__init__.py +6 -0
- toolbox/allennlp_models/text_classifier/dataset_readers/__init__.py +6 -0
- toolbox/allennlp_models/text_classifier/dataset_readers/hierarchical_classification_json.py +99 -0
- toolbox/allennlp_models/text_classifier/models/__init__.py +6 -0
- toolbox/allennlp_models/text_classifier/models/hierarchical_text_classifier.py +291 -0
- toolbox/torch/__init__.py +6 -0
- toolbox/torch/modules/__init__.py +6 -0
- toolbox/torch/modules/loss.py +738 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*.th filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
.git/
|
3 |
+
.idea/
|
4 |
+
|
5 |
+
**/flagged/
|
6 |
+
**/__pycache__/
|
7 |
+
|
8 |
+
trained_models/
|
README.md
CHANGED
@@ -1,13 +1,12 @@
|
|
1 |
---
|
2 |
title: Telemarketing Intent Classification
|
3 |
-
emoji:
|
4 |
colorFrom: indigo
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
-
sdk_version: 3.
|
8 |
-
app_file:
|
9 |
pinned: false
|
10 |
-
license: apache-2.0
|
11 |
---
|
12 |
|
13 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
---
|
2 |
title: Telemarketing Intent Classification
|
3 |
+
emoji: 😻
|
4 |
colorFrom: indigo
|
5 |
+
colorTo: yellow
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 3.20.1
|
8 |
+
app_file: main.py
|
9 |
pinned: false
|
|
|
10 |
---
|
11 |
|
12 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
examples/text_classification/telemarketing_intent_classification/1.prepare_data.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import argparse
|
4 |
+
import json
|
5 |
+
import os
|
6 |
+
from pathlib import Path
|
7 |
+
import random
|
8 |
+
import sys
|
9 |
+
|
10 |
+
pwd = os.path.abspath(os.path.dirname(__file__))
|
11 |
+
sys.path.append(os.path.join(pwd, '../../../'))
|
12 |
+
|
13 |
+
import pandas as pd
|
14 |
+
from tqdm import tqdm
|
15 |
+
|
16 |
+
from project_settings import project_path
|
17 |
+
|
18 |
+
|
19 |
+
def get_args():
|
20 |
+
parser = argparse.ArgumentParser()
|
21 |
+
parser.add_argument('--without_irrelevant_domain', action='store_true')
|
22 |
+
parser.add_argument('--dataset_filename', default='dataset.xlsx', type=str)
|
23 |
+
parser.add_argument('--do_lowercase', action='store_true')
|
24 |
+
|
25 |
+
parser.add_argument('--train_subset', default='train.json', type=str)
|
26 |
+
parser.add_argument('--valid_subset', default='valid.json', type=str)
|
27 |
+
|
28 |
+
args = parser.parse_args()
|
29 |
+
return args
|
30 |
+
|
31 |
+
|
32 |
+
def main():
|
33 |
+
args = get_args()
|
34 |
+
|
35 |
+
n_hierarchical = 2
|
36 |
+
|
37 |
+
df = pd.read_excel(args.dataset_filename)
|
38 |
+
df = df[df['selected'] == 1]
|
39 |
+
|
40 |
+
dataset = list()
|
41 |
+
for i, row in tqdm(df.iterrows(), total=len(df)):
|
42 |
+
text = row['text']
|
43 |
+
label0 = row['label0']
|
44 |
+
if args.without_irrelevant_domain and label0 == '无关领域':
|
45 |
+
continue
|
46 |
+
|
47 |
+
text = str(text)
|
48 |
+
if args.do_lowercase:
|
49 |
+
text = text.lower()
|
50 |
+
|
51 |
+
labels = {'label{}'.format(idx): str(row['label{}'.format(idx)]) for idx in range(n_hierarchical)}
|
52 |
+
|
53 |
+
random1 = random.random()
|
54 |
+
random2 = random.random()
|
55 |
+
|
56 |
+
dataset.append({
|
57 |
+
'text': text,
|
58 |
+
**labels,
|
59 |
+
|
60 |
+
'random1': random1,
|
61 |
+
'random2': random2,
|
62 |
+
'flag': 'TRAIN' if random2 < 0.8 else 'TEST',
|
63 |
+
})
|
64 |
+
|
65 |
+
dataset = list(sorted(dataset, key=lambda x: x['random1'], reverse=True))
|
66 |
+
|
67 |
+
f_train = open(args.train_subset, 'w', encoding='utf-8')
|
68 |
+
f_test = open(args.valid_subset, 'w', encoding='utf-8')
|
69 |
+
|
70 |
+
for row in tqdm(dataset):
|
71 |
+
|
72 |
+
flag = row['flag']
|
73 |
+
row = json.dumps(row, ensure_ascii=False)
|
74 |
+
if flag == 'TRAIN':
|
75 |
+
f_train.write('{}\n'.format(row))
|
76 |
+
else:
|
77 |
+
f_test.write('{}\n'.format(row))
|
78 |
+
|
79 |
+
f_train.close()
|
80 |
+
f_test.close()
|
81 |
+
return
|
82 |
+
|
83 |
+
|
84 |
+
if __name__ == '__main__':
|
85 |
+
main()
|
examples/text_classification/telemarketing_intent_classification/2.make_hierarchical_labels.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import argparse
|
4 |
+
from collections import OrderedDict
|
5 |
+
import os
|
6 |
+
from pathlib import Path
|
7 |
+
import pickle
|
8 |
+
import sys
|
9 |
+
|
10 |
+
pwd = os.path.abspath(os.path.dirname(__file__))
|
11 |
+
sys.path.append(os.path.join(pwd, '../../'))
|
12 |
+
|
13 |
+
import pandas as pd
|
14 |
+
|
15 |
+
|
16 |
+
def get_args():
|
17 |
+
parser = argparse.ArgumentParser()
|
18 |
+
parser.add_argument('--dataset_filename', default='dataset.xlsx', type=str)
|
19 |
+
parser.add_argument('--hierarchical_labels_pkl', default='hierarchical_labels.pkl', type=str)
|
20 |
+
|
21 |
+
args = parser.parse_args()
|
22 |
+
return args
|
23 |
+
|
24 |
+
|
25 |
+
def main():
|
26 |
+
args = get_args()
|
27 |
+
|
28 |
+
n_hierarchical = 2
|
29 |
+
|
30 |
+
df = pd.read_excel(args.dataset_filename)
|
31 |
+
df = df[df['selected'] == 1]
|
32 |
+
|
33 |
+
# 生成 hierarchical_labels
|
34 |
+
temp_hierarchical_labels = OrderedDict()
|
35 |
+
|
36 |
+
for i, row in df.iterrows():
|
37 |
+
text = row['text']
|
38 |
+
label0 = row['label0']
|
39 |
+
label1 = row['label1']
|
40 |
+
|
41 |
+
if temp_hierarchical_labels.get(label0) is None:
|
42 |
+
temp_hierarchical_labels[label0] = list()
|
43 |
+
|
44 |
+
if label1 not in temp_hierarchical_labels[label0]:
|
45 |
+
temp_hierarchical_labels[label0].append(label1)
|
46 |
+
|
47 |
+
if n_hierarchical > 2:
|
48 |
+
hierarchical_labels = OrderedDict()
|
49 |
+
for idx in range(n_hierarchical - 2):
|
50 |
+
for k, v in temp_hierarchical_labels.items():
|
51 |
+
parent, label = k.rsplit('_', maxsplit=1)
|
52 |
+
|
53 |
+
if hierarchical_labels.get(parent) is None:
|
54 |
+
hierarchical_labels[parent] = OrderedDict({
|
55 |
+
label: v
|
56 |
+
})
|
57 |
+
else:
|
58 |
+
if hierarchical_labels[parent].get(label) is None:
|
59 |
+
hierarchical_labels[parent][label] = v
|
60 |
+
else:
|
61 |
+
hierarchical_labels = temp_hierarchical_labels
|
62 |
+
|
63 |
+
with open(args.hierarchical_labels_pkl, 'wb') as f:
|
64 |
+
pickle.dump(hierarchical_labels, f)
|
65 |
+
|
66 |
+
with open(args.hierarchical_labels_pkl, 'rb') as f:
|
67 |
+
hierarchical_labels = pickle.load(f)
|
68 |
+
|
69 |
+
# print(hierarchical_labels)
|
70 |
+
return
|
71 |
+
|
72 |
+
|
73 |
+
if __name__ == '__main__':
|
74 |
+
main()
|
examples/text_classification/telemarketing_intent_classification/3.make_vocabulary.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import argparse
|
4 |
+
from collections import OrderedDict
|
5 |
+
import json
|
6 |
+
import os
|
7 |
+
from pathlib import Path
|
8 |
+
import pickle
|
9 |
+
import sys
|
10 |
+
|
11 |
+
pwd = os.path.abspath(os.path.dirname(__file__))
|
12 |
+
sys.path.append(os.path.join(pwd, '../../../'))
|
13 |
+
|
14 |
+
import pandas as pd
|
15 |
+
|
16 |
+
from allennlp.data.vocabulary import Vocabulary
|
17 |
+
|
18 |
+
from project_settings import project_path
|
19 |
+
|
20 |
+
|
21 |
+
def get_args():
|
22 |
+
parser = argparse.ArgumentParser()
|
23 |
+
parser.add_argument(
|
24 |
+
"--pretrained_model_path",
|
25 |
+
default=(project_path / "pretrained_models/chinese-bert-wwm-ext").as_posix(),
|
26 |
+
type=str
|
27 |
+
)
|
28 |
+
parser.add_argument('--hierarchical_labels_pkl', default='hierarchical_labels.pkl', type=str)
|
29 |
+
parser.add_argument('--vocabulary', default='vocabulary', type=str)
|
30 |
+
|
31 |
+
args = parser.parse_args()
|
32 |
+
return args
|
33 |
+
|
34 |
+
|
35 |
+
def main():
|
36 |
+
args = get_args()
|
37 |
+
|
38 |
+
with open(args.hierarchical_labels_pkl, 'rb') as f:
|
39 |
+
hierarchical_labels = pickle.load(f)
|
40 |
+
# print(hierarchical_labels)
|
41 |
+
# 深度遍历
|
42 |
+
token_to_index = OrderedDict()
|
43 |
+
tasks = [hierarchical_labels]
|
44 |
+
while len(tasks) != 0:
|
45 |
+
task = tasks.pop(0)
|
46 |
+
for parent, downstream in task.items():
|
47 |
+
if isinstance(downstream, list):
|
48 |
+
for label in downstream:
|
49 |
+
if pd.isna(label):
|
50 |
+
continue
|
51 |
+
label = '{}_{}'.format(parent, label)
|
52 |
+
token_to_index[label] = len(token_to_index)
|
53 |
+
elif isinstance(downstream, OrderedDict):
|
54 |
+
new_task = OrderedDict()
|
55 |
+
for k, v in downstream.items():
|
56 |
+
new_task['{}_{}'.format(parent, k)] = v
|
57 |
+
tasks.append(new_task)
|
58 |
+
else:
|
59 |
+
raise NotImplementedError
|
60 |
+
|
61 |
+
vocabulary = Vocabulary(non_padded_namespaces=['tokens', 'labels'])
|
62 |
+
for label, index in token_to_index.items():
|
63 |
+
vocabulary.add_token_to_namespace(label, namespace='labels')
|
64 |
+
|
65 |
+
vocabulary.set_from_file(
|
66 |
+
filename=os.path.join(args.pretrained_model_path, 'vocab.txt'),
|
67 |
+
is_padded=False,
|
68 |
+
oov_token='[UNK]',
|
69 |
+
namespace='tokens',
|
70 |
+
)
|
71 |
+
vocabulary.save_to_files(args.vocabulary)
|
72 |
+
|
73 |
+
print('注意检查 Vocabulary 中标签的顺序与 hierarchical_labels 是否一致. ')
|
74 |
+
return
|
75 |
+
|
76 |
+
|
77 |
+
if __name__ == '__main__':
|
78 |
+
main()
|
examples/text_classification/telemarketing_intent_classification/4.train_model.py
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
多层 softmax 实现多极文本分类
|
5 |
+
|
6 |
+
由于初始化时, 各层 softmax 的概率趋于平衡.
|
7 |
+
|
8 |
+
因此在第一层时 `领域无关` 就分到了 50% 的概率.
|
9 |
+
|
10 |
+
`领域相关` 中的各类别去分剩下的 50% 的概率.
|
11 |
+
这会导致模型一开始时输出的类别全是 `领域无关`, 这导致模型无法优化.
|
12 |
+
|
13 |
+
解决方案:
|
14 |
+
1. 从数据集中去除 `领域无关` 数据. 并训练模型.
|
15 |
+
2. 等模型收敛之后, 再使用包含 `领域无关` 的数据集, 让模型加载之前的权重, 并重新开始训练模型.
|
16 |
+
|
17 |
+
"""
|
18 |
+
import argparse
|
19 |
+
import json
|
20 |
+
import os
|
21 |
+
import sys
|
22 |
+
|
23 |
+
pwd = os.path.abspath(os.path.dirname(__file__))
|
24 |
+
sys.path.append(os.path.join(pwd, '../../../'))
|
25 |
+
|
26 |
+
from allennlp.data.data_loaders.multiprocess_data_loader import MultiProcessDataLoader
|
27 |
+
from allennlp.data.token_indexers.single_id_token_indexer import SingleIdTokenIndexer
|
28 |
+
from allennlp.data.token_indexers.pretrained_transformer_indexer import PretrainedTransformerIndexer
|
29 |
+
from allennlp.data.tokenizers.pretrained_transformer_tokenizer import PretrainedTransformerTokenizer
|
30 |
+
from allennlp.data.vocabulary import Vocabulary
|
31 |
+
from allennlp.modules.text_field_embedders.basic_text_field_embedder import BasicTextFieldEmbedder
|
32 |
+
from allennlp.modules.token_embedders.embedding import Embedding
|
33 |
+
from allennlp.modules.seq2vec_encoders.cnn_encoder import CnnEncoder
|
34 |
+
from allennlp_models.rc.modules.seq2seq_encoders.stacked_self_attention import StackedSelfAttentionEncoder
|
35 |
+
from allennlp.training.gradient_descent_trainer import GradientDescentTrainer
|
36 |
+
from allennlp.training.checkpointer import Checkpointer
|
37 |
+
from pytorch_pretrained_bert.optimization import BertAdam
|
38 |
+
import torch
|
39 |
+
|
40 |
+
from project_settings import project_path
|
41 |
+
from toolbox.allennlp_models.text_classifier.models.hierarchical_text_classifier import HierarchicalClassifier
|
42 |
+
from toolbox.allennlp_models.text_classifier.dataset_readers.hierarchical_classification_json import HierarchicalClassificationJsonReader
|
43 |
+
|
44 |
+
|
45 |
+
def get_args():
|
46 |
+
parser = argparse.ArgumentParser()
|
47 |
+
parser.add_argument(
|
48 |
+
"--pretrained_model_path",
|
49 |
+
default=(project_path / "pretrained_models/chinese-bert-wwm-ext").as_posix(),
|
50 |
+
type=str
|
51 |
+
)
|
52 |
+
parser.add_argument('--hierarchical_labels_pkl', default='data_dir/hierarchical_labels.pkl', type=str)
|
53 |
+
parser.add_argument('--vocabulary_dir', default='data_dir/vocabulary', type=str)
|
54 |
+
parser.add_argument('--train_subset', default='data_dir/train.json', type=str)
|
55 |
+
parser.add_argument('--valid_subset', default='data_dir/valid.json', type=str)
|
56 |
+
parser.add_argument("--serialization_dir", default="data_dir/serialization_dir", type=str)
|
57 |
+
# parser.add_argument('--checkpoint_path', default="data_dir/serialization_dir/best.th", type=str)
|
58 |
+
parser.add_argument('--checkpoint_path', default=None, type=str)
|
59 |
+
|
60 |
+
args = parser.parse_args()
|
61 |
+
return args
|
62 |
+
|
63 |
+
|
64 |
+
def main():
|
65 |
+
args = get_args()
|
66 |
+
|
67 |
+
dataset_reader = HierarchicalClassificationJsonReader(
|
68 |
+
token_indexers={
|
69 |
+
'tokens': SingleIdTokenIndexer(
|
70 |
+
namespace='tokens',
|
71 |
+
lowercase_tokens=True,
|
72 |
+
token_min_padding_length=5,
|
73 |
+
)
|
74 |
+
},
|
75 |
+
tokenizer=PretrainedTransformerTokenizer(
|
76 |
+
model_name=os.path.join(project_path, args.pretrained_model_path),
|
77 |
+
),
|
78 |
+
)
|
79 |
+
|
80 |
+
vocabulary = Vocabulary.from_files(args.vocabulary_dir)
|
81 |
+
|
82 |
+
data_loader = MultiProcessDataLoader(
|
83 |
+
reader=dataset_reader,
|
84 |
+
data_path=args.train_subset,
|
85 |
+
batch_size=64,
|
86 |
+
shuffle=True,
|
87 |
+
)
|
88 |
+
data_loader.index_with(vocabulary)
|
89 |
+
|
90 |
+
validation_data_loader = MultiProcessDataLoader(
|
91 |
+
reader=dataset_reader,
|
92 |
+
data_path=args.valid_subset,
|
93 |
+
batch_size=64,
|
94 |
+
shuffle=True,
|
95 |
+
)
|
96 |
+
validation_data_loader.index_with(vocabulary)
|
97 |
+
|
98 |
+
model = HierarchicalClassifier(
|
99 |
+
vocab=vocabulary,
|
100 |
+
hierarchical_labels_pkl=args.hierarchical_labels_pkl,
|
101 |
+
text_field_embedder=BasicTextFieldEmbedder(
|
102 |
+
token_embedders={
|
103 |
+
'tokens': Embedding(
|
104 |
+
num_embeddings=vocabulary.get_vocab_size('tokens'),
|
105 |
+
embedding_dim=128,
|
106 |
+
)
|
107 |
+
}
|
108 |
+
),
|
109 |
+
seq2seq_encoder=StackedSelfAttentionEncoder(
|
110 |
+
input_dim=128,
|
111 |
+
hidden_dim=128,
|
112 |
+
projection_dim=128,
|
113 |
+
feedforward_hidden_dim=128,
|
114 |
+
num_layers=2,
|
115 |
+
num_attention_heads=4,
|
116 |
+
use_positional_encoding=False,
|
117 |
+
),
|
118 |
+
seq2vec_encoder=CnnEncoder(
|
119 |
+
embedding_dim=128,
|
120 |
+
num_filters=32,
|
121 |
+
ngram_filter_sizes=(2, 3, 4, 5),
|
122 |
+
),
|
123 |
+
)
|
124 |
+
|
125 |
+
if args.checkpoint_path is not None:
|
126 |
+
with open(args.checkpoint_path, "rb") as f:
|
127 |
+
state_dict = torch.load(f, map_location=torch.device("cpu"))
|
128 |
+
model.load_state_dict(state_dict)
|
129 |
+
model.train()
|
130 |
+
|
131 |
+
parameters = [v for n, v in model.named_parameters()]
|
132 |
+
|
133 |
+
optimizer = BertAdam(
|
134 |
+
params=parameters,
|
135 |
+
lr=5e-4,
|
136 |
+
warmup=0.1,
|
137 |
+
t_total=10000,
|
138 |
+
# t_total=200000,
|
139 |
+
schedule='warmup_linear'
|
140 |
+
)
|
141 |
+
|
142 |
+
if torch.cuda.is_available():
|
143 |
+
cuda_device = 0
|
144 |
+
model.cuda(device=0)
|
145 |
+
else:
|
146 |
+
cuda_device = -1
|
147 |
+
|
148 |
+
print(cuda_device)
|
149 |
+
|
150 |
+
trainer = GradientDescentTrainer(
|
151 |
+
cuda_device=cuda_device,
|
152 |
+
|
153 |
+
model=model,
|
154 |
+
optimizer=optimizer,
|
155 |
+
checkpointer=Checkpointer(
|
156 |
+
serialization_dir=args.serialization_dir,
|
157 |
+
keep_most_recent_by_count=10,
|
158 |
+
),
|
159 |
+
data_loader=data_loader,
|
160 |
+
validation_data_loader=validation_data_loader,
|
161 |
+
patience=5,
|
162 |
+
validation_metric='+accuracy',
|
163 |
+
num_epochs=100,
|
164 |
+
serialization_dir=args.serialization_dir,
|
165 |
+
)
|
166 |
+
trainer.train()
|
167 |
+
|
168 |
+
return
|
169 |
+
|
170 |
+
|
171 |
+
if __name__ == '__main__':
|
172 |
+
main()
|
examples/text_classification/telemarketing_intent_classification/5.predict_model.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import argparse
|
4 |
+
import os
|
5 |
+
import time
|
6 |
+
|
7 |
+
from allennlp.data.tokenizers.pretrained_transformer_tokenizer import PretrainedTransformerTokenizer
|
8 |
+
from allennlp.data.token_indexers.single_id_token_indexer import SingleIdTokenIndexer
|
9 |
+
from allennlp.data.vocabulary import Vocabulary
|
10 |
+
from allennlp.modules.text_field_embedders.basic_text_field_embedder import BasicTextFieldEmbedder
|
11 |
+
from allennlp.modules.token_embedders.embedding import Embedding
|
12 |
+
from allennlp.modules.seq2vec_encoders.cnn_encoder import CnnEncoder
|
13 |
+
from allennlp_models.rc.modules.seq2seq_encoders.stacked_self_attention import StackedSelfAttentionEncoder
|
14 |
+
from allennlp.predictors.text_classifier import TextClassifierPredictor
|
15 |
+
import torch
|
16 |
+
|
17 |
+
from project_settings import project_path
|
18 |
+
from toolbox.allennlp_models.text_classifier.models.hierarchical_text_classifier import HierarchicalClassifier
|
19 |
+
from toolbox.allennlp_models.text_classifier.dataset_readers.hierarchical_classification_json import HierarchicalClassificationJsonReader
|
20 |
+
|
21 |
+
|
22 |
+
def get_args():
|
23 |
+
parser = argparse.ArgumentParser()
|
24 |
+
parser.add_argument(
|
25 |
+
"--pretrained_model_path",
|
26 |
+
default=(project_path / "pretrained_models/chinese-bert-wwm-ext").as_posix(),
|
27 |
+
type=str
|
28 |
+
)
|
29 |
+
parser.add_argument('--hierarchical_labels_pkl', default='data_dir/hierarchical_labels.pkl', type=str)
|
30 |
+
parser.add_argument('--vocabulary_dir', default='data_dir/vocabulary', type=str)
|
31 |
+
|
32 |
+
parser.add_argument(
|
33 |
+
"--serialization_dir",
|
34 |
+
default="data_dir/serialization_dir2",
|
35 |
+
type=str
|
36 |
+
)
|
37 |
+
args = parser.parse_args()
|
38 |
+
return args
|
39 |
+
|
40 |
+
|
41 |
+
def main():
|
42 |
+
args = get_args()
|
43 |
+
|
44 |
+
dataset_reader = HierarchicalClassificationJsonReader(
|
45 |
+
token_indexers={
|
46 |
+
'tokens': SingleIdTokenIndexer(
|
47 |
+
namespace='tokens',
|
48 |
+
lowercase_tokens=True,
|
49 |
+
token_min_padding_length=5,
|
50 |
+
)
|
51 |
+
},
|
52 |
+
tokenizer=PretrainedTransformerTokenizer(
|
53 |
+
model_name=os.path.join(project_path, args.pretrained_model_path),
|
54 |
+
),
|
55 |
+
)
|
56 |
+
|
57 |
+
vocabulary = Vocabulary.from_files(args.vocabulary_dir)
|
58 |
+
|
59 |
+
model = HierarchicalClassifier(
|
60 |
+
vocab=vocabulary,
|
61 |
+
hierarchical_labels_pkl=args.hierarchical_labels_pkl,
|
62 |
+
text_field_embedder=BasicTextFieldEmbedder(
|
63 |
+
token_embedders={
|
64 |
+
'tokens': Embedding(
|
65 |
+
num_embeddings=vocabulary.get_vocab_size('tokens'),
|
66 |
+
embedding_dim=128,
|
67 |
+
)
|
68 |
+
}
|
69 |
+
),
|
70 |
+
seq2seq_encoder=StackedSelfAttentionEncoder(
|
71 |
+
input_dim=128,
|
72 |
+
hidden_dim=128,
|
73 |
+
projection_dim=128,
|
74 |
+
feedforward_hidden_dim=128,
|
75 |
+
num_layers=2,
|
76 |
+
num_attention_heads=4,
|
77 |
+
use_positional_encoding=False,
|
78 |
+
),
|
79 |
+
seq2vec_encoder=CnnEncoder(
|
80 |
+
embedding_dim=128,
|
81 |
+
num_filters=32,
|
82 |
+
ngram_filter_sizes=(2, 3, 4, 5),
|
83 |
+
),
|
84 |
+
)
|
85 |
+
|
86 |
+
checkpoint_path = os.path.join(args.serialization_dir, "best.th")
|
87 |
+
with open(checkpoint_path, 'rb') as f:
|
88 |
+
state_dict = torch.load(f, map_location="cpu")
|
89 |
+
model.load_state_dict(state_dict, strict=True)
|
90 |
+
model.eval()
|
91 |
+
|
92 |
+
predictor = TextClassifierPredictor(
|
93 |
+
model=model,
|
94 |
+
dataset_reader=dataset_reader,
|
95 |
+
)
|
96 |
+
|
97 |
+
while True:
|
98 |
+
text = input("text: ")
|
99 |
+
if text == "Quit":
|
100 |
+
break
|
101 |
+
|
102 |
+
json_dict = {'sentence': text}
|
103 |
+
|
104 |
+
begin_time = time.time()
|
105 |
+
outputs = predictor.predict_json(
|
106 |
+
json_dict
|
107 |
+
)
|
108 |
+
|
109 |
+
outputs = predictor._model.decode(outputs)
|
110 |
+
label = outputs['label']
|
111 |
+
print(label)
|
112 |
+
print('time cost: {}'.format(time.time() - begin_time))
|
113 |
+
return
|
114 |
+
|
115 |
+
|
116 |
+
if __name__ == '__main__':
|
117 |
+
main()
|
examples/text_classification/telemarketing_intent_classification/6.make_json_config.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import argparse
|
4 |
+
import json
|
5 |
+
import os
|
6 |
+
|
7 |
+
from allennlp.data.vocabulary import Vocabulary
|
8 |
+
import torch
|
9 |
+
|
10 |
+
from project_settings import project_path
|
11 |
+
|
12 |
+
|
13 |
+
def get_args():
|
14 |
+
parser = argparse.ArgumentParser()
|
15 |
+
parser.add_argument(
|
16 |
+
"--pretrained_model_path",
|
17 |
+
default=(project_path / "pretrained_models/chinese-bert-wwm-ext").as_posix(),
|
18 |
+
type=str
|
19 |
+
)
|
20 |
+
parser.add_argument('--hierarchical_labels_pkl', default='data_dir/hierarchical_labels.pkl', type=str)
|
21 |
+
parser.add_argument('--vocabulary_dir', default='data_dir/vocabulary', type=str)
|
22 |
+
parser.add_argument('--train_subset', default='data_dir/train.json', type=str)
|
23 |
+
parser.add_argument('--valid_subset', default='data_dir/valid.json', type=str)
|
24 |
+
parser.add_argument("--serialization_dir", default="data_dir/serialization_dir", type=str)
|
25 |
+
parser.add_argument("--json_config_dir", default="data_dir", type=str)
|
26 |
+
args = parser.parse_args()
|
27 |
+
return args
|
28 |
+
|
29 |
+
|
30 |
+
def main():
|
31 |
+
args = get_args()
|
32 |
+
|
33 |
+
vocabulary = Vocabulary.from_files(args.vocabulary_dir)
|
34 |
+
|
35 |
+
if torch.cuda.is_available():
|
36 |
+
cuda_device = 0
|
37 |
+
else:
|
38 |
+
cuda_device = -1
|
39 |
+
|
40 |
+
json_config = {
|
41 |
+
"dataset_reader": {
|
42 |
+
"type": "hierarchical_classification_json",
|
43 |
+
"token_indexers": {
|
44 |
+
"tokens": {
|
45 |
+
"type": "single_id",
|
46 |
+
"namespace": "tokens",
|
47 |
+
"lowercase_tokens": True,
|
48 |
+
"token_min_padding_length": 5
|
49 |
+
}
|
50 |
+
},
|
51 |
+
"tokenizer": {
|
52 |
+
"type": "pretrained_transformer",
|
53 |
+
"model_name": args.pretrained_model_path
|
54 |
+
}
|
55 |
+
},
|
56 |
+
"train_data_path": args.train_subset,
|
57 |
+
"validation_data_path": args.valid_subset,
|
58 |
+
"vocabulary": {
|
59 |
+
"directory_path": args.vocabulary_dir,
|
60 |
+
},
|
61 |
+
"model": {
|
62 |
+
"type": "hierarchical_classifier",
|
63 |
+
"hierarchical_labels_pkl": args.hierarchical_labels_pkl,
|
64 |
+
"text_field_embedder": {
|
65 |
+
"token_embedders": {
|
66 |
+
"tokens": {
|
67 |
+
"type": "embedding",
|
68 |
+
"num_embeddings": vocabulary.get_vocab_size(namespace="tokens"),
|
69 |
+
"embedding_dim": 128
|
70 |
+
}
|
71 |
+
}
|
72 |
+
},
|
73 |
+
"seq2seq_encoder": {
|
74 |
+
"type": "stacked_self_attention",
|
75 |
+
"input_dim": 128,
|
76 |
+
"hidden_dim": 128,
|
77 |
+
"projection_dim": 128,
|
78 |
+
"feedforward_hidden_dim": 128,
|
79 |
+
"num_layers": 2,
|
80 |
+
"num_attention_heads": 4,
|
81 |
+
"use_positional_encoding": False
|
82 |
+
},
|
83 |
+
"seq2vec_encoder": {
|
84 |
+
"type": "cnn",
|
85 |
+
"embedding_dim": 128,
|
86 |
+
"num_filters": 32,
|
87 |
+
"ngram_filter_sizes": (2, 3, 4, 5),
|
88 |
+
},
|
89 |
+
},
|
90 |
+
"data_loader": {
|
91 |
+
"type": "multiprocess",
|
92 |
+
"batch_size": 64,
|
93 |
+
"shuffle": True
|
94 |
+
},
|
95 |
+
"trainer": {
|
96 |
+
"type": "gradient_descent",
|
97 |
+
"cuda_device": cuda_device,
|
98 |
+
"optimizer": {
|
99 |
+
"type": "bert_adam",
|
100 |
+
"lr": 5e-5,
|
101 |
+
"warmup": 0.1,
|
102 |
+
"t_total": 50000,
|
103 |
+
"schedule": "warmup_linear"
|
104 |
+
},
|
105 |
+
"checkpointer": {
|
106 |
+
"serialization_dir": args.serialization_dir,
|
107 |
+
"keep_most_recent_by_count": 10
|
108 |
+
},
|
109 |
+
"patience": 5,
|
110 |
+
"validation_metric": "+accuracy",
|
111 |
+
"num_epochs": 200
|
112 |
+
}
|
113 |
+
}
|
114 |
+
|
115 |
+
with open(os.path.join(args.json_config_dir, "config.json"), "w", encoding="utf-8") as f:
|
116 |
+
json.dump(json_config, f, indent=4, ensure_ascii=False)
|
117 |
+
return
|
118 |
+
|
119 |
+
|
120 |
+
if __name__ == '__main__':
|
121 |
+
main()
|
examples/text_classification/telemarketing_intent_classification/7.predict_by_archive.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import argparse
|
4 |
+
import os
|
5 |
+
import time
|
6 |
+
|
7 |
+
from allennlp.data.tokenizers.pretrained_transformer_tokenizer import PretrainedTransformerTokenizer
|
8 |
+
from allennlp.data.token_indexers.single_id_token_indexer import SingleIdTokenIndexer
|
9 |
+
from allennlp.data.vocabulary import Vocabulary
|
10 |
+
from allennlp.modules.text_field_embedders.basic_text_field_embedder import BasicTextFieldEmbedder
|
11 |
+
from allennlp.modules.token_embedders.embedding import Embedding
|
12 |
+
from allennlp.modules.seq2vec_encoders.cnn_encoder import CnnEncoder
|
13 |
+
from allennlp.models.archival import archive_model, load_archive
|
14 |
+
from allennlp_models.rc.modules.seq2seq_encoders.stacked_self_attention import StackedSelfAttentionEncoder
|
15 |
+
from allennlp.predictors.predictor import Predictor
|
16 |
+
from allennlp.predictors.text_classifier import TextClassifierPredictor
|
17 |
+
import torch
|
18 |
+
|
19 |
+
from project_settings import project_path
|
20 |
+
from toolbox.allennlp_models.text_classifier.models.hierarchical_text_classifier import HierarchicalClassifier
|
21 |
+
from toolbox.allennlp_models.text_classifier.dataset_readers.hierarchical_classification_json import HierarchicalClassificationJsonReader
|
22 |
+
|
23 |
+
|
24 |
+
def get_args():
|
25 |
+
parser = argparse.ArgumentParser()
|
26 |
+
parser.add_argument(
|
27 |
+
"--text",
|
28 |
+
default="给我推荐一些篮球游戏?",
|
29 |
+
type=str
|
30 |
+
)
|
31 |
+
parser.add_argument(
|
32 |
+
"--archive_file",
|
33 |
+
default=(project_path / "trained_models/telemarketing_intent_classification_vi").as_posix(),
|
34 |
+
type=str
|
35 |
+
)
|
36 |
+
parser.add_argument(
|
37 |
+
"--pretrained_model_path",
|
38 |
+
default=(project_path / "pretrained_models/chinese-bert-wwm-ext").as_posix(),
|
39 |
+
type=str
|
40 |
+
)
|
41 |
+
parser.add_argument(
|
42 |
+
"--predictor_name",
|
43 |
+
default="text_classifier",
|
44 |
+
type=str
|
45 |
+
)
|
46 |
+
args = parser.parse_args()
|
47 |
+
return args
|
48 |
+
|
49 |
+
|
50 |
+
def main():
|
51 |
+
args = get_args()
|
52 |
+
|
53 |
+
archive = load_archive(archive_file=args.archive_file)
|
54 |
+
|
55 |
+
predictor = Predictor.from_archive(archive, predictor_name=args.predictor_name)
|
56 |
+
|
57 |
+
json_dict = {
|
58 |
+
"sentence": args.text
|
59 |
+
}
|
60 |
+
|
61 |
+
begin_time = time.time()
|
62 |
+
outputs = predictor.predict_json(
|
63 |
+
json_dict
|
64 |
+
)
|
65 |
+
outputs = predictor._model.decode(outputs)
|
66 |
+
label = outputs['label']
|
67 |
+
print(label)
|
68 |
+
print('time cost: {}'.format(time.time() - begin_time))
|
69 |
+
|
70 |
+
return
|
71 |
+
|
72 |
+
|
73 |
+
if __name__ == '__main__':
|
74 |
+
main()
|
examples/text_classification/telemarketing_intent_classification/run.sh
ADDED
@@ -0,0 +1,254 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
|
3 |
+
# nohup sh run.sh --system_version centos --stage 0 --stop_stage 5 &
|
4 |
+
|
5 |
+
# sh run.sh --system_version windows --stage -2 --stop_stage 8
|
6 |
+
# sh run.sh --system_version windows --stage 0 --stop_stage 0
|
7 |
+
# sh run.sh --system_version windows --stage 1 --stop_stage 1
|
8 |
+
# sh run.sh --system_version windows --stage 2 --stop_stage 2
|
9 |
+
# sh run.sh --system_version windows --stage 0 --stop_stage 5
|
10 |
+
# sh run.sh --system_version windows --stage 6 --stop_stage 6
|
11 |
+
|
12 |
+
# params
|
13 |
+
system_version="centos";
|
14 |
+
verbose=true;
|
15 |
+
stage=0 # start from 0 if you need to start from data preparation
|
16 |
+
stop_stage=5
|
17 |
+
|
18 |
+
|
19 |
+
#trained_model_name=telemarketing_intent_classification_cn
|
20 |
+
#pretrained_bert_model_name=chinese-bert-wwm-ext
|
21 |
+
#dataset_fn="telemarketing_intent_cn.xlsx"
|
22 |
+
|
23 |
+
#trained_model_name=telemarketing_intent_classification_en
|
24 |
+
#pretrained_bert_model_name=bert-base-uncased
|
25 |
+
#dataset_fn="telemarketing_intent_en.xlsx"
|
26 |
+
|
27 |
+
#trained_model_name=telemarketing_intent_classification_jp
|
28 |
+
#pretrained_bert_model_name=bert-base-japanese
|
29 |
+
#dataset_fn="telemarketing_intent_jp.xlsx"
|
30 |
+
|
31 |
+
trained_model_name=telemarketing_intent_classification_vi
|
32 |
+
pretrained_bert_model_name=bert-base-vietnamese-uncased
|
33 |
+
dataset_fn="telemarketing_intent_vi.xlsx"
|
34 |
+
|
35 |
+
# parse options
|
36 |
+
while true; do
|
37 |
+
[ -z "${1:-}" ] && break; # break if there are no arguments
|
38 |
+
case "$1" in
|
39 |
+
--*) name=$(echo "$1" | sed s/^--// | sed s/-/_/g);
|
40 |
+
eval '[ -z "${'"$name"'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
|
41 |
+
old_value="(eval echo \\$$name)";
|
42 |
+
if [ "${old_value}" == "true" ] || [ "${old_value}" == "false" ]; then
|
43 |
+
was_bool=true;
|
44 |
+
else
|
45 |
+
was_bool=false;
|
46 |
+
fi
|
47 |
+
|
48 |
+
# Set the variable to the right value-- the escaped quotes make it work if
|
49 |
+
# the option had spaces, like --cmd "queue.pl -sync y"
|
50 |
+
eval "${name}=\"$2\"";
|
51 |
+
|
52 |
+
# Check that Boolean-valued arguments are really Boolean.
|
53 |
+
if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
|
54 |
+
echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
|
55 |
+
exit 1;
|
56 |
+
fi
|
57 |
+
shift 2;
|
58 |
+
;;
|
59 |
+
|
60 |
+
*) break;
|
61 |
+
esac
|
62 |
+
done
|
63 |
+
|
64 |
+
$verbose && echo "system_version: ${system_version}"
|
65 |
+
|
66 |
+
work_dir="$(pwd)"
|
67 |
+
data_dir="$(pwd)/data_dir"
|
68 |
+
pretrained_models_dir="${work_dir}/../../../pretrained_models";
|
69 |
+
trained_models_dir="${work_dir}/../../../trained_models";
|
70 |
+
|
71 |
+
serialization_dir1="${data_dir}/serialization_dir1";
|
72 |
+
serialization_dir2="${data_dir}/serialization_dir2";
|
73 |
+
|
74 |
+
mkdir -p "${data_dir}"
|
75 |
+
mkdir -p "${pretrained_models_dir}"
|
76 |
+
mkdir -p "${trained_models_dir}"
|
77 |
+
mkdir -p "${serialization_dir1}"
|
78 |
+
mkdir -p "${serialization_dir2}"
|
79 |
+
|
80 |
+
vocabulary_dir="${data_dir}/vocabulary"
|
81 |
+
train_subset="${data_dir}/train.json"
|
82 |
+
valid_subset="${data_dir}/valid.json"
|
83 |
+
hierarchical_labels_pkl="${data_dir}/hierarchical_labels.pkl"
|
84 |
+
dataset_filename="${data_dir}/${dataset_fn}"
|
85 |
+
|
86 |
+
export PYTHONPATH="${work_dir}/../../.."
|
87 |
+
|
88 |
+
if [ $system_version == "windows" ]; then
|
89 |
+
alias python3='C:/Users/tianx/PycharmProjects/virtualenv/AllenNLP/Scripts/python.exe'
|
90 |
+
elif [ $system_version == "centos" ]; then
|
91 |
+
source /data/local/bin/AllenNLP/bin/activate
|
92 |
+
alias python3='/data/local/bin/AllenNLP/bin/python3'
|
93 |
+
elif [ $system_version == "ubuntu" ]; then
|
94 |
+
source /data/local/bin/AllenNLP/bin/activate
|
95 |
+
alias python3='/data/local/bin/AllenNLP/bin/python3'
|
96 |
+
fi
|
97 |
+
|
98 |
+
|
99 |
+
declare -A pretrained_bert_model_dict
|
100 |
+
pretrained_bert_model_dict=(
|
101 |
+
["chinese-bert-wwm-ext"]="https://huggingface.co/hfl/chinese-bert-wwm-ext"
|
102 |
+
["bert-base-uncased"]="https://huggingface.co/bert-base-uncased"
|
103 |
+
["bert-base-japanese"]="https://huggingface.co/cl-tohoku/bert-base-japanese"
|
104 |
+
["bert-base-vietnamese-uncased"]="https://huggingface.co/trituenhantaoio/bert-base-vietnamese-uncased"
|
105 |
+
|
106 |
+
)
|
107 |
+
pretrained_model_dir="${pretrained_models_dir}/${pretrained_bert_model_name}"
|
108 |
+
|
109 |
+
|
110 |
+
if [ ${stage} -le -2 ] && [ ${stop_stage} -ge -2 ]; then
|
111 |
+
$verbose && echo "stage -2: download pretrained models"
|
112 |
+
cd "${work_dir}" || exit 1;
|
113 |
+
|
114 |
+
if [ ! -d "${pretrained_model_dir}" ]; then
|
115 |
+
mkdir -p "${pretrained_models_dir}"
|
116 |
+
cd "${pretrained_models_dir}" || exit 1;
|
117 |
+
|
118 |
+
repository_url="${pretrained_bert_model_dict[${pretrained_bert_model_name}]}"
|
119 |
+
git clone "${repository_url}"
|
120 |
+
|
121 |
+
cd "${pretrained_model_dir}" || exit 1;
|
122 |
+
rm flax_model.msgpack && rm pytorch_model.bin && rm tf_model.h5
|
123 |
+
rm -rf .git/
|
124 |
+
wget "${repository_url}/resolve/main/pytorch_model.bin"
|
125 |
+
fi
|
126 |
+
fi
|
127 |
+
|
128 |
+
|
129 |
+
if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
|
130 |
+
$verbose && echo "stage -1: download data"
|
131 |
+
cd "${data_dir}" || exit 1;
|
132 |
+
|
133 |
+
wget "https://huggingface.co/datasets/qgyd2021/telemarketing_intent/resolve/main/${dataset_fn}"
|
134 |
+
|
135 |
+
fi
|
136 |
+
|
137 |
+
|
138 |
+
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
|
139 |
+
$verbose && echo "stage 0: prepare data without irrelevant domain (make train subset, valid subset file)"
|
140 |
+
cd "${work_dir}" || exit 1;
|
141 |
+
|
142 |
+
python3 1.prepare_data.py \
|
143 |
+
--without_irrelevant_domain \
|
144 |
+
--dataset_filename "${dataset_filename}" \
|
145 |
+
--do_lowercase \
|
146 |
+
--train_subset "${train_subset}" \
|
147 |
+
--valid_subset "${valid_subset}" \
|
148 |
+
|
149 |
+
fi
|
150 |
+
|
151 |
+
|
152 |
+
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
153 |
+
$verbose && echo "stage 1: make hierarchical labels dictionary (make hierarchical_labels.pkl file)"
|
154 |
+
cd "${work_dir}" || exit 1
|
155 |
+
python3 2.make_hierarchical_labels.py \
|
156 |
+
--dataset_filename "${dataset_filename}" \
|
157 |
+
--hierarchical_labels_pkl "${hierarchical_labels_pkl}" \
|
158 |
+
|
159 |
+
fi
|
160 |
+
|
161 |
+
|
162 |
+
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
163 |
+
$verbose && echo "stage 2: make vocabulary (make vocabulary directory)"
|
164 |
+
cd "${work_dir}" || exit 1
|
165 |
+
python3 3.make_vocabulary.py \
|
166 |
+
--pretrained_model_path "${pretrained_model_dir}" \
|
167 |
+
--hierarchical_labels_pkl "${hierarchical_labels_pkl}" \
|
168 |
+
--vocabulary "${vocabulary_dir}" \
|
169 |
+
|
170 |
+
fi
|
171 |
+
|
172 |
+
|
173 |
+
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
174 |
+
$verbose && echo "stage 3: train model without irrelevant domain"
|
175 |
+
cd "${work_dir}" || exit 1
|
176 |
+
|
177 |
+
python3 4.train_model.py \
|
178 |
+
--pretrained_model_path "${pretrained_model_dir}" \
|
179 |
+
--hierarchical_labels_pkl "${hierarchical_labels_pkl}" \
|
180 |
+
--vocabulary_dir "${vocabulary_dir}" \
|
181 |
+
--train_subset "${train_subset}" \
|
182 |
+
--valid_subset "${valid_subset}" \
|
183 |
+
--serialization_dir "${serialization_dir1}" \
|
184 |
+
|
185 |
+
fi
|
186 |
+
|
187 |
+
|
188 |
+
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
|
189 |
+
$verbose && echo "stage 4: prepare data with irrelevant domain"
|
190 |
+
cd "${work_dir}" || exit 1
|
191 |
+
|
192 |
+
python3 1.prepare_data.py \
|
193 |
+
--dataset_filename "${dataset_filename}" \
|
194 |
+
--do_lowercase \
|
195 |
+
--train_subset "${train_subset}" \
|
196 |
+
--valid_subset "${valid_subset}" \
|
197 |
+
|
198 |
+
fi
|
199 |
+
|
200 |
+
|
201 |
+
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
|
202 |
+
$verbose && echo "stage 5: train model with irrelevant domain"
|
203 |
+
cd "${work_dir}" || exit 1
|
204 |
+
|
205 |
+
python3 4.train_model.py \
|
206 |
+
--pretrained_model_path "${pretrained_model_dir}" \
|
207 |
+
--hierarchical_labels_pkl "${hierarchical_labels_pkl}" \
|
208 |
+
--vocabulary_dir "${vocabulary_dir}" \
|
209 |
+
--train_subset "${train_subset}" \
|
210 |
+
--valid_subset "${valid_subset}" \
|
211 |
+
--serialization_dir "${serialization_dir2}" \
|
212 |
+
--checkpoint_path "${serialization_dir1}/best.th"
|
213 |
+
|
214 |
+
fi
|
215 |
+
|
216 |
+
|
217 |
+
if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
|
218 |
+
$verbose && echo "stage 6: make json config"
|
219 |
+
cd "${work_dir}" || exit 1
|
220 |
+
python3 6.make_json_config.py \
|
221 |
+
--pretrained_model_path "${pretrained_model_dir}" \
|
222 |
+
--hierarchical_labels_pkl "${hierarchical_labels_pkl}" \
|
223 |
+
--vocabulary_dir "${vocabulary_dir}" \
|
224 |
+
--train_subset "${train_subset}" \
|
225 |
+
--valid_subset "${valid_subset}" \
|
226 |
+
--serialization_dir "${serialization_dir2}" \
|
227 |
+
--json_config_dir "${data_dir}" \
|
228 |
+
|
229 |
+
fi
|
230 |
+
|
231 |
+
|
232 |
+
if [ ${stage} -le 7 ] && [ ${stop_stage} -ge 7 ]; then
|
233 |
+
$verbose && echo "stage 7: collect files"
|
234 |
+
cd "${work_dir}" || exit 1;
|
235 |
+
|
236 |
+
mkdir -p "${trained_models_dir}/${trained_model_name}"
|
237 |
+
|
238 |
+
cp -r "${vocabulary_dir}" "${trained_models_dir}/${trained_model_name}/vocabulary/"
|
239 |
+
cp "${serialization_dir2}/best.th" "${trained_models_dir}/${trained_model_name}/weights.th"
|
240 |
+
cp "${data_dir}/config.json" "${trained_models_dir}/${trained_model_name}/config.json"
|
241 |
+
cp "${hierarchical_labels_pkl}" "${trained_models_dir}/${trained_model_name}"
|
242 |
+
|
243 |
+
fi
|
244 |
+
|
245 |
+
|
246 |
+
if [ ${stage} -le 8 ] && [ ${stop_stage} -ge 8 ]; then
|
247 |
+
$verbose && echo "stage 8: predict by archive"
|
248 |
+
cd "${work_dir}" || exit 1;
|
249 |
+
|
250 |
+
python3 7.predict_by_archive.py \
|
251 |
+
--archive_file "${trained_models_dir}/${trained_model_name}" \
|
252 |
+
--pretrained_model_path "${pretrained_model_dir}" \
|
253 |
+
|
254 |
+
fi
|
main.py
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import argparse
|
4 |
+
import os
|
5 |
+
import time
|
6 |
+
|
7 |
+
from allennlp.data.tokenizers.pretrained_transformer_tokenizer import PretrainedTransformerTokenizer
|
8 |
+
from allennlp.data.token_indexers.single_id_token_indexer import SingleIdTokenIndexer
|
9 |
+
from allennlp.data.vocabulary import Vocabulary
|
10 |
+
from allennlp.modules.text_field_embedders.basic_text_field_embedder import BasicTextFieldEmbedder
|
11 |
+
from allennlp.modules.token_embedders.embedding import Embedding
|
12 |
+
from allennlp.modules.seq2vec_encoders.cnn_encoder import CnnEncoder
|
13 |
+
from allennlp.models.archival import archive_model, load_archive
|
14 |
+
from allennlp_models.rc.modules.seq2seq_encoders.stacked_self_attention import StackedSelfAttentionEncoder
|
15 |
+
from allennlp.predictors.predictor import Predictor
|
16 |
+
from allennlp.predictors.text_classifier import TextClassifierPredictor
|
17 |
+
import gradio as gr
|
18 |
+
import torch
|
19 |
+
|
20 |
+
from project_settings import project_path
|
21 |
+
from toolbox.allennlp_models.text_classifier.models.hierarchical_text_classifier import HierarchicalClassifier
|
22 |
+
from toolbox.allennlp_models.text_classifier.dataset_readers.hierarchical_classification_json import HierarchicalClassificationJsonReader
|
23 |
+
|
24 |
+
|
25 |
+
def get_args():
|
26 |
+
parser = argparse.ArgumentParser()
|
27 |
+
parser.add_argument(
|
28 |
+
"--cn_archive_file",
|
29 |
+
default=(project_path / "trained_models/telemarketing_intent_classification_cn").as_posix(),
|
30 |
+
type=str
|
31 |
+
)
|
32 |
+
parser.add_argument(
|
33 |
+
"--en_archive_file",
|
34 |
+
default=(project_path / "trained_models/telemarketing_intent_classification_en").as_posix(),
|
35 |
+
type=str
|
36 |
+
)
|
37 |
+
parser.add_argument(
|
38 |
+
"--jp_archive_file",
|
39 |
+
default=(project_path / "trained_models/telemarketing_intent_classification_jp").as_posix(),
|
40 |
+
type=str
|
41 |
+
)
|
42 |
+
parser.add_argument(
|
43 |
+
"--vi_archive_file",
|
44 |
+
default=(project_path / "trained_models/telemarketing_intent_classification_vi").as_posix(),
|
45 |
+
type=str
|
46 |
+
)
|
47 |
+
parser.add_argument(
|
48 |
+
"--predictor_name",
|
49 |
+
default="text_classifier",
|
50 |
+
type=str
|
51 |
+
)
|
52 |
+
args = parser.parse_args()
|
53 |
+
return args
|
54 |
+
|
55 |
+
|
56 |
+
def main():
|
57 |
+
args = get_args()
|
58 |
+
|
59 |
+
cn_archive = load_archive(archive_file=args.cn_archive_file)
|
60 |
+
cn_predictor = Predictor.from_archive(cn_archive, predictor_name=args.predictor_name)
|
61 |
+
en_archive = load_archive(archive_file=args.en_archive_file)
|
62 |
+
en_predictor = Predictor.from_archive(en_archive, predictor_name=args.predictor_name)
|
63 |
+
jp_archive = load_archive(archive_file=args.jp_archive_file)
|
64 |
+
jp_predictor = Predictor.from_archive(jp_archive, predictor_name=args.predictor_name)
|
65 |
+
vi_archive = load_archive(archive_file=args.vi_archive_file)
|
66 |
+
vi_predictor = Predictor.from_archive(vi_archive, predictor_name=args.predictor_name)
|
67 |
+
|
68 |
+
predictor_map = {
|
69 |
+
"chinese": cn_predictor,
|
70 |
+
"english": en_predictor,
|
71 |
+
"japanese": jp_predictor,
|
72 |
+
"vietnamese": vi_predictor,
|
73 |
+
}
|
74 |
+
|
75 |
+
def fn(text: str, language: str):
|
76 |
+
predictor = predictor_map.get(language, cn_predictor)
|
77 |
+
|
78 |
+
json_dict = {'sentence': text}
|
79 |
+
outputs = predictor.predict_json(
|
80 |
+
json_dict
|
81 |
+
)
|
82 |
+
outputs = predictor._model.decode(outputs)
|
83 |
+
label = outputs['label'][0]
|
84 |
+
prob = outputs['prob'][0]
|
85 |
+
prob = round(prob, 4)
|
86 |
+
return label, prob
|
87 |
+
|
88 |
+
description = """
|
89 |
+
电销场景意图识别.
|
90 |
+
语言: 汉语, 英语, 日语, 越南语.
|
91 |
+
数据集是私有的.
|
92 |
+
|
93 |
+
model: selfattention-cnn
|
94 |
+
dataset: telemarketing_intent (https://huggingface.co/datasets/qgyd2021/telemarketing_intent)
|
95 |
+
|
96 |
+
accuracy:
|
97 |
+
chinese: 0.8002
|
98 |
+
english: 0.7011
|
99 |
+
japanese: 0.8154
|
100 |
+
vietnamese: 0.8168
|
101 |
+
|
102 |
+
"""
|
103 |
+
demo = gr.Interface(
|
104 |
+
fn=fn,
|
105 |
+
inputs=[
|
106 |
+
gr.Text(label="text"),
|
107 |
+
gr.Dropdown(
|
108 |
+
choices=list(sorted(predictor_map.keys())),
|
109 |
+
label="language"
|
110 |
+
)
|
111 |
+
],
|
112 |
+
outputs=[gr.Text(label="intent"), gr.Number(label="prob")],
|
113 |
+
examples=[
|
114 |
+
["你找谁", "chinese"],
|
115 |
+
["你是谁啊", "chinese"],
|
116 |
+
["不好意思我现在很忙", "chinese"],
|
117 |
+
["对不起, 不需要哈", "chinese"],
|
118 |
+
["u have got the wrong number", "english"],
|
119 |
+
["sure, thank a lot", "english"],
|
120 |
+
["please leave your message for 95688496", "english"],
|
121 |
+
["yes well", "english"],
|
122 |
+
["失礼の", "japanese"],
|
123 |
+
["ビートいう発表の後に、お名前とご用件をお話ください。", "japanese"],
|
124 |
+
["わかんない。", "japanese"],
|
125 |
+
["に出ることができません", "japanese"],
|
126 |
+
["À không phải em nha.", "vietnamese"],
|
127 |
+
["Dạ nhầm số rồi ạ?", "vietnamese"],
|
128 |
+
["Ừ, cảm ơn em nhá.", "vietnamese"],
|
129 |
+
["Không, chị không có tiền.", "vietnamese"],
|
130 |
+
],
|
131 |
+
examples_per_page=50,
|
132 |
+
title="Telemarketing Intent Classification",
|
133 |
+
description=description,
|
134 |
+
)
|
135 |
+
demo.launch()
|
136 |
+
|
137 |
+
return
|
138 |
+
|
139 |
+
|
140 |
+
if __name__ == '__main__':
|
141 |
+
main()
|
predict.py
ADDED
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import argparse
|
4 |
+
import json
|
5 |
+
|
6 |
+
from allennlp.data.tokenizers.pretrained_transformer_tokenizer import PretrainedTransformerTokenizer
|
7 |
+
from allennlp.data.token_indexers.single_id_token_indexer import SingleIdTokenIndexer
|
8 |
+
from allennlp.data.vocabulary import Vocabulary
|
9 |
+
from allennlp.modules.text_field_embedders.basic_text_field_embedder import BasicTextFieldEmbedder
|
10 |
+
from allennlp.modules.token_embedders.embedding import Embedding
|
11 |
+
from allennlp.modules.seq2vec_encoders.cnn_encoder import CnnEncoder
|
12 |
+
from allennlp.models.archival import archive_model, load_archive
|
13 |
+
from allennlp_models.rc.modules.seq2seq_encoders.stacked_self_attention import StackedSelfAttentionEncoder
|
14 |
+
from allennlp.predictors.predictor import Predictor
|
15 |
+
from allennlp.predictors.text_classifier import TextClassifierPredictor
|
16 |
+
import gradio as gr
|
17 |
+
import numpy as np
|
18 |
+
import pandas as pd
|
19 |
+
import torch
|
20 |
+
from tqdm import tqdm
|
21 |
+
|
22 |
+
from project_settings import project_path
|
23 |
+
from toolbox.allennlp_models.text_classifier.models.hierarchical_text_classifier import HierarchicalClassifier
|
24 |
+
from toolbox.allennlp_models.text_classifier.dataset_readers.hierarchical_classification_json import HierarchicalClassificationJsonReader
|
25 |
+
|
26 |
+
|
27 |
+
def get_args():
|
28 |
+
parser = argparse.ArgumentParser()
|
29 |
+
parser.add_argument(
|
30 |
+
"--excel_file",
|
31 |
+
default=r"D:\Users\tianx\PycharmProjects\telemarketing_intent\data\excel\telemarketing_intent_vi.xlsx",
|
32 |
+
type=str,
|
33 |
+
)
|
34 |
+
parser.add_argument(
|
35 |
+
"--archive_file",
|
36 |
+
default=(project_path / "trained_models/telemarketing_intent_classification_vi").as_posix(),
|
37 |
+
type=str
|
38 |
+
)
|
39 |
+
parser.add_argument(
|
40 |
+
"--predictor_name",
|
41 |
+
default="text_classifier",
|
42 |
+
type=str
|
43 |
+
)
|
44 |
+
parser.add_argument(
|
45 |
+
"--top_k",
|
46 |
+
default=10,
|
47 |
+
type=int
|
48 |
+
)
|
49 |
+
parser.add_argument(
|
50 |
+
"--output_file",
|
51 |
+
default="intent_top_k.jsonl",
|
52 |
+
type=str
|
53 |
+
)
|
54 |
+
args = parser.parse_args()
|
55 |
+
return args
|
56 |
+
|
57 |
+
|
58 |
+
def main():
|
59 |
+
args = get_args()
|
60 |
+
|
61 |
+
archive = load_archive(archive_file=args.archive_file)
|
62 |
+
predictor = Predictor.from_archive(archive, predictor_name=args.predictor_name)
|
63 |
+
|
64 |
+
df = pd.read_excel(args.excel_file)
|
65 |
+
|
66 |
+
with open(args.output_file, "w", encoding="utf-8") as f:
|
67 |
+
for i, row in tqdm(df.iterrows(), total=len(df)):
|
68 |
+
if i < 26976:
|
69 |
+
continue
|
70 |
+
|
71 |
+
source = row["source"]
|
72 |
+
text = row["text"]
|
73 |
+
label0 = row["label0"]
|
74 |
+
label1 = row["label1"]
|
75 |
+
selected = row["selected"]
|
76 |
+
checked = row["checked"]
|
77 |
+
|
78 |
+
if pd.isna(source) or source is None:
|
79 |
+
source = None
|
80 |
+
|
81 |
+
if pd.isna(text) or text is None:
|
82 |
+
continue
|
83 |
+
text = str(text)
|
84 |
+
|
85 |
+
if pd.isna(label0) or label0 is None:
|
86 |
+
label0 = None
|
87 |
+
|
88 |
+
if pd.isna(label1) or label1 is None:
|
89 |
+
label1 = None
|
90 |
+
|
91 |
+
if pd.isna(selected) or selected is None:
|
92 |
+
selected = None
|
93 |
+
else:
|
94 |
+
try:
|
95 |
+
selected = int(selected)
|
96 |
+
except Exception:
|
97 |
+
print(type(selected))
|
98 |
+
selected = None
|
99 |
+
|
100 |
+
if pd.isna(checked) or checked is None:
|
101 |
+
checked = None
|
102 |
+
else:
|
103 |
+
try:
|
104 |
+
checked = int(checked)
|
105 |
+
except Exception:
|
106 |
+
print(type(checked))
|
107 |
+
checked = None
|
108 |
+
|
109 |
+
# print(text)
|
110 |
+
json_dict = {'sentence': text}
|
111 |
+
outputs = predictor.predict_json(
|
112 |
+
json_dict
|
113 |
+
)
|
114 |
+
probs = outputs["probs"]
|
115 |
+
arg_idx = np.argsort(probs)
|
116 |
+
|
117 |
+
arg_idx_top_k = arg_idx[-10:]
|
118 |
+
label_top_k = [
|
119 |
+
predictor._model.vocab.get_token_from_index(index=idx, namespace="labels").split("_")[-1] for idx in arg_idx_top_k
|
120 |
+
]
|
121 |
+
prob_top_k = [
|
122 |
+
str(round(probs[idx], 5)) for idx in arg_idx_top_k
|
123 |
+
]
|
124 |
+
|
125 |
+
row_ = {
|
126 |
+
"source": source,
|
127 |
+
"text": text,
|
128 |
+
"label0": label0,
|
129 |
+
"label1": label1,
|
130 |
+
"selected": selected,
|
131 |
+
"checked": checked,
|
132 |
+
"predict_label_top_k": ";".join(list(reversed(label_top_k))),
|
133 |
+
"predict_prob_top_k": ";".join(list(reversed(prob_top_k)))
|
134 |
+
}
|
135 |
+
row_ = json.dumps(row_, ensure_ascii=False)
|
136 |
+
f.write("{}\n".format(row_))
|
137 |
+
|
138 |
+
return
|
139 |
+
|
140 |
+
|
141 |
+
if __name__ == '__main__':
|
142 |
+
main()
|
pretrained_models/bert-base-japanese/.gitattributes
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.bin.* filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.tar.gz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
pretrained_models/bert-base-japanese/README.md
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
language: ja
|
3 |
+
license: cc-by-sa-4.0
|
4 |
+
datasets:
|
5 |
+
- wikipedia
|
6 |
+
widget:
|
7 |
+
- text: 東北大学で[MASK]の研究をしています。
|
8 |
+
---
|
9 |
+
|
10 |
+
# BERT base Japanese (IPA dictionary)
|
11 |
+
|
12 |
+
This is a [BERT](https://github.com/google-research/bert) model pretrained on texts in the Japanese language.
|
13 |
+
|
14 |
+
This version of the model processes input texts with word-level tokenization based on the IPA dictionary, followed by the WordPiece subword tokenization.
|
15 |
+
|
16 |
+
The codes for the pretraining are available at [cl-tohoku/bert-japanese](https://github.com/cl-tohoku/bert-japanese/tree/v1.0).
|
17 |
+
|
18 |
+
## Model architecture
|
19 |
+
|
20 |
+
The model architecture is the same as the original BERT base model; 12 layers, 768 dimensions of hidden states, and 12 attention heads.
|
21 |
+
|
22 |
+
## Training Data
|
23 |
+
|
24 |
+
The model is trained on Japanese Wikipedia as of September 1, 2019.
|
25 |
+
To generate the training corpus, [WikiExtractor](https://github.com/attardi/wikiextractor) is used to extract plain texts from a dump file of Wikipedia articles.
|
26 |
+
The text files used for the training are 2.6GB in size, consisting of approximately 17M sentences.
|
27 |
+
|
28 |
+
## Tokenization
|
29 |
+
|
30 |
+
The texts are first tokenized by [MeCab](https://taku910.github.io/mecab/) morphological parser with the IPA dictionary and then split into subwords by the WordPiece algorithm.
|
31 |
+
The vocabulary size is 32000.
|
32 |
+
|
33 |
+
## Training
|
34 |
+
|
35 |
+
The model is trained with the same configuration as the original BERT; 512 tokens per instance, 256 instances per batch, and 1M training steps.
|
36 |
+
|
37 |
+
## Licenses
|
38 |
+
|
39 |
+
The pretrained models are distributed under the terms of the [Creative Commons Attribution-ShareAlike 3.0](https://creativecommons.org/licenses/by-sa/3.0/).
|
40 |
+
|
41 |
+
## Acknowledgments
|
42 |
+
|
43 |
+
For training models, we used Cloud TPUs provided by [TensorFlow Research Cloud](https://www.tensorflow.org/tfrc/) program.
|
pretrained_models/bert-base-japanese/config.json
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"architectures": [
|
3 |
+
"BertForMaskedLM"
|
4 |
+
],
|
5 |
+
"attention_probs_dropout_prob": 0.1,
|
6 |
+
"hidden_act": "gelu",
|
7 |
+
"hidden_dropout_prob": 0.1,
|
8 |
+
"hidden_size": 768,
|
9 |
+
"initializer_range": 0.02,
|
10 |
+
"intermediate_size": 3072,
|
11 |
+
"layer_norm_eps": 1e-12,
|
12 |
+
"max_position_embeddings": 512,
|
13 |
+
"model_type": "bert",
|
14 |
+
"num_attention_heads": 12,
|
15 |
+
"num_hidden_layers": 12,
|
16 |
+
"pad_token_id": 0,
|
17 |
+
"tokenizer_class": "BertJapaneseTokenizer",
|
18 |
+
"type_vocab_size": 2,
|
19 |
+
"vocab_size": 32000
|
20 |
+
}
|
pretrained_models/bert-base-japanese/tokenizer_config.json
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"do_lower_case": false,
|
3 |
+
"subword_tokenizer_type": "wordpiece",
|
4 |
+
"word_tokenizer_type": "mecab"
|
5 |
+
}
|
pretrained_models/bert-base-japanese/vocab.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
pretrained_models/bert-base-uncased/.gitattributes
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.bin.* filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.tar.gz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
11 |
+
model.safetensors filter=lfs diff=lfs merge=lfs -text
|
pretrained_models/bert-base-uncased/LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [yyyy] [name of copyright owner]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
pretrained_models/bert-base-uncased/README.md
ADDED
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
language: en
|
3 |
+
tags:
|
4 |
+
- exbert
|
5 |
+
license: apache-2.0
|
6 |
+
datasets:
|
7 |
+
- bookcorpus
|
8 |
+
- wikipedia
|
9 |
+
---
|
10 |
+
|
11 |
+
# BERT base model (uncased)
|
12 |
+
|
13 |
+
Pretrained model on English language using a masked language modeling (MLM) objective. It was introduced in
|
14 |
+
[this paper](https://arxiv.org/abs/1810.04805) and first released in
|
15 |
+
[this repository](https://github.com/google-research/bert). This model is uncased: it does not make a difference
|
16 |
+
between english and English.
|
17 |
+
|
18 |
+
Disclaimer: The team releasing BERT did not write a model card for this model so this model card has been written by
|
19 |
+
the Hugging Face team.
|
20 |
+
|
21 |
+
## Model description
|
22 |
+
|
23 |
+
BERT is a transformers model pretrained on a large corpus of English data in a self-supervised fashion. This means it
|
24 |
+
was pretrained on the raw texts only, with no humans labeling them in any way (which is why it can use lots of
|
25 |
+
publicly available data) with an automatic process to generate inputs and labels from those texts. More precisely, it
|
26 |
+
was pretrained with two objectives:
|
27 |
+
|
28 |
+
- Masked language modeling (MLM): taking a sentence, the model randomly masks 15% of the words in the input then run
|
29 |
+
the entire masked sentence through the model and has to predict the masked words. This is different from traditional
|
30 |
+
recurrent neural networks (RNNs) that usually see the words one after the other, or from autoregressive models like
|
31 |
+
GPT which internally masks the future tokens. It allows the model to learn a bidirectional representation of the
|
32 |
+
sentence.
|
33 |
+
- Next sentence prediction (NSP): the models concatenates two masked sentences as inputs during pretraining. Sometimes
|
34 |
+
they correspond to sentences that were next to each other in the original text, sometimes not. The model then has to
|
35 |
+
predict if the two sentences were following each other or not.
|
36 |
+
|
37 |
+
This way, the model learns an inner representation of the English language that can then be used to extract features
|
38 |
+
useful for downstream tasks: if you have a dataset of labeled sentences, for instance, you can train a standard
|
39 |
+
classifier using the features produced by the BERT model as inputs.
|
40 |
+
|
41 |
+
## Model variations
|
42 |
+
|
43 |
+
BERT has originally been released in base and large variations, for cased and uncased input text. The uncased models also strips out an accent markers.
|
44 |
+
Chinese and multilingual uncased and cased versions followed shortly after.
|
45 |
+
Modified preprocessing with whole word masking has replaced subpiece masking in a following work, with the release of two models.
|
46 |
+
Other 24 smaller models are released afterward.
|
47 |
+
|
48 |
+
The detailed release history can be found on the [google-research/bert readme](https://github.com/google-research/bert/blob/master/README.md) on github.
|
49 |
+
|
50 |
+
| Model | #params | Language |
|
51 |
+
|------------------------|--------------------------------|-------|
|
52 |
+
| [`bert-base-uncased`](https://huggingface.co/bert-base-uncased) | 110M | English |
|
53 |
+
| [`bert-large-uncased`](https://huggingface.co/bert-large-uncased) | 340M | English | sub
|
54 |
+
| [`bert-base-cased`](https://huggingface.co/bert-base-cased) | 110M | English |
|
55 |
+
| [`bert-large-cased`](https://huggingface.co/bert-large-cased) | 340M | English |
|
56 |
+
| [`bert-base-chinese`](https://huggingface.co/bert-base-chinese) | 110M | Chinese |
|
57 |
+
| [`bert-base-multilingual-cased`](https://huggingface.co/bert-base-multilingual-cased) | 110M | Multiple |
|
58 |
+
| [`bert-large-uncased-whole-word-masking`](https://huggingface.co/bert-large-uncased-whole-word-masking) | 340M | English |
|
59 |
+
| [`bert-large-cased-whole-word-masking`](https://huggingface.co/bert-large-cased-whole-word-masking) | 340M | English |
|
60 |
+
|
61 |
+
## Intended uses & limitations
|
62 |
+
|
63 |
+
You can use the raw model for either masked language modeling or next sentence prediction, but it's mostly intended to
|
64 |
+
be fine-tuned on a downstream task. See the [model hub](https://huggingface.co/models?filter=bert) to look for
|
65 |
+
fine-tuned versions of a task that interests you.
|
66 |
+
|
67 |
+
Note that this model is primarily aimed at being fine-tuned on tasks that use the whole sentence (potentially masked)
|
68 |
+
to make decisions, such as sequence classification, token classification or question answering. For tasks such as text
|
69 |
+
generation you should look at model like GPT2.
|
70 |
+
|
71 |
+
### How to use
|
72 |
+
|
73 |
+
You can use this model directly with a pipeline for masked language modeling:
|
74 |
+
|
75 |
+
```python
|
76 |
+
>>> from transformers import pipeline
|
77 |
+
>>> unmasker = pipeline('fill-mask', model='bert-base-uncased')
|
78 |
+
>>> unmasker("Hello I'm a [MASK] model.")
|
79 |
+
|
80 |
+
[{'sequence': "[CLS] hello i'm a fashion model. [SEP]",
|
81 |
+
'score': 0.1073106899857521,
|
82 |
+
'token': 4827,
|
83 |
+
'token_str': 'fashion'},
|
84 |
+
{'sequence': "[CLS] hello i'm a role model. [SEP]",
|
85 |
+
'score': 0.08774490654468536,
|
86 |
+
'token': 2535,
|
87 |
+
'token_str': 'role'},
|
88 |
+
{'sequence': "[CLS] hello i'm a new model. [SEP]",
|
89 |
+
'score': 0.05338378623127937,
|
90 |
+
'token': 2047,
|
91 |
+
'token_str': 'new'},
|
92 |
+
{'sequence': "[CLS] hello i'm a super model. [SEP]",
|
93 |
+
'score': 0.04667217284440994,
|
94 |
+
'token': 3565,
|
95 |
+
'token_str': 'super'},
|
96 |
+
{'sequence': "[CLS] hello i'm a fine model. [SEP]",
|
97 |
+
'score': 0.027095865458250046,
|
98 |
+
'token': 2986,
|
99 |
+
'token_str': 'fine'}]
|
100 |
+
```
|
101 |
+
|
102 |
+
Here is how to use this model to get the features of a given text in PyTorch:
|
103 |
+
|
104 |
+
```python
|
105 |
+
from transformers import BertTokenizer, BertModel
|
106 |
+
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
107 |
+
model = BertModel.from_pretrained("bert-base-uncased")
|
108 |
+
text = "Replace me by any text you'd like."
|
109 |
+
encoded_input = tokenizer(text, return_tensors='pt')
|
110 |
+
output = model(**encoded_input)
|
111 |
+
```
|
112 |
+
|
113 |
+
and in TensorFlow:
|
114 |
+
|
115 |
+
```python
|
116 |
+
from transformers import BertTokenizer, TFBertModel
|
117 |
+
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
118 |
+
model = TFBertModel.from_pretrained("bert-base-uncased")
|
119 |
+
text = "Replace me by any text you'd like."
|
120 |
+
encoded_input = tokenizer(text, return_tensors='tf')
|
121 |
+
output = model(encoded_input)
|
122 |
+
```
|
123 |
+
|
124 |
+
### Limitations and bias
|
125 |
+
|
126 |
+
Even if the training data used for this model could be characterized as fairly neutral, this model can have biased
|
127 |
+
predictions:
|
128 |
+
|
129 |
+
```python
|
130 |
+
>>> from transformers import pipeline
|
131 |
+
>>> unmasker = pipeline('fill-mask', model='bert-base-uncased')
|
132 |
+
>>> unmasker("The man worked as a [MASK].")
|
133 |
+
|
134 |
+
[{'sequence': '[CLS] the man worked as a carpenter. [SEP]',
|
135 |
+
'score': 0.09747550636529922,
|
136 |
+
'token': 10533,
|
137 |
+
'token_str': 'carpenter'},
|
138 |
+
{'sequence': '[CLS] the man worked as a waiter. [SEP]',
|
139 |
+
'score': 0.0523831807076931,
|
140 |
+
'token': 15610,
|
141 |
+
'token_str': 'waiter'},
|
142 |
+
{'sequence': '[CLS] the man worked as a barber. [SEP]',
|
143 |
+
'score': 0.04962705448269844,
|
144 |
+
'token': 13362,
|
145 |
+
'token_str': 'barber'},
|
146 |
+
{'sequence': '[CLS] the man worked as a mechanic. [SEP]',
|
147 |
+
'score': 0.03788609802722931,
|
148 |
+
'token': 15893,
|
149 |
+
'token_str': 'mechanic'},
|
150 |
+
{'sequence': '[CLS] the man worked as a salesman. [SEP]',
|
151 |
+
'score': 0.037680890411138535,
|
152 |
+
'token': 18968,
|
153 |
+
'token_str': 'salesman'}]
|
154 |
+
|
155 |
+
>>> unmasker("The woman worked as a [MASK].")
|
156 |
+
|
157 |
+
[{'sequence': '[CLS] the woman worked as a nurse. [SEP]',
|
158 |
+
'score': 0.21981462836265564,
|
159 |
+
'token': 6821,
|
160 |
+
'token_str': 'nurse'},
|
161 |
+
{'sequence': '[CLS] the woman worked as a waitress. [SEP]',
|
162 |
+
'score': 0.1597415804862976,
|
163 |
+
'token': 13877,
|
164 |
+
'token_str': 'waitress'},
|
165 |
+
{'sequence': '[CLS] the woman worked as a maid. [SEP]',
|
166 |
+
'score': 0.1154729500412941,
|
167 |
+
'token': 10850,
|
168 |
+
'token_str': 'maid'},
|
169 |
+
{'sequence': '[CLS] the woman worked as a prostitute. [SEP]',
|
170 |
+
'score': 0.037968918681144714,
|
171 |
+
'token': 19215,
|
172 |
+
'token_str': 'prostitute'},
|
173 |
+
{'sequence': '[CLS] the woman worked as a cook. [SEP]',
|
174 |
+
'score': 0.03042375110089779,
|
175 |
+
'token': 5660,
|
176 |
+
'token_str': 'cook'}]
|
177 |
+
```
|
178 |
+
|
179 |
+
This bias will also affect all fine-tuned versions of this model.
|
180 |
+
|
181 |
+
## Training data
|
182 |
+
|
183 |
+
The BERT model was pretrained on [BookCorpus](https://yknzhu.wixsite.com/mbweb), a dataset consisting of 11,038
|
184 |
+
unpublished books and [English Wikipedia](https://en.wikipedia.org/wiki/English_Wikipedia) (excluding lists, tables and
|
185 |
+
headers).
|
186 |
+
|
187 |
+
## Training procedure
|
188 |
+
|
189 |
+
### Preprocessing
|
190 |
+
|
191 |
+
The texts are lowercased and tokenized using WordPiece and a vocabulary size of 30,000. The inputs of the model are
|
192 |
+
then of the form:
|
193 |
+
|
194 |
+
```
|
195 |
+
[CLS] Sentence A [SEP] Sentence B [SEP]
|
196 |
+
```
|
197 |
+
|
198 |
+
With probability 0.5, sentence A and sentence B correspond to two consecutive sentences in the original corpus, and in
|
199 |
+
the other cases, it's another random sentence in the corpus. Note that what is considered a sentence here is a
|
200 |
+
consecutive span of text usually longer than a single sentence. The only constrain is that the result with the two
|
201 |
+
"sentences" has a combined length of less than 512 tokens.
|
202 |
+
|
203 |
+
The details of the masking procedure for each sentence are the following:
|
204 |
+
- 15% of the tokens are masked.
|
205 |
+
- In 80% of the cases, the masked tokens are replaced by `[MASK]`.
|
206 |
+
- In 10% of the cases, the masked tokens are replaced by a random token (different) from the one they replace.
|
207 |
+
- In the 10% remaining cases, the masked tokens are left as is.
|
208 |
+
|
209 |
+
### Pretraining
|
210 |
+
|
211 |
+
The model was trained on 4 cloud TPUs in Pod configuration (16 TPU chips total) for one million steps with a batch size
|
212 |
+
of 256. The sequence length was limited to 128 tokens for 90% of the steps and 512 for the remaining 10%. The optimizer
|
213 |
+
used is Adam with a learning rate of 1e-4, \\(\beta_{1} = 0.9\\) and \\(\beta_{2} = 0.999\\), a weight decay of 0.01,
|
214 |
+
learning rate warmup for 10,000 steps and linear decay of the learning rate after.
|
215 |
+
|
216 |
+
## Evaluation results
|
217 |
+
|
218 |
+
When fine-tuned on downstream tasks, this model achieves the following results:
|
219 |
+
|
220 |
+
Glue test results:
|
221 |
+
|
222 |
+
| Task | MNLI-(m/mm) | QQP | QNLI | SST-2 | CoLA | STS-B | MRPC | RTE | Average |
|
223 |
+
|:----:|:-----------:|:----:|:----:|:-----:|:----:|:-----:|:----:|:----:|:-------:|
|
224 |
+
| | 84.6/83.4 | 71.2 | 90.5 | 93.5 | 52.1 | 85.8 | 88.9 | 66.4 | 79.6 |
|
225 |
+
|
226 |
+
|
227 |
+
### BibTeX entry and citation info
|
228 |
+
|
229 |
+
```bibtex
|
230 |
+
@article{DBLP:journals/corr/abs-1810-04805,
|
231 |
+
author = {Jacob Devlin and
|
232 |
+
Ming{-}Wei Chang and
|
233 |
+
Kenton Lee and
|
234 |
+
Kristina Toutanova},
|
235 |
+
title = {{BERT:} Pre-training of Deep Bidirectional Transformers for Language
|
236 |
+
Understanding},
|
237 |
+
journal = {CoRR},
|
238 |
+
volume = {abs/1810.04805},
|
239 |
+
year = {2018},
|
240 |
+
url = {http://arxiv.org/abs/1810.04805},
|
241 |
+
archivePrefix = {arXiv},
|
242 |
+
eprint = {1810.04805},
|
243 |
+
timestamp = {Tue, 30 Oct 2018 20:39:56 +0100},
|
244 |
+
biburl = {https://dblp.org/rec/journals/corr/abs-1810-04805.bib},
|
245 |
+
bibsource = {dblp computer science bibliography, https://dblp.org}
|
246 |
+
}
|
247 |
+
```
|
248 |
+
|
249 |
+
<a href="https://huggingface.co/exbert/?model=bert-base-uncased">
|
250 |
+
<img width="300px" src="https://cdn-media.huggingface.co/exbert/button.png">
|
251 |
+
</a>
|
pretrained_models/bert-base-uncased/config.json
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"architectures": [
|
3 |
+
"BertForMaskedLM"
|
4 |
+
],
|
5 |
+
"attention_probs_dropout_prob": 0.1,
|
6 |
+
"gradient_checkpointing": false,
|
7 |
+
"hidden_act": "gelu",
|
8 |
+
"hidden_dropout_prob": 0.1,
|
9 |
+
"hidden_size": 768,
|
10 |
+
"initializer_range": 0.02,
|
11 |
+
"intermediate_size": 3072,
|
12 |
+
"layer_norm_eps": 1e-12,
|
13 |
+
"max_position_embeddings": 512,
|
14 |
+
"model_type": "bert",
|
15 |
+
"num_attention_heads": 12,
|
16 |
+
"num_hidden_layers": 12,
|
17 |
+
"pad_token_id": 0,
|
18 |
+
"position_embedding_type": "absolute",
|
19 |
+
"transformers_version": "4.6.0.dev0",
|
20 |
+
"type_vocab_size": 2,
|
21 |
+
"use_cache": true,
|
22 |
+
"vocab_size": 30522
|
23 |
+
}
|
pretrained_models/bert-base-uncased/tokenizer.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
pretrained_models/bert-base-uncased/tokenizer_config.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"do_lower_case": true
|
3 |
+
}
|
pretrained_models/bert-base-uncased/vocab.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
pretrained_models/bert-base-vietnamese-uncased/.gitattributes
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.bin.* filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.tar.gz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
pretrained_models/bert-base-vietnamese-uncased/README.md
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Usage
|
2 |
+
```python
|
3 |
+
from transformers import BertForSequenceClassification
|
4 |
+
from transformers import BertTokenizer
|
5 |
+
model = BertForSequenceClassification.from_pretrained("trituenhantaoio/bert-base-vietnamese-uncased")
|
6 |
+
tokenizer = BertTokenizer.from_pretrained("trituenhantaoio/bert-base-vietnamese-uncased")
|
7 |
+
```
|
8 |
+
|
9 |
+
### References
|
10 |
+
|
11 |
+
```
|
12 |
+
@article{ttnt2020bert,
|
13 |
+
title={Vietnamese BERT: Pretrained on News and Wiki},
|
14 |
+
author={trituenhantao.io},
|
15 |
+
year = {2020},
|
16 |
+
publisher = {GitHub},
|
17 |
+
journal = {GitHub repository},
|
18 |
+
howpublished = {\url{https://github.com/trituenhantaoio/vn-bert-base-uncased}},
|
19 |
+
}
|
20 |
+
```
|
21 |
+
|
22 |
+
[trituenhantao.io](https://trituenhantao.io)
|
pretrained_models/bert-base-vietnamese-uncased/config.json
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "/content/drive/My Drive/Colab_data/2020/vietnamese-bert-19-11-2020/vietnamese-bert-10-2020/bert_model/pytorch_model",
|
3 |
+
"attention_probs_dropout_prob": 0.1,
|
4 |
+
"directionality": "bidi",
|
5 |
+
"gradient_checkpointing": false,
|
6 |
+
"hidden_act": "gelu",
|
7 |
+
"hidden_dropout_prob": 0.1,
|
8 |
+
"hidden_size": 768,
|
9 |
+
"initializer_range": 0.02,
|
10 |
+
"intermediate_size": 3072,
|
11 |
+
"layer_norm_eps": 1e-12,
|
12 |
+
"max_position_embeddings": 512,
|
13 |
+
"model_type": "bert",
|
14 |
+
"num_attention_heads": 12,
|
15 |
+
"num_hidden_layers": 12,
|
16 |
+
"pad_token_id": 0,
|
17 |
+
"pooler_fc_size": 768,
|
18 |
+
"pooler_num_attention_heads": 12,
|
19 |
+
"pooler_num_fc_layers": 3,
|
20 |
+
"pooler_size_per_head": 128,
|
21 |
+
"pooler_type": "first_token_transform",
|
22 |
+
"position_embedding_type": "absolute",
|
23 |
+
"transformers_version": "4.2.2",
|
24 |
+
"type_vocab_size": 2,
|
25 |
+
"use_cache": true,
|
26 |
+
"vocab_size": 32000
|
27 |
+
}
|
pretrained_models/bert-base-vietnamese-uncased/special_tokens_map.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]"}
|
pretrained_models/bert-base-vietnamese-uncased/tokenizer_config.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"do_lower_case": true, "do_basic_tokenize": true, "never_split": null, "unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]", "tokenize_chinese_chars": true, "strip_accents": null}
|
pretrained_models/bert-base-vietnamese-uncased/vocab.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
pretrained_models/chinese-bert-wwm-ext/.gitattributes
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.bin.* filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.tar.gz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
pretrained_models/chinese-bert-wwm-ext/README.md
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
language:
|
3 |
+
- zh
|
4 |
+
license: "apache-2.0"
|
5 |
+
---
|
6 |
+
## Chinese BERT with Whole Word Masking
|
7 |
+
For further accelerating Chinese natural language processing, we provide **Chinese pre-trained BERT with Whole Word Masking**.
|
8 |
+
|
9 |
+
**[Pre-Training with Whole Word Masking for Chinese BERT](https://arxiv.org/abs/1906.08101)**
|
10 |
+
Yiming Cui, Wanxiang Che, Ting Liu, Bing Qin, Ziqing Yang, Shijin Wang, Guoping Hu
|
11 |
+
|
12 |
+
This repository is developed based on:https://github.com/google-research/bert
|
13 |
+
|
14 |
+
You may also interested in,
|
15 |
+
- Chinese BERT series: https://github.com/ymcui/Chinese-BERT-wwm
|
16 |
+
- Chinese MacBERT: https://github.com/ymcui/MacBERT
|
17 |
+
- Chinese ELECTRA: https://github.com/ymcui/Chinese-ELECTRA
|
18 |
+
- Chinese XLNet: https://github.com/ymcui/Chinese-XLNet
|
19 |
+
- Knowledge Distillation Toolkit - TextBrewer: https://github.com/airaria/TextBrewer
|
20 |
+
|
21 |
+
More resources by HFL: https://github.com/ymcui/HFL-Anthology
|
22 |
+
|
23 |
+
## Citation
|
24 |
+
If you find the technical report or resource is useful, please cite the following technical report in your paper.
|
25 |
+
- Primary: https://arxiv.org/abs/2004.13922
|
26 |
+
```
|
27 |
+
@inproceedings{cui-etal-2020-revisiting,
|
28 |
+
title = "Revisiting Pre-Trained Models for {C}hinese Natural Language Processing",
|
29 |
+
author = "Cui, Yiming and
|
30 |
+
Che, Wanxiang and
|
31 |
+
Liu, Ting and
|
32 |
+
Qin, Bing and
|
33 |
+
Wang, Shijin and
|
34 |
+
Hu, Guoping",
|
35 |
+
booktitle = "Proceedings of the 2020 Conference on Empirical Methods in Natural Language Processing: Findings",
|
36 |
+
month = nov,
|
37 |
+
year = "2020",
|
38 |
+
address = "Online",
|
39 |
+
publisher = "Association for Computational Linguistics",
|
40 |
+
url = "https://www.aclweb.org/anthology/2020.findings-emnlp.58",
|
41 |
+
pages = "657--668",
|
42 |
+
}
|
43 |
+
```
|
44 |
+
- Secondary: https://arxiv.org/abs/1906.08101
|
45 |
+
```
|
46 |
+
@article{chinese-bert-wwm,
|
47 |
+
title={Pre-Training with Whole Word Masking for Chinese BERT},
|
48 |
+
author={Cui, Yiming and Che, Wanxiang and Liu, Ting and Qin, Bing and Yang, Ziqing and Wang, Shijin and Hu, Guoping},
|
49 |
+
journal={arXiv preprint arXiv:1906.08101},
|
50 |
+
year={2019}
|
51 |
+
}
|
52 |
+
```
|
pretrained_models/chinese-bert-wwm-ext/added_tokens.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{}
|
pretrained_models/chinese-bert-wwm-ext/config.json
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"architectures": [
|
3 |
+
"BertForMaskedLM"
|
4 |
+
],
|
5 |
+
"attention_probs_dropout_prob": 0.1,
|
6 |
+
"directionality": "bidi",
|
7 |
+
"hidden_act": "gelu",
|
8 |
+
"hidden_dropout_prob": 0.1,
|
9 |
+
"hidden_size": 768,
|
10 |
+
"initializer_range": 0.02,
|
11 |
+
"intermediate_size": 3072,
|
12 |
+
"layer_norm_eps": 1e-12,
|
13 |
+
"max_position_embeddings": 512,
|
14 |
+
"model_type": "bert",
|
15 |
+
"num_attention_heads": 12,
|
16 |
+
"num_hidden_layers": 12,
|
17 |
+
"output_past": true,
|
18 |
+
"pad_token_id": 0,
|
19 |
+
"pooler_fc_size": 768,
|
20 |
+
"pooler_num_attention_heads": 12,
|
21 |
+
"pooler_num_fc_layers": 3,
|
22 |
+
"pooler_size_per_head": 128,
|
23 |
+
"pooler_type": "first_token_transform",
|
24 |
+
"type_vocab_size": 2,
|
25 |
+
"vocab_size": 21128
|
26 |
+
}
|
pretrained_models/chinese-bert-wwm-ext/special_tokens_map.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]"}
|
pretrained_models/chinese-bert-wwm-ext/tokenizer.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
pretrained_models/chinese-bert-wwm-ext/tokenizer_config.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"init_inputs": []}
|
pretrained_models/chinese-bert-wwm-ext/vocab.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
project_settings.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import os
|
4 |
+
from pathlib import Path
|
5 |
+
|
6 |
+
|
7 |
+
project_path = os.path.abspath(os.path.dirname(__file__))
|
8 |
+
project_path = Path(project_path)
|
9 |
+
|
10 |
+
|
11 |
+
if __name__ == '__main__':
|
12 |
+
pass
|
requirements.txt
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio==3.20.1
|
2 |
+
allennlp==2.10.1
|
3 |
+
allennlp-models==2.10.1
|
4 |
+
torch==1.12.1
|
5 |
+
overrides==7.3.1
|
6 |
+
pytorch_pretrained_bert==0.6.2
|
7 |
+
pydantic==1.10.12
|
8 |
+
thinc==7.4.6
|
9 |
+
spacy==2.3.9
|
10 |
+
fugashi==1.1.2
|
11 |
+
ipadic==1.0.0
|
12 |
+
pandas==2.0.3
|
13 |
+
xlrd==1.2.0
|
14 |
+
openpyxl==3.0.9
|
toolbox/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
|
5 |
+
if __name__ == '__main__':
|
6 |
+
pass
|
toolbox/allennlp_models/text_classifier/dataset_readers/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
|
5 |
+
if __name__ == '__main__':
|
6 |
+
pass
|
toolbox/allennlp_models/text_classifier/dataset_readers/hierarchical_classification_json.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
from typing import Dict, Iterable, List, Union
|
4 |
+
import logging
|
5 |
+
import json
|
6 |
+
from overrides import overrides
|
7 |
+
from allennlp.common.file_utils import cached_path
|
8 |
+
from allennlp.data.dataset_readers.dataset_reader import DatasetReader
|
9 |
+
from allennlp.data.fields import LabelField, TextField, Field, ListField
|
10 |
+
from allennlp.data.instance import Instance
|
11 |
+
from allennlp.data.token_indexers import TokenIndexer, SingleIdTokenIndexer
|
12 |
+
from allennlp.data.tokenizers import Tokenizer, SpacyTokenizer
|
13 |
+
from allennlp.data.tokenizers.sentence_splitter import SpacySentenceSplitter
|
14 |
+
|
15 |
+
logger = logging.getLogger(__name__)
|
16 |
+
|
17 |
+
|
18 |
+
@DatasetReader.register("hierarchical_classification_json")
|
19 |
+
class HierarchicalClassificationJsonReader(DatasetReader):
|
20 |
+
def __init__(self,
|
21 |
+
n_hierarchical: int = 2,
|
22 |
+
token_indexers: Dict[str, TokenIndexer] = None,
|
23 |
+
tokenizer: Tokenizer = None,
|
24 |
+
segment_sentences: bool = False,
|
25 |
+
max_sequence_length: int = None,
|
26 |
+
skip_label_indexing: bool = False,
|
27 |
+
**kwargs) -> None:
|
28 |
+
super().__init__(**kwargs)
|
29 |
+
self._n_hierarchical = n_hierarchical
|
30 |
+
self._tokenizer = tokenizer or SpacyTokenizer()
|
31 |
+
self._segment_sentences = segment_sentences
|
32 |
+
self._max_sequence_length = max_sequence_length
|
33 |
+
self._skip_label_indexing = skip_label_indexing
|
34 |
+
self._token_indexers = token_indexers or {'tokens': SingleIdTokenIndexer()}
|
35 |
+
if self._segment_sentences:
|
36 |
+
self._sentence_segmenter = SpacySentenceSplitter()
|
37 |
+
|
38 |
+
@overrides
|
39 |
+
def _read(self, file_path) -> Iterable[Instance]:
|
40 |
+
with open(cached_path(file_path), "r", encoding='utf-8') as data_file:
|
41 |
+
for line in data_file.readlines():
|
42 |
+
if not line:
|
43 |
+
continue
|
44 |
+
items = json.loads(line)
|
45 |
+
text = items["text"]
|
46 |
+
|
47 |
+
labels = [items.get("label{}".format(idx), None) for idx in range(self._n_hierarchical)]
|
48 |
+
if all(labels):
|
49 |
+
label = '_'.join(labels)
|
50 |
+
else:
|
51 |
+
label = None
|
52 |
+
|
53 |
+
if label is not None:
|
54 |
+
if self._skip_label_indexing:
|
55 |
+
try:
|
56 |
+
label = int(label)
|
57 |
+
except ValueError:
|
58 |
+
raise ValueError('Labels must be integers if skip_label_indexing is True.')
|
59 |
+
else:
|
60 |
+
label = str(label)
|
61 |
+
instance = self.text_to_instance(text=text, label=label)
|
62 |
+
if instance is not None:
|
63 |
+
yield instance
|
64 |
+
|
65 |
+
def _truncate(self, tokens):
|
66 |
+
if len(tokens) > self._max_sequence_length:
|
67 |
+
tokens = tokens[:self._max_sequence_length]
|
68 |
+
return tokens
|
69 |
+
|
70 |
+
@overrides
|
71 |
+
# def text_to_instance(self, text: str, label: Union[str, int] = None) -> Instance:
|
72 |
+
def text_to_instance(self, *inputs) -> Instance:
|
73 |
+
if len(inputs) == 1:
|
74 |
+
text = inputs[0]
|
75 |
+
label = None
|
76 |
+
elif len(inputs) == 2:
|
77 |
+
text, label = inputs
|
78 |
+
else:
|
79 |
+
raise AssertionError
|
80 |
+
|
81 |
+
fields: Dict[str, Field] = {}
|
82 |
+
if self._segment_sentences:
|
83 |
+
sentences: List[Field] = []
|
84 |
+
sentence_splits = self._sentence_segmenter.split_sentences(text)
|
85 |
+
for sentence in sentence_splits:
|
86 |
+
word_tokens = self._tokenizer.tokenize(sentence)
|
87 |
+
if self._max_sequence_length is not None:
|
88 |
+
word_tokens = self._truncate(word_tokens)
|
89 |
+
sentences.append(TextField(word_tokens, self._token_indexers))
|
90 |
+
fields['tokens'] = ListField(sentences)
|
91 |
+
else:
|
92 |
+
tokens = self._tokenizer.tokenize(text)
|
93 |
+
if self._max_sequence_length is not None:
|
94 |
+
tokens = self._truncate(tokens)
|
95 |
+
fields['tokens'] = TextField(tokens, self._token_indexers)
|
96 |
+
if label is not None:
|
97 |
+
fields['label'] = LabelField(label,
|
98 |
+
skip_indexing=self._skip_label_indexing)
|
99 |
+
return Instance(fields)
|
toolbox/allennlp_models/text_classifier/models/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
|
5 |
+
if __name__ == '__main__':
|
6 |
+
pass
|
toolbox/allennlp_models/text_classifier/models/hierarchical_text_classifier.py
ADDED
@@ -0,0 +1,291 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
from collections import OrderedDict
|
4 |
+
import pickle
|
5 |
+
from typing import Dict, Optional
|
6 |
+
|
7 |
+
from overrides import overrides
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
|
11 |
+
from allennlp.data import Vocabulary
|
12 |
+
from allennlp.models.model import Model
|
13 |
+
from allennlp.modules import Seq2SeqEncoder, Seq2VecEncoder, TextFieldEmbedder
|
14 |
+
from allennlp.nn import InitializerApplicator, RegularizerApplicator
|
15 |
+
from allennlp.nn.util import get_text_field_mask
|
16 |
+
from allennlp.training.metrics import CategoricalAccuracy
|
17 |
+
|
18 |
+
from toolbox.torch.modules.loss import FocalLoss, NegativeEntropy
|
19 |
+
|
20 |
+
|
21 |
+
@Model.register("hierarchical_classifier")
|
22 |
+
class HierarchicalClassifier(Model):
|
23 |
+
def __init__(self,
|
24 |
+
vocab: Vocabulary,
|
25 |
+
hierarchical_labels_pkl: str,
|
26 |
+
text_field_embedder: TextFieldEmbedder,
|
27 |
+
seq2vec_encoder: Seq2VecEncoder,
|
28 |
+
seq2seq_encoder: Seq2SeqEncoder = None,
|
29 |
+
dropout: float = None,
|
30 |
+
num_labels: int = None,
|
31 |
+
label_namespace: str = "labels",
|
32 |
+
balance_probs: bool = False,
|
33 |
+
initializer: InitializerApplicator = InitializerApplicator(),
|
34 |
+
regularizer: Optional[RegularizerApplicator] = None) -> None:
|
35 |
+
|
36 |
+
super().__init__(vocab, regularizer)
|
37 |
+
self._hierarchical_labels_pkl = hierarchical_labels_pkl
|
38 |
+
self._text_field_embedder = text_field_embedder
|
39 |
+
|
40 |
+
if seq2seq_encoder:
|
41 |
+
self._seq2seq_encoder = seq2seq_encoder
|
42 |
+
else:
|
43 |
+
self._seq2seq_encoder = None
|
44 |
+
|
45 |
+
self._seq2vec_encoder = seq2vec_encoder
|
46 |
+
self._classifier_input_dim = self._seq2vec_encoder.get_output_dim()
|
47 |
+
|
48 |
+
if dropout:
|
49 |
+
self._dropout = torch.nn.Dropout(dropout)
|
50 |
+
else:
|
51 |
+
self._dropout = None
|
52 |
+
|
53 |
+
self._label_namespace = label_namespace
|
54 |
+
|
55 |
+
if num_labels:
|
56 |
+
self._num_labels = num_labels
|
57 |
+
else:
|
58 |
+
self._num_labels = vocab.get_vocab_size(namespace=self._label_namespace)
|
59 |
+
|
60 |
+
with open(self._hierarchical_labels_pkl, 'rb') as f:
|
61 |
+
hierarchical_labels = pickle.load(f)
|
62 |
+
self._classification_layer = HierarchicalSoftMaxClassificationLayer(
|
63 |
+
classifier_input_dim=self._classifier_input_dim,
|
64 |
+
hierarchical_labels=hierarchical_labels,
|
65 |
+
activation='softmax',
|
66 |
+
)
|
67 |
+
|
68 |
+
self._accuracy = CategoricalAccuracy()
|
69 |
+
|
70 |
+
if balance_probs:
|
71 |
+
self._loss = NegativeEntropy(
|
72 |
+
inputs_logits=False,
|
73 |
+
)
|
74 |
+
else:
|
75 |
+
self._loss = FocalLoss(
|
76 |
+
num_classes=self._num_labels,
|
77 |
+
inputs_logits=False,
|
78 |
+
)
|
79 |
+
initializer(self)
|
80 |
+
|
81 |
+
def forward(self, # type: ignore
|
82 |
+
tokens: Dict[str, torch.LongTensor],
|
83 |
+
label: torch.IntTensor = None) -> Dict[str, torch.Tensor]:
|
84 |
+
embedded_text = self._text_field_embedder(tokens)
|
85 |
+
mask = get_text_field_mask(tokens)
|
86 |
+
|
87 |
+
if self._seq2seq_encoder:
|
88 |
+
embedded_text = self._seq2seq_encoder(embedded_text, mask=mask)
|
89 |
+
|
90 |
+
embedded_text = self._seq2vec_encoder(embedded_text, mask=mask)
|
91 |
+
|
92 |
+
if self._dropout:
|
93 |
+
embedded_text = self._dropout(embedded_text)
|
94 |
+
|
95 |
+
probs = self._classification_layer(embedded_text)
|
96 |
+
|
97 |
+
output_dict = {"probs": probs}
|
98 |
+
|
99 |
+
if label is not None:
|
100 |
+
loss = self._loss(probs, label.long().view(-1))
|
101 |
+
output_dict["loss"] = loss
|
102 |
+
self._accuracy(probs, label)
|
103 |
+
|
104 |
+
return output_dict
|
105 |
+
|
106 |
+
def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
107 |
+
predictions = output_dict["probs"]
|
108 |
+
predictions = torch.tensor(predictions)
|
109 |
+
if predictions.dim() == 2:
|
110 |
+
predictions_list = [predictions[i] for i in range(predictions.shape[0])]
|
111 |
+
else:
|
112 |
+
predictions_list = [predictions]
|
113 |
+
classes = list()
|
114 |
+
prob = list()
|
115 |
+
for prediction in predictions_list:
|
116 |
+
label_idx = prediction.argmax(dim=-1).item()
|
117 |
+
label_str = self.vocab.get_index_to_token_vocabulary(self._label_namespace).get(label_idx, str(label_idx))
|
118 |
+
classes.append(label_str)
|
119 |
+
prob.append(prediction[label_idx].item())
|
120 |
+
output_dict["label"] = classes
|
121 |
+
output_dict["prob"] = prob
|
122 |
+
return output_dict
|
123 |
+
|
124 |
+
def get_metrics(self, reset: bool = False) -> Dict[str, float]:
|
125 |
+
metrics = {'accuracy': self._accuracy.get_metric(reset)}
|
126 |
+
return metrics
|
127 |
+
|
128 |
+
|
129 |
+
class HierarchicalSoftMaxClassificationLayer(nn.Module):
|
130 |
+
"""多层 softmax 实现多极文本分类
|
131 |
+
|
132 |
+
由于初始化时, 各层 softmax 的概率趋于平衡.
|
133 |
+
|
134 |
+
因此在第一层时 `领域无关` 就分到了 50% 的概率.
|
135 |
+
|
136 |
+
`领域相关` 中的各类别去分剩下的 50% 的概率.
|
137 |
+
这会导致模型一开始时输出的类别全是 `领域无关`, 这导致模型无法优化.
|
138 |
+
|
139 |
+
解决方��:
|
140 |
+
1. 从数据集中去除 `领域无关` 数据. 并训练模型.
|
141 |
+
2. 等模型收敛之后, 再使用包含 `领域无关` 的数据集, 让模型加载之前的权重, 并重新开始训练模型.
|
142 |
+
|
143 |
+
"""
|
144 |
+
|
145 |
+
@staticmethod
|
146 |
+
def demo1():
|
147 |
+
# hierarchical_labels = OrderedDict({
|
148 |
+
# '领域相关': OrderedDict({
|
149 |
+
# '肯定答复': [
|
150 |
+
# '肯定(好的)', '肯定(可以)', '肯定(正确)'
|
151 |
+
# ],
|
152 |
+
# '否定答复': [
|
153 |
+
# '否定(不可以)', '否定(不知道)', '否定(错误)'
|
154 |
+
# ],
|
155 |
+
# '用户正忙': [
|
156 |
+
# '用户正忙'
|
157 |
+
# ]
|
158 |
+
# }),
|
159 |
+
# '领域无关': OrderedDict({
|
160 |
+
# '领域无关': [
|
161 |
+
# '领域无关'
|
162 |
+
# ]
|
163 |
+
# })
|
164 |
+
# })
|
165 |
+
|
166 |
+
hierarchical_labels = OrderedDict({
|
167 |
+
'领域相关': ['肯定答复', '否定答复', '用户正忙', '查联系方式'],
|
168 |
+
'领域无关': ['领域无关'],
|
169 |
+
})
|
170 |
+
|
171 |
+
softmax_layer = HierarchicalSoftMaxClassificationLayer(
|
172 |
+
classifier_input_dim=3,
|
173 |
+
hierarchical_labels=hierarchical_labels,
|
174 |
+
activation='softmax',
|
175 |
+
# activation='sigmoid',
|
176 |
+
|
177 |
+
)
|
178 |
+
|
179 |
+
for k, v in softmax_layer.__dict__['_modules'].items():
|
180 |
+
print(k)
|
181 |
+
print(v)
|
182 |
+
|
183 |
+
inputs = torch.ones(size=(2, 3), dtype=torch.float32)
|
184 |
+
|
185 |
+
probs = softmax_layer.forward(inputs)
|
186 |
+
print(probs)
|
187 |
+
print(torch.sum(probs, dim=-1))
|
188 |
+
return
|
189 |
+
|
190 |
+
def __init__(self, classifier_input_dim: int, hierarchical_labels: OrderedDict, activation: str = 'softmax'):
|
191 |
+
super(HierarchicalSoftMaxClassificationLayer, self).__init__()
|
192 |
+
self.classifier_input_dim = classifier_input_dim
|
193 |
+
self.hierarchical_labels = hierarchical_labels
|
194 |
+
self.activation: str = activation
|
195 |
+
|
196 |
+
self._init_hierarchical_classification_layer(hierarchical_labels)
|
197 |
+
|
198 |
+
def _init_hierarchical_classification_layer(self,
|
199 |
+
hierarchical_labels: OrderedDict,
|
200 |
+
key: str = 'classification_layer',
|
201 |
+
child_class: str = None):
|
202 |
+
num_labels = len(hierarchical_labels)
|
203 |
+
|
204 |
+
classification_layer = torch.nn.Linear(self.classifier_input_dim, num_labels)
|
205 |
+
if child_class is not None:
|
206 |
+
key = '{header}_{child_class}'.format(header=key, child_class=child_class)
|
207 |
+
setattr(
|
208 |
+
self,
|
209 |
+
key,
|
210 |
+
classification_layer
|
211 |
+
)
|
212 |
+
|
213 |
+
branch = 0
|
214 |
+
for k, v in hierarchical_labels.items():
|
215 |
+
if isinstance(v, OrderedDict):
|
216 |
+
self._init_hierarchical_classification_layer(
|
217 |
+
v,
|
218 |
+
key=key,
|
219 |
+
child_class=branch,
|
220 |
+
)
|
221 |
+
elif isinstance(v, list):
|
222 |
+
num_labels = len(v)
|
223 |
+
classification_layer = torch.nn.Linear(self.classifier_input_dim, num_labels)
|
224 |
+
setattr(
|
225 |
+
self,
|
226 |
+
'{key}_{child_class}'.format(key=key, child_class=branch),
|
227 |
+
classification_layer,
|
228 |
+
)
|
229 |
+
else:
|
230 |
+
raise NotImplementedError
|
231 |
+
branch += 1
|
232 |
+
return
|
233 |
+
|
234 |
+
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
235 |
+
key = 'classification_layer'
|
236 |
+
classification_layer = getattr(self, key)
|
237 |
+
logits = classification_layer.forward(inputs)
|
238 |
+
probs = torch.softmax(logits, dim=-1)
|
239 |
+
|
240 |
+
probs = self._layer_probs(
|
241 |
+
inputs=inputs,
|
242 |
+
probs=probs,
|
243 |
+
key=key,
|
244 |
+
)
|
245 |
+
|
246 |
+
return probs
|
247 |
+
|
248 |
+
def _layer_probs(self,
|
249 |
+
inputs: torch.Tensor,
|
250 |
+
probs: torch.Tensor,
|
251 |
+
key: str,
|
252 |
+
):
|
253 |
+
|
254 |
+
result = list()
|
255 |
+
for child_class in range(probs.shape[1]):
|
256 |
+
parent_probs = torch.unsqueeze(probs[:, child_class], dim=-1)
|
257 |
+
|
258 |
+
child_key = '{key}_{child_class}'.format(key=key, child_class=child_class)
|
259 |
+
classification_layer = getattr(self, child_key)
|
260 |
+
logits = classification_layer.forward(inputs)
|
261 |
+
|
262 |
+
child_child_key = '{key}_{child_class}'.format(key=child_key, child_class=0)
|
263 |
+
if hasattr(self, child_child_key):
|
264 |
+
child_probs = torch.softmax(logits, dim=-1)
|
265 |
+
child_probs = child_probs * parent_probs
|
266 |
+
|
267 |
+
child_probs = self._layer_probs(
|
268 |
+
inputs=inputs,
|
269 |
+
probs=child_probs,
|
270 |
+
key=child_key,
|
271 |
+
)
|
272 |
+
else:
|
273 |
+
if self.activation == 'softmax':
|
274 |
+
child_probs = torch.softmax(logits, dim=-1)
|
275 |
+
else:
|
276 |
+
child_probs = torch.sigmoid(logits)
|
277 |
+
child_probs = child_probs * parent_probs
|
278 |
+
|
279 |
+
result.append(child_probs)
|
280 |
+
|
281 |
+
result = torch.concat(result, dim=-1)
|
282 |
+
return result
|
283 |
+
|
284 |
+
|
285 |
+
def demo1():
|
286 |
+
HierarchicalSoftMaxClassificationLayer.demo1()
|
287 |
+
return
|
288 |
+
|
289 |
+
|
290 |
+
if __name__ == '__main__':
|
291 |
+
demo1()
|
toolbox/torch/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
|
5 |
+
if __name__ == '__main__':
|
6 |
+
pass
|
toolbox/torch/modules/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
|
5 |
+
if __name__ == '__main__':
|
6 |
+
pass
|
toolbox/torch/modules/loss.py
ADDED
@@ -0,0 +1,738 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import math
|
4 |
+
from typing import List, Optional
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
from torch.nn.modules.loss import _Loss
|
11 |
+
from torch.autograd import Variable
|
12 |
+
|
13 |
+
|
14 |
+
class ClassBalancedLoss(_Loss):
|
15 |
+
"""
|
16 |
+
https://arxiv.org/abs/1901.05555
|
17 |
+
"""
|
18 |
+
@staticmethod
|
19 |
+
def demo1():
|
20 |
+
batch_loss: torch.FloatTensor = torch.randn(size=(2, 1), dtype=torch.float32)
|
21 |
+
targets: torch.LongTensor = torch.tensor([1, 2], dtype=torch.long)
|
22 |
+
|
23 |
+
class_balanced_loss = ClassBalancedLoss(
|
24 |
+
num_classes=3,
|
25 |
+
num_samples_each_class=[300, 433, 50],
|
26 |
+
reduction='mean',
|
27 |
+
)
|
28 |
+
loss = class_balanced_loss.forward(batch_loss=batch_loss, targets=targets)
|
29 |
+
print(loss)
|
30 |
+
return
|
31 |
+
|
32 |
+
@staticmethod
|
33 |
+
def demo2():
|
34 |
+
inputs: torch.FloatTensor = torch.randn(size=(2, 3), dtype=torch.float32)
|
35 |
+
targets: torch.LongTensor = torch.tensor([1, 2], dtype=torch.long)
|
36 |
+
|
37 |
+
focal_loss = FocalLoss(
|
38 |
+
num_classes=3,
|
39 |
+
# reduction='mean',
|
40 |
+
# reduction='sum',
|
41 |
+
reduction='none',
|
42 |
+
)
|
43 |
+
batch_loss = focal_loss.forward(inputs, targets)
|
44 |
+
print(batch_loss)
|
45 |
+
|
46 |
+
class_balanced_loss = ClassBalancedLoss(
|
47 |
+
num_classes=3,
|
48 |
+
num_samples_each_class=[300, 433, 50],
|
49 |
+
reduction='mean',
|
50 |
+
)
|
51 |
+
loss = class_balanced_loss.forward(batch_loss=batch_loss, targets=targets)
|
52 |
+
print(loss)
|
53 |
+
|
54 |
+
return
|
55 |
+
|
56 |
+
def __init__(self,
|
57 |
+
num_classes: int,
|
58 |
+
num_samples_each_class: List[int],
|
59 |
+
beta: float = 0.999,
|
60 |
+
reduction: str = 'mean') -> None:
|
61 |
+
super(ClassBalancedLoss, self).__init__(None, None, reduction)
|
62 |
+
|
63 |
+
effective_num = 1.0 - np.power(beta, num_samples_each_class)
|
64 |
+
weights = (1.0 - beta) / np.array(effective_num)
|
65 |
+
self.weights = weights / np.sum(weights) * num_classes
|
66 |
+
|
67 |
+
def forward(self, batch_loss: torch.FloatTensor, targets: torch.LongTensor):
|
68 |
+
"""
|
69 |
+
:param batch_loss: shape=[batch_size, 1]
|
70 |
+
:param targets: shape=[batch_size,]
|
71 |
+
:return:
|
72 |
+
"""
|
73 |
+
weights = list()
|
74 |
+
targets = targets.numpy()
|
75 |
+
for target in targets:
|
76 |
+
weights.append([self.weights[target]])
|
77 |
+
|
78 |
+
weights = torch.tensor(weights, dtype=torch.float32)
|
79 |
+
batch_loss = weights * batch_loss
|
80 |
+
|
81 |
+
if self.reduction == 'mean':
|
82 |
+
loss = batch_loss.mean()
|
83 |
+
elif self.reduction == 'sum':
|
84 |
+
loss = batch_loss.sum()
|
85 |
+
else:
|
86 |
+
loss = batch_loss
|
87 |
+
return loss
|
88 |
+
|
89 |
+
|
90 |
+
class EqualizationLoss(_Loss):
|
91 |
+
"""
|
92 |
+
在图像识别中的, sigmoid 的多标签分类, 且 num_classes 类别数之外有一个 background 背景类别.
|
93 |
+
Equalization Loss
|
94 |
+
https://arxiv.org/abs/2003.05176
|
95 |
+
Equalization Loss v2
|
96 |
+
https://arxiv.org/abs/2012.08548
|
97 |
+
"""
|
98 |
+
|
99 |
+
@staticmethod
|
100 |
+
def demo1():
|
101 |
+
logits: torch.FloatTensor = torch.randn(size=(3, 3), dtype=torch.float32)
|
102 |
+
targets: torch.LongTensor = torch.tensor([1, 2, 3], dtype=torch.long)
|
103 |
+
|
104 |
+
equalization_loss = EqualizationLoss(
|
105 |
+
num_samples_each_class=[300, 433, 50],
|
106 |
+
threshold=100,
|
107 |
+
reduction='mean',
|
108 |
+
)
|
109 |
+
loss = equalization_loss.forward(logits=logits, targets=targets)
|
110 |
+
print(loss)
|
111 |
+
return
|
112 |
+
|
113 |
+
def __init__(self,
|
114 |
+
num_samples_each_class: List[int],
|
115 |
+
threshold: int = 100,
|
116 |
+
reduction: str = 'mean') -> None:
|
117 |
+
super(EqualizationLoss, self).__init__(None, None, reduction)
|
118 |
+
self.num_samples_each_class = np.array(num_samples_each_class, dtype=np.int32)
|
119 |
+
self.threshold = threshold
|
120 |
+
|
121 |
+
def forward(self,
|
122 |
+
logits: torch.FloatTensor,
|
123 |
+
targets: torch.LongTensor
|
124 |
+
):
|
125 |
+
"""
|
126 |
+
num_classes + 1 对应于背景类别 background.
|
127 |
+
:param logits: shape=[batch_size, num_classes]
|
128 |
+
:param targets: shape=[batch_size]
|
129 |
+
:return:
|
130 |
+
"""
|
131 |
+
batch_size, num_classes = logits.size()
|
132 |
+
|
133 |
+
one_hot_targets = F.one_hot(targets, num_classes=num_classes + 1)
|
134 |
+
one_hot_targets = one_hot_targets[:, :-1]
|
135 |
+
|
136 |
+
exclude = self.exclude_func(
|
137 |
+
num_classes=num_classes,
|
138 |
+
targets=targets
|
139 |
+
)
|
140 |
+
is_tail = self.threshold_func(
|
141 |
+
num_classes=num_classes,
|
142 |
+
num_samples_each_class=self.num_samples_each_class,
|
143 |
+
threshold=self.threshold,
|
144 |
+
)
|
145 |
+
|
146 |
+
weights = 1 - exclude * is_tail * (1 - one_hot_targets)
|
147 |
+
|
148 |
+
batch_loss = F.binary_cross_entropy_with_logits(
|
149 |
+
logits,
|
150 |
+
one_hot_targets.float(),
|
151 |
+
reduction='none'
|
152 |
+
)
|
153 |
+
|
154 |
+
batch_loss = weights * batch_loss
|
155 |
+
|
156 |
+
if self.reduction == 'mean':
|
157 |
+
loss = batch_loss.mean()
|
158 |
+
elif self.reduction == 'sum':
|
159 |
+
loss = batch_loss.sum()
|
160 |
+
else:
|
161 |
+
loss = batch_loss
|
162 |
+
|
163 |
+
loss = loss / num_classes
|
164 |
+
return loss
|
165 |
+
|
166 |
+
@staticmethod
|
167 |
+
def exclude_func(num_classes: int, targets: torch.LongTensor):
|
168 |
+
"""
|
169 |
+
最后一个类别是背景 background.
|
170 |
+
:param num_classes: int,
|
171 |
+
:param targets: shape=[batch_size,]
|
172 |
+
:return: weight, shape=[batch_size, num_classes]
|
173 |
+
"""
|
174 |
+
batch_size = targets.shape[0]
|
175 |
+
weight = (targets != num_classes).float()
|
176 |
+
weight = weight.view(batch_size, 1).expand(batch_size, num_classes)
|
177 |
+
return weight
|
178 |
+
|
179 |
+
@staticmethod
|
180 |
+
def threshold_func(num_classes: int, num_samples_each_class: np.ndarray, threshold: int):
|
181 |
+
"""
|
182 |
+
:param num_classes: int,
|
183 |
+
:param num_samples_each_class: shape=[num_classes]
|
184 |
+
:param threshold: int,
|
185 |
+
:return: weight, shape=[1, num_classes]
|
186 |
+
"""
|
187 |
+
weight = torch.zeros(size=(num_classes,))
|
188 |
+
weight[num_samples_each_class < threshold] = 1
|
189 |
+
weight = torch.unsqueeze(weight, dim=0)
|
190 |
+
return weight
|
191 |
+
|
192 |
+
|
193 |
+
class FocalLoss(_Loss):
|
194 |
+
"""
|
195 |
+
https://arxiv.org/abs/1708.02002
|
196 |
+
"""
|
197 |
+
@staticmethod
|
198 |
+
def demo1(self):
|
199 |
+
inputs: torch.FloatTensor = torch.randn(size=(2, 3), dtype=torch.float32)
|
200 |
+
targets: torch.LongTensor = torch.tensor([1, 2], dtype=torch.long)
|
201 |
+
|
202 |
+
focal_loss = FocalLoss(
|
203 |
+
num_classes=3,
|
204 |
+
reduction='mean',
|
205 |
+
# reduction='sum',
|
206 |
+
# reduction='none',
|
207 |
+
)
|
208 |
+
loss = focal_loss.forward(inputs, targets)
|
209 |
+
print(loss)
|
210 |
+
return
|
211 |
+
|
212 |
+
def __init__(self,
|
213 |
+
num_classes: int,
|
214 |
+
alpha: List[float] = None,
|
215 |
+
gamma: int = 2,
|
216 |
+
reduction: str = 'mean',
|
217 |
+
inputs_logits: bool = True) -> None:
|
218 |
+
"""
|
219 |
+
:param num_classes:
|
220 |
+
:param alpha:
|
221 |
+
:param gamma:
|
222 |
+
:param reduction: (`none`, `mean`, `sum`) available.
|
223 |
+
:param inputs_logits: if False, the inputs should be probs.
|
224 |
+
"""
|
225 |
+
super(FocalLoss, self).__init__(None, None, reduction)
|
226 |
+
if alpha is None:
|
227 |
+
self.alpha = torch.ones(num_classes, 1)
|
228 |
+
else:
|
229 |
+
self.alpha = torch.tensor(alpha, dtype=torch.float32)
|
230 |
+
self.gamma = gamma
|
231 |
+
self.num_classes = num_classes
|
232 |
+
self.inputs_logits = inputs_logits
|
233 |
+
|
234 |
+
def forward(self,
|
235 |
+
inputs: torch.FloatTensor,
|
236 |
+
targets: torch.LongTensor):
|
237 |
+
"""
|
238 |
+
:param inputs: logits, shape=[batch_size, num_classes]
|
239 |
+
:param targets: shape=[batch_size,]
|
240 |
+
:return:
|
241 |
+
"""
|
242 |
+
batch_size, num_classes = inputs.shape
|
243 |
+
|
244 |
+
if self.inputs_logits:
|
245 |
+
probs = F.softmax(inputs, dim=-1)
|
246 |
+
else:
|
247 |
+
probs = inputs
|
248 |
+
|
249 |
+
# class_mask = inputs.data.new(batch_size, num_classes).fill_(0)
|
250 |
+
class_mask = torch.zeros(size=(batch_size, num_classes), dtype=inputs.dtype, device=inputs.device)
|
251 |
+
# class_mask = Variable(class_mask)
|
252 |
+
ids = targets.view(-1, 1)
|
253 |
+
class_mask.scatter_(1, ids.data, 1.)
|
254 |
+
|
255 |
+
if inputs.is_cuda and not self.alpha.is_cuda:
|
256 |
+
self.alpha = self.alpha.cuda()
|
257 |
+
alpha = self.alpha[ids.data.view(-1)]
|
258 |
+
|
259 |
+
probs = (probs * class_mask).sum(1).view(-1, 1)
|
260 |
+
|
261 |
+
log_p = probs.log()
|
262 |
+
|
263 |
+
batch_loss = -alpha*(torch.pow((1-probs), self.gamma))*log_p
|
264 |
+
|
265 |
+
if self.reduction == 'mean':
|
266 |
+
loss = batch_loss.mean()
|
267 |
+
elif self.reduction == 'sum':
|
268 |
+
loss = batch_loss.sum()
|
269 |
+
else:
|
270 |
+
loss = batch_loss
|
271 |
+
return loss
|
272 |
+
|
273 |
+
|
274 |
+
class HingeLoss(_Loss):
|
275 |
+
@staticmethod
|
276 |
+
def demo1():
|
277 |
+
inputs: torch.FloatTensor = torch.randn(size=(2, 3), dtype=torch.float32)
|
278 |
+
targets: torch.LongTensor = torch.tensor([1, 2], dtype=torch.long)
|
279 |
+
|
280 |
+
hinge_loss = HingeLoss(
|
281 |
+
margin_list=[300, 433, 50],
|
282 |
+
reduction='mean',
|
283 |
+
)
|
284 |
+
loss = hinge_loss.forward(inputs=inputs, targets=targets)
|
285 |
+
print(loss)
|
286 |
+
return
|
287 |
+
|
288 |
+
def __init__(self,
|
289 |
+
margin_list: List[float],
|
290 |
+
max_margin: float = 0.5,
|
291 |
+
scale: float = 1.0,
|
292 |
+
weight: Optional[torch.Tensor] = None,
|
293 |
+
reduction: str = 'mean') -> None:
|
294 |
+
super(HingeLoss, self).__init__(None, None, reduction)
|
295 |
+
|
296 |
+
self.max_margin = max_margin
|
297 |
+
self.scale = scale
|
298 |
+
self.weight = weight
|
299 |
+
|
300 |
+
margin_list = np.array(margin_list)
|
301 |
+
margin_list = margin_list * (max_margin / np.max(margin_list))
|
302 |
+
self.margin_list = torch.tensor(margin_list, dtype=torch.float32)
|
303 |
+
|
304 |
+
def forward(self,
|
305 |
+
inputs: torch.FloatTensor,
|
306 |
+
targets: torch.LongTensor
|
307 |
+
):
|
308 |
+
"""
|
309 |
+
:param inputs: logits, shape=[batch_size, num_classes]
|
310 |
+
:param targets: shape=[batch_size,]
|
311 |
+
:return:
|
312 |
+
"""
|
313 |
+
batch_size, num_classes = inputs.shape
|
314 |
+
one_hot_targets = F.one_hot(targets, num_classes=num_classes)
|
315 |
+
margin_list = torch.unsqueeze(self.margin_list, dim=0)
|
316 |
+
|
317 |
+
batch_margin = torch.sum(margin_list * one_hot_targets, dim=-1)
|
318 |
+
batch_margin = torch.unsqueeze(batch_margin, dim=-1)
|
319 |
+
inputs_margin = inputs - batch_margin
|
320 |
+
|
321 |
+
# 将类别对应的 logits 值减小一点, 以形成 margin 边界.
|
322 |
+
logits = torch.where(one_hot_targets > 0, inputs_margin, inputs)
|
323 |
+
|
324 |
+
loss = F.cross_entropy(
|
325 |
+
input=self.scale * logits,
|
326 |
+
target=targets,
|
327 |
+
weight=self.weight,
|
328 |
+
reduction=self.reduction,
|
329 |
+
)
|
330 |
+
return loss
|
331 |
+
|
332 |
+
|
333 |
+
class HingeLinear(nn.Module):
|
334 |
+
"""
|
335 |
+
use this instead of `HingeLoss`, then you can combine it with `FocalLoss` or others.
|
336 |
+
"""
|
337 |
+
def __init__(self,
|
338 |
+
margin_list: List[float],
|
339 |
+
max_margin: float = 0.5,
|
340 |
+
scale: float = 1.0,
|
341 |
+
weight: Optional[torch.Tensor] = None
|
342 |
+
) -> None:
|
343 |
+
super(HingeLinear, self).__init__()
|
344 |
+
|
345 |
+
self.max_margin = max_margin
|
346 |
+
self.scale = scale
|
347 |
+
self.weight = weight
|
348 |
+
|
349 |
+
margin_list = np.array(margin_list)
|
350 |
+
margin_list = margin_list * (max_margin / np.max(margin_list))
|
351 |
+
self.margin_list = torch.tensor(margin_list, dtype=torch.float32)
|
352 |
+
|
353 |
+
def forward(self,
|
354 |
+
inputs: torch.FloatTensor,
|
355 |
+
targets: torch.LongTensor
|
356 |
+
):
|
357 |
+
"""
|
358 |
+
:param inputs: logits, shape=[batch_size, num_classes]
|
359 |
+
:param targets: shape=[batch_size,]
|
360 |
+
:return:
|
361 |
+
"""
|
362 |
+
if self.training and targets is not None:
|
363 |
+
batch_size, num_classes = inputs.shape
|
364 |
+
one_hot_targets = F.one_hot(targets, num_classes=num_classes)
|
365 |
+
margin_list = torch.unsqueeze(self.margin_list, dim=0)
|
366 |
+
|
367 |
+
batch_margin = torch.sum(margin_list * one_hot_targets, dim=-1)
|
368 |
+
batch_margin = torch.unsqueeze(batch_margin, dim=-1)
|
369 |
+
inputs_margin = inputs - batch_margin
|
370 |
+
|
371 |
+
# 将类别对应的 logits 值减小一点, 以形成 margin 边界.
|
372 |
+
logits = torch.where(one_hot_targets > 0, inputs_margin, inputs)
|
373 |
+
logits = logits * self.scale
|
374 |
+
else:
|
375 |
+
logits = inputs
|
376 |
+
return logits
|
377 |
+
|
378 |
+
|
379 |
+
class LDAMLoss(_Loss):
|
380 |
+
"""
|
381 |
+
https://arxiv.org/abs/1906.07413
|
382 |
+
"""
|
383 |
+
@staticmethod
|
384 |
+
def demo1():
|
385 |
+
inputs: torch.FloatTensor = torch.randn(size=(2, 3), dtype=torch.float32)
|
386 |
+
targets: torch.LongTensor = torch.tensor([1, 2], dtype=torch.long)
|
387 |
+
|
388 |
+
ldam_loss = LDAMLoss(
|
389 |
+
num_samples_each_class=[300, 433, 50],
|
390 |
+
reduction='mean',
|
391 |
+
)
|
392 |
+
loss = ldam_loss.forward(inputs=inputs, targets=targets)
|
393 |
+
print(loss)
|
394 |
+
return
|
395 |
+
|
396 |
+
def __init__(self,
|
397 |
+
num_samples_each_class: List[int],
|
398 |
+
max_margin: float = 0.5,
|
399 |
+
scale: float = 30.0,
|
400 |
+
weight: Optional[torch.Tensor] = None,
|
401 |
+
reduction: str = 'mean') -> None:
|
402 |
+
super(LDAMLoss, self).__init__(None, None, reduction)
|
403 |
+
|
404 |
+
margin_list = np.power(num_samples_each_class, -0.25)
|
405 |
+
margin_list = margin_list * (max_margin / np.max(margin_list))
|
406 |
+
|
407 |
+
self.num_samples_each_class = num_samples_each_class
|
408 |
+
self.margin_list = torch.tensor(margin_list, dtype=torch.float32)
|
409 |
+
self.scale = scale
|
410 |
+
self.weight = weight
|
411 |
+
|
412 |
+
def forward(self,
|
413 |
+
inputs: torch.FloatTensor,
|
414 |
+
targets: torch.LongTensor
|
415 |
+
):
|
416 |
+
"""
|
417 |
+
:param inputs: logits, shape=[batch_size, num_classes]
|
418 |
+
:param targets: shape=[batch_size,]
|
419 |
+
:return:
|
420 |
+
"""
|
421 |
+
batch_size, num_classes = inputs.shape
|
422 |
+
one_hot_targets = F.one_hot(targets, num_classes=num_classes)
|
423 |
+
margin_list = torch.unsqueeze(self.margin_list, dim=0)
|
424 |
+
|
425 |
+
batch_margin = torch.sum(margin_list * one_hot_targets, dim=-1)
|
426 |
+
batch_margin = torch.unsqueeze(batch_margin, dim=-1)
|
427 |
+
inputs_margin = inputs - batch_margin
|
428 |
+
|
429 |
+
# 将类别对应的 logits 值减小一点, 以形成 margin 边界.
|
430 |
+
logits = torch.where(one_hot_targets > 0, inputs_margin, inputs)
|
431 |
+
|
432 |
+
loss = F.cross_entropy(
|
433 |
+
input=self.scale * logits,
|
434 |
+
target=targets,
|
435 |
+
weight=self.weight,
|
436 |
+
reduction=self.reduction,
|
437 |
+
)
|
438 |
+
return loss
|
439 |
+
|
440 |
+
|
441 |
+
class NegativeEntropy(_Loss):
|
442 |
+
def __init__(self,
|
443 |
+
reduction: str = 'mean',
|
444 |
+
inputs_logits: bool = True) -> None:
|
445 |
+
super(NegativeEntropy, self).__init__(None, None, reduction)
|
446 |
+
self.inputs_logits = inputs_logits
|
447 |
+
|
448 |
+
def forward(self,
|
449 |
+
inputs: torch.FloatTensor,
|
450 |
+
targets: torch.LongTensor):
|
451 |
+
if self.inputs_logits:
|
452 |
+
probs = F.softmax(inputs, dim=-1)
|
453 |
+
log_probs = torch.nn.functional.log_softmax(probs, dim=-1)
|
454 |
+
else:
|
455 |
+
probs = inputs
|
456 |
+
log_probs = torch.log(probs)
|
457 |
+
|
458 |
+
weighted_negative_likelihood = - log_probs * probs
|
459 |
+
|
460 |
+
loss = - weighted_negative_likelihood.sum()
|
461 |
+
return loss
|
462 |
+
|
463 |
+
|
464 |
+
class LargeMarginSoftMaxLoss(_Loss):
|
465 |
+
"""
|
466 |
+
Alias: L-Softmax
|
467 |
+
|
468 |
+
https://arxiv.org/abs/1612.02295
|
469 |
+
https://github.com/wy1iu/LargeMargin_Softmax_Loss
|
470 |
+
https://github.com/amirhfarzaneh/lsoftmax-pytorch/blob/master/lsoftmax.py
|
471 |
+
|
472 |
+
参考链接:
|
473 |
+
https://www.jianshu.com/p/06cc3f84aa85
|
474 |
+
|
475 |
+
论文认为, softmax 和 cross entropy 的组合, 没有明确鼓励对特征进行判别学习.
|
476 |
+
|
477 |
+
"""
|
478 |
+
def __init__(self,
|
479 |
+
reduction: str = 'mean') -> None:
|
480 |
+
super(LargeMarginSoftMaxLoss, self).__init__(None, None, reduction)
|
481 |
+
|
482 |
+
|
483 |
+
class AngularSoftMaxLoss(_Loss):
|
484 |
+
"""
|
485 |
+
Alias: A-Softmax
|
486 |
+
|
487 |
+
https://arxiv.org/abs/1704.08063
|
488 |
+
|
489 |
+
https://github.com/woshildh/a-softmax_pytorch/blob/master/a_softmax.py
|
490 |
+
|
491 |
+
参考链接:
|
492 |
+
https://www.jianshu.com/p/06cc3f84aa85
|
493 |
+
|
494 |
+
好像作者认为人脸是一个球面, 所以将向量转换到一个球面上是有帮助的.
|
495 |
+
"""
|
496 |
+
def __init__(self,
|
497 |
+
reduction: str = 'mean') -> None:
|
498 |
+
super(AngularSoftMaxLoss, self).__init__(None, None, reduction)
|
499 |
+
|
500 |
+
|
501 |
+
class AdditiveMarginSoftMax(_Loss):
|
502 |
+
"""
|
503 |
+
Alias: AM-Softmax
|
504 |
+
|
505 |
+
https://arxiv.org/abs/1801.05599
|
506 |
+
|
507 |
+
Large Margin Cosine Loss
|
508 |
+
https://arxiv.org/abs/1801.09414
|
509 |
+
|
510 |
+
参考链接:
|
511 |
+
https://www.jianshu.com/p/06cc3f84aa85
|
512 |
+
|
513 |
+
说明:
|
514 |
+
相对于普通的 对 logits 做 softmax,
|
515 |
+
它将真实标签对应的 logit 值减去 m, 来让模型它该值调整得更大一些.
|
516 |
+
另外, 它还将每个 logits 乘以 s, 这可以控制各 logits 之间的相对大小.
|
517 |
+
根 HingeLoss 有点像.
|
518 |
+
"""
|
519 |
+
def __init__(self,
|
520 |
+
reduction: str = 'mean') -> None:
|
521 |
+
super(AdditiveMarginSoftMax, self).__init__(None, None, reduction)
|
522 |
+
|
523 |
+
|
524 |
+
class AdditiveAngularMarginSoftMax(_Loss):
|
525 |
+
"""
|
526 |
+
Alias: ArcFace, AAM-Softmax
|
527 |
+
|
528 |
+
ArcFace: Additive Angular Margin Loss for Deep Face Recognition
|
529 |
+
https://arxiv.org/abs/1801.07698
|
530 |
+
|
531 |
+
参考代码:
|
532 |
+
https://github.com/huangkeju/AAMSoftmax-OpenMax/blob/main/AAMSoftmax%2BOvA/metrics.py
|
533 |
+
|
534 |
+
"""
|
535 |
+
@staticmethod
|
536 |
+
def demo1():
|
537 |
+
"""
|
538 |
+
角度与数值转换
|
539 |
+
pi / 180 代表 1 度,
|
540 |
+
pi / 180 = 0.01745
|
541 |
+
"""
|
542 |
+
|
543 |
+
# 度数转数值
|
544 |
+
degree = 10
|
545 |
+
result = degree * math.pi / 180
|
546 |
+
print(result)
|
547 |
+
|
548 |
+
# 数值转数度
|
549 |
+
radian = 0.2
|
550 |
+
result = radian / (math.pi / 180)
|
551 |
+
print(result)
|
552 |
+
|
553 |
+
return
|
554 |
+
|
555 |
+
def __init__(self,
|
556 |
+
hidden_size: int,
|
557 |
+
num_labels: int,
|
558 |
+
margin: float = 0.2,
|
559 |
+
scale: float = 10.0,
|
560 |
+
):
|
561 |
+
"""
|
562 |
+
:param hidden_size:
|
563 |
+
:param num_labels:
|
564 |
+
:param margin: 建议取值角度为 [10, 30], 对应的数值为 [0.1745, 0.5236]
|
565 |
+
:param scale:
|
566 |
+
"""
|
567 |
+
super(AdditiveAngularMarginSoftMax, self).__init__()
|
568 |
+
self.margin = margin
|
569 |
+
self.scale = scale
|
570 |
+
self.weight = torch.nn.Parameter(torch.FloatTensor(num_labels, hidden_size), requires_grad=True)
|
571 |
+
nn.init.xavier_uniform_(self.weight)
|
572 |
+
|
573 |
+
self.cos_margin = math.cos(self.margin)
|
574 |
+
self.sin_margin = math.sin(self.margin)
|
575 |
+
|
576 |
+
# sin(a-b) = sin(a)cos(b) - cos(a)sin(b)
|
577 |
+
# sin(pi - a) = sin(a)
|
578 |
+
|
579 |
+
self.loss = nn.CrossEntropyLoss()
|
580 |
+
|
581 |
+
def forward(self,
|
582 |
+
inputs: torch.Tensor,
|
583 |
+
label: torch.LongTensor = None
|
584 |
+
):
|
585 |
+
"""
|
586 |
+
:param inputs: shape=[batch_size, ..., hidden_size]
|
587 |
+
:param label:
|
588 |
+
:return: logits
|
589 |
+
"""
|
590 |
+
x = F.normalize(inputs)
|
591 |
+
weight = F.normalize(self.weight)
|
592 |
+
cosine = F.linear(x, weight)
|
593 |
+
|
594 |
+
if self.training:
|
595 |
+
|
596 |
+
# sin^2 + cos^2 = 1
|
597 |
+
sine = torch.sqrt((1.0 - torch.mul(cosine, cosine)).clamp(0, 1))
|
598 |
+
|
599 |
+
# cos(a+b) = cos(a)cos(b) - sin(a)sin(b)
|
600 |
+
cosine_theta_margin = cosine * self.cos_margin - sine * self.sin_margin
|
601 |
+
|
602 |
+
# when the `cosine > - self.cos_margin` there is enough space to add margin on theta.
|
603 |
+
cosine_theta_margin = torch.where(cosine > - self.cos_margin, cosine_theta_margin, cosine - (self.margin * self.sin_margin))
|
604 |
+
|
605 |
+
one_hot = torch.zeros_like(cosine)
|
606 |
+
one_hot.scatter_(1, label.view(-1, 1), 1)
|
607 |
+
|
608 |
+
#
|
609 |
+
logits = torch.where(one_hot == 1, cosine_theta_margin, cosine)
|
610 |
+
logits = logits * self.scale
|
611 |
+
else:
|
612 |
+
logits = cosine
|
613 |
+
|
614 |
+
loss = self.loss(logits, label)
|
615 |
+
# prec1 = accuracy(output.detach(), label.detach(), topk=(1,))[0]
|
616 |
+
return loss
|
617 |
+
|
618 |
+
|
619 |
+
class AdditiveAngularMarginLinear(nn.Module):
|
620 |
+
"""
|
621 |
+
Alias: ArcFace, AAM-Softmax
|
622 |
+
|
623 |
+
ArcFace: Additive Angular Margin Loss for Deep Face Recognition
|
624 |
+
https://arxiv.org/abs/1801.07698
|
625 |
+
|
626 |
+
参考代码:
|
627 |
+
https://github.com/huangkeju/AAMSoftmax-OpenMax/blob/main/AAMSoftmax%2BOvA/metrics.py
|
628 |
+
|
629 |
+
"""
|
630 |
+
@staticmethod
|
631 |
+
def demo1():
|
632 |
+
"""
|
633 |
+
角度与数值转换
|
634 |
+
pi / 180 代表 1 度,
|
635 |
+
pi / 180 = 0.01745
|
636 |
+
"""
|
637 |
+
|
638 |
+
# 度数转数值
|
639 |
+
degree = 10
|
640 |
+
result = degree * math.pi / 180
|
641 |
+
print(result)
|
642 |
+
|
643 |
+
# 数值转数度
|
644 |
+
radian = 0.2
|
645 |
+
result = radian / (math.pi / 180)
|
646 |
+
print(result)
|
647 |
+
|
648 |
+
return
|
649 |
+
|
650 |
+
@staticmethod
|
651 |
+
def demo2():
|
652 |
+
|
653 |
+
return
|
654 |
+
|
655 |
+
def __init__(self,
|
656 |
+
hidden_size: int,
|
657 |
+
num_labels: int,
|
658 |
+
margin: float = 0.2,
|
659 |
+
scale: float = 10.0,
|
660 |
+
):
|
661 |
+
"""
|
662 |
+
:param hidden_size:
|
663 |
+
:param num_labels:
|
664 |
+
:param margin: 建议取值角度为 [10, 30], 对应的数值为 [0.1745, 0.5236]
|
665 |
+
:param scale:
|
666 |
+
"""
|
667 |
+
super(AdditiveAngularMarginLinear, self).__init__()
|
668 |
+
self.margin = margin
|
669 |
+
self.scale = scale
|
670 |
+
self.weight = torch.nn.Parameter(torch.FloatTensor(num_labels, hidden_size), requires_grad=True)
|
671 |
+
nn.init.xavier_uniform_(self.weight)
|
672 |
+
|
673 |
+
self.cos_margin = math.cos(self.margin)
|
674 |
+
self.sin_margin = math.sin(self.margin)
|
675 |
+
|
676 |
+
# sin(a-b) = sin(a)cos(b) - cos(a)sin(b)
|
677 |
+
# sin(pi - a) = sin(a)
|
678 |
+
|
679 |
+
def forward(self,
|
680 |
+
inputs: torch.Tensor,
|
681 |
+
targets: torch.LongTensor = None
|
682 |
+
):
|
683 |
+
"""
|
684 |
+
:param inputs: shape=[batch_size, ..., hidden_size]
|
685 |
+
:param targets:
|
686 |
+
:return: logits
|
687 |
+
"""
|
688 |
+
x = F.normalize(inputs)
|
689 |
+
weight = F.normalize(self.weight)
|
690 |
+
cosine = F.linear(x, weight)
|
691 |
+
|
692 |
+
if self.training and targets is not None:
|
693 |
+
# sin^2 + cos^2 = 1
|
694 |
+
sine = torch.sqrt((1.0 - torch.mul(cosine, cosine)).clamp(0, 1))
|
695 |
+
|
696 |
+
# cos(a+b) = cos(a)cos(b) - sin(a)sin(b)
|
697 |
+
cosine_theta_margin = cosine * self.cos_margin - sine * self.sin_margin
|
698 |
+
|
699 |
+
# when the `cosine > - self.cos_margin` there is enough space to add margin on theta.
|
700 |
+
cosine_theta_margin = torch.where(cosine > - self.cos_margin, cosine_theta_margin, cosine - (self.margin * self.sin_margin))
|
701 |
+
|
702 |
+
one_hot = torch.zeros_like(cosine)
|
703 |
+
one_hot.scatter_(1, targets.view(-1, 1), 1)
|
704 |
+
|
705 |
+
logits = torch.where(one_hot == 1, cosine_theta_margin, cosine)
|
706 |
+
logits = logits * self.scale
|
707 |
+
else:
|
708 |
+
logits = cosine
|
709 |
+
return logits
|
710 |
+
|
711 |
+
|
712 |
+
def demo1():
|
713 |
+
HingeLoss.demo1()
|
714 |
+
return
|
715 |
+
|
716 |
+
|
717 |
+
def demo2():
|
718 |
+
AdditiveAngularMarginSoftMax.demo1()
|
719 |
+
|
720 |
+
inputs = torch.ones(size=(2, 5), dtype=torch.float32)
|
721 |
+
label: torch.LongTensor = torch.tensor(data=[0, 1], dtype=torch.long)
|
722 |
+
|
723 |
+
aam_softmax = AdditiveAngularMarginSoftMax(
|
724 |
+
hidden_size=5,
|
725 |
+
num_labels=2,
|
726 |
+
margin=1,
|
727 |
+
scale=1
|
728 |
+
)
|
729 |
+
|
730 |
+
outputs = aam_softmax.forward(inputs, label)
|
731 |
+
print(outputs)
|
732 |
+
|
733 |
+
return
|
734 |
+
|
735 |
+
|
736 |
+
if __name__ == '__main__':
|
737 |
+
# demo1()
|
738 |
+
demo2()
|