luciusssss
commited on
Commit
•
a48216a
1
Parent(s):
36a788d
Upload 22 files
Browse files- .gitattributes +1 -0
- client.py +25 -0
- data/labels2id.pkl +3 -0
- models.py +36 -0
- preprocessors.py +48 -0
- pretrained_models/ELECT +3 -0
- pretrained_models/chinese-roberta-wwm-ext/added_tokens.json +1 -0
- pretrained_models/chinese-roberta-wwm-ext/config.json +28 -0
- pretrained_models/chinese-roberta-wwm-ext/pytorch_model.bin +3 -0
- pretrained_models/chinese-roberta-wwm-ext/special_tokens_map.json +1 -0
- pretrained_models/chinese-roberta-wwm-ext/tokenizer.json +0 -0
- pretrained_models/chinese-roberta-wwm-ext/tokenizer_config.json +1 -0
- pretrained_models/chinese-roberta-wwm-ext/vocab.txt +0 -0
- pretrained_models/roberta_wwm_ext_hunyin_2epoch/README.md +55 -0
- pretrained_models/roberta_wwm_ext_hunyin_2epoch/config.json +43 -0
- pretrained_models/roberta_wwm_ext_hunyin_2epoch/pytorch_model.bin +3 -0
- pretrained_models/roberta_wwm_ext_hunyin_2epoch/special_tokens_map.json +7 -0
- pretrained_models/roberta_wwm_ext_hunyin_2epoch/tokenizer.json +0 -0
- pretrained_models/roberta_wwm_ext_hunyin_2epoch/tokenizer_config.json +13 -0
- pretrained_models/roberta_wwm_ext_hunyin_2epoch/vocab.txt +0 -0
- server.py +156 -0
- utils/__init__.py +2 -0
- utils/arg_parser.py +24 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
pretrained_models/ELECT filter=lfs diff=lfs merge=lfs -text
|
client.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import requests
|
3 |
+
import time
|
4 |
+
|
5 |
+
|
6 |
+
def json_send(data, url):
|
7 |
+
headers = {"Content-type": "application/json",
|
8 |
+
"Accept": "text/plain", "charset": "UTF-8"}
|
9 |
+
response = requests.post(url=url, headers=headers, data=json.dumps(data))
|
10 |
+
return json.loads(response.text)
|
11 |
+
|
12 |
+
|
13 |
+
if __name__ == "__main__":
|
14 |
+
url = 'http://127.0.0.1:9099/check_hunyin'
|
15 |
+
|
16 |
+
print("Start inference")
|
17 |
+
|
18 |
+
while True:
|
19 |
+
input_text = input("Enter text:").strip()
|
20 |
+
if len(input_text) == 0:
|
21 |
+
continue
|
22 |
+
data = {"input": input_text}
|
23 |
+
result = json_send(data, url)
|
24 |
+
print(result['output'])
|
25 |
+
|
data/labels2id.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:179f76b8b014524ca915315f6eab916a20b582d89016e15b36bbdc055f1790cd
|
3 |
+
size 54968
|
models.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from transformers import AutoModel,AutoTokenizer
|
5 |
+
|
6 |
+
class Elect(nn.Module):
|
7 |
+
def __init__(self,args,device):
|
8 |
+
super(Elect, self).__init__()
|
9 |
+
self.device = device
|
10 |
+
self.plm = AutoModel.from_pretrained(args.ckpt_dir)
|
11 |
+
self.hidden_size = self.plm.config.hidden_size
|
12 |
+
self.tokenizer = AutoTokenizer.from_pretrained(args.ckpt_dir)
|
13 |
+
self.clf = nn.Linear(self.hidden_size, len(args.labels))
|
14 |
+
self.dropout = nn.Dropout(0.3)
|
15 |
+
|
16 |
+
self.p2l = nn.Linear(self.hidden_size,256)
|
17 |
+
self.proj = nn.Linear(self.hidden_size*2,self.hidden_size)
|
18 |
+
self.l2a = nn.Linear(11,256)
|
19 |
+
|
20 |
+
self.la = nn.Parameter(torch.zeros(len(args.labels),self.hidden_size))
|
21 |
+
|
22 |
+
def forward(self, batch):
|
23 |
+
ids = batch['ids'].to(self.device, dtype=torch.long)
|
24 |
+
mask = batch['mask'].to(self.device, dtype=torch.long)
|
25 |
+
token_type_ids = batch['token_type_ids'].to(self.device, dtype=torch.long)
|
26 |
+
hidden_state = self.plm(input_ids=ids, attention_mask=mask)[0]
|
27 |
+
pooler = hidden_state[:, 0] # [batch_size, hidden_size]
|
28 |
+
pooler = self.dropout(pooler) # [batch_size, hidden_size]
|
29 |
+
|
30 |
+
attn = torch.softmax(pooler@(self.la.transpose(0,1)),dim=-1) # [batch_size, hidden_size]
|
31 |
+
art = attn@self.la # [batch_size, hidden_size]
|
32 |
+
oa = F.relu(self.proj(torch.cat([art, pooler],dim=-1))) # [batch_size, hidden_size]
|
33 |
+
|
34 |
+
output = self.clf(oa) # [batch_size, len(labels)]
|
35 |
+
|
36 |
+
return output
|
preprocessors.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import pickle as pkl
|
4 |
+
import numpy as np
|
5 |
+
from sklearn.preprocessing import MultiLabelBinarizer
|
6 |
+
|
7 |
+
class BasicPreprocessor(object):
|
8 |
+
def __init__(self, data_generator, tokenizer, args):
|
9 |
+
self.data_generator = data_generator
|
10 |
+
self.tokenizer = tokenizer
|
11 |
+
self.args = args
|
12 |
+
file_path = os.path.join(args.data_dir, args.data_file)
|
13 |
+
if file_path.endswith("pkl"):
|
14 |
+
with open(file_path, "rb") as f:
|
15 |
+
self.raw_data = pkl.load(f)
|
16 |
+
print(self.raw_data[0])
|
17 |
+
exit()
|
18 |
+
elif file_path.endswith("json"):
|
19 |
+
self.raw_data = json.load(open(file_path, "r", encoding="utf-8"))
|
20 |
+
self.shuffle()
|
21 |
+
|
22 |
+
self.mlb=MultiLabelBinarizer()
|
23 |
+
self.mlb.fit([args.labels])
|
24 |
+
|
25 |
+
def shuffle(self):
|
26 |
+
idx=np.arange(len(self.raw_data))
|
27 |
+
np.random.shuffle(idx)
|
28 |
+
self.raw_data=np.array(self.raw_data)[idx]
|
29 |
+
|
30 |
+
def process(self):
|
31 |
+
args = self.args
|
32 |
+
data_generator = self.data_generator
|
33 |
+
raw_data = self.raw_data
|
34 |
+
tokenizer = self.tokenizer
|
35 |
+
mlb = self.mlb
|
36 |
+
|
37 |
+
if args.test_only:
|
38 |
+
train_data = data_generator(raw_data[:1], tokenizer, mlb, 'test', args)
|
39 |
+
test_data = data_generator(raw_data, tokenizer, mlb, 'test', args)
|
40 |
+
return train_data, test_data
|
41 |
+
#只使用90%作为训练集,10%作为测试集,不使用验证集
|
42 |
+
train_data = data_generator(raw_data[:int(len(raw_data)*0.9)], tokenizer, mlb, 'train', args)
|
43 |
+
test_data = data_generator(raw_data[int(len(raw_data)*0.9):], tokenizer, mlb, 'test', args)
|
44 |
+
|
45 |
+
return train_data, test_data
|
46 |
+
|
47 |
+
|
48 |
+
|
pretrained_models/ELECT
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:acc44b4361b2a738336dce66dab399e54338f6100b900ddf1c654fd2d444b0ee
|
3 |
+
size 415790649
|
pretrained_models/chinese-roberta-wwm-ext/added_tokens.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{}
|
pretrained_models/chinese-roberta-wwm-ext/config.json
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"architectures": [
|
3 |
+
"BertForMaskedLM"
|
4 |
+
],
|
5 |
+
"attention_probs_dropout_prob": 0.1,
|
6 |
+
"bos_token_id": 0,
|
7 |
+
"directionality": "bidi",
|
8 |
+
"eos_token_id": 2,
|
9 |
+
"hidden_act": "gelu",
|
10 |
+
"hidden_dropout_prob": 0.1,
|
11 |
+
"hidden_size": 768,
|
12 |
+
"initializer_range": 0.02,
|
13 |
+
"intermediate_size": 3072,
|
14 |
+
"layer_norm_eps": 1e-12,
|
15 |
+
"max_position_embeddings": 512,
|
16 |
+
"model_type": "bert",
|
17 |
+
"num_attention_heads": 12,
|
18 |
+
"num_hidden_layers": 12,
|
19 |
+
"output_past": true,
|
20 |
+
"pad_token_id": 1,
|
21 |
+
"pooler_fc_size": 768,
|
22 |
+
"pooler_num_attention_heads": 12,
|
23 |
+
"pooler_num_fc_layers": 3,
|
24 |
+
"pooler_size_per_head": 128,
|
25 |
+
"pooler_type": "first_token_transform",
|
26 |
+
"type_vocab_size": 2,
|
27 |
+
"vocab_size": 21128
|
28 |
+
}
|
pretrained_models/chinese-roberta-wwm-ext/pytorch_model.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1ded5a5a1c7841dee6e47942f7b5bf2bcf6f73ff19197580f852f7f638f86b35
|
3 |
+
size 411578458
|
pretrained_models/chinese-roberta-wwm-ext/special_tokens_map.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]"}
|
pretrained_models/chinese-roberta-wwm-ext/tokenizer.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
pretrained_models/chinese-roberta-wwm-ext/tokenizer_config.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"init_inputs": []}
|
pretrained_models/chinese-roberta-wwm-ext/vocab.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
pretrained_models/roberta_wwm_ext_hunyin_2epoch/README.md
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
tags:
|
3 |
+
- generated_from_trainer
|
4 |
+
metrics:
|
5 |
+
- accuracy
|
6 |
+
model-index:
|
7 |
+
- name: roberta_wwm_ext_hunyin_2epoch
|
8 |
+
results: []
|
9 |
+
---
|
10 |
+
|
11 |
+
<!-- This model card has been generated automatically according to the information the Trainer had access to. You
|
12 |
+
should probably proofread and complete it, then remove this comment. -->
|
13 |
+
|
14 |
+
# roberta_wwm_ext_hunyin_2epoch
|
15 |
+
|
16 |
+
This model is a fine-tuned version of [/home/zhangc/law_related/law_telecom/PLMs/chinese-roberta-wwm-ext](https://huggingface.co//home/zhangc/law_related/law_telecom/PLMs/chinese-roberta-wwm-ext) on an unknown dataset.
|
17 |
+
It achieves the following results on the evaluation set:
|
18 |
+
- Loss: 0.0510
|
19 |
+
- Accuracy: 0.9881
|
20 |
+
|
21 |
+
## Model description
|
22 |
+
|
23 |
+
More information needed
|
24 |
+
|
25 |
+
## Intended uses & limitations
|
26 |
+
|
27 |
+
More information needed
|
28 |
+
|
29 |
+
## Training and evaluation data
|
30 |
+
|
31 |
+
More information needed
|
32 |
+
|
33 |
+
## Training procedure
|
34 |
+
|
35 |
+
### Training hyperparameters
|
36 |
+
|
37 |
+
The following hyperparameters were used during training:
|
38 |
+
- learning_rate: 2e-05
|
39 |
+
- train_batch_size: 32
|
40 |
+
- eval_batch_size: 8
|
41 |
+
- seed: 42
|
42 |
+
- optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-08
|
43 |
+
- lr_scheduler_type: linear
|
44 |
+
- num_epochs: 2.0
|
45 |
+
|
46 |
+
### Training results
|
47 |
+
|
48 |
+
|
49 |
+
|
50 |
+
### Framework versions
|
51 |
+
|
52 |
+
- Transformers 4.28.0.dev0
|
53 |
+
- Pytorch 1.13.1+cu117
|
54 |
+
- Datasets 2.10.1
|
55 |
+
- Tokenizers 0.13.2
|
pretrained_models/roberta_wwm_ext_hunyin_2epoch/config.json
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "/home/zhangc/law_related/law_telecom/PLMs/chinese-roberta-wwm-ext",
|
3 |
+
"architectures": [
|
4 |
+
"BertForSequenceClassification"
|
5 |
+
],
|
6 |
+
"attention_probs_dropout_prob": 0.1,
|
7 |
+
"bos_token_id": 0,
|
8 |
+
"classifier_dropout": null,
|
9 |
+
"directionality": "bidi",
|
10 |
+
"eos_token_id": 2,
|
11 |
+
"hidden_act": "gelu",
|
12 |
+
"hidden_dropout_prob": 0.1,
|
13 |
+
"hidden_size": 768,
|
14 |
+
"id2label": {
|
15 |
+
"0": false,
|
16 |
+
"1": true
|
17 |
+
},
|
18 |
+
"initializer_range": 0.02,
|
19 |
+
"intermediate_size": 3072,
|
20 |
+
"label2id": {
|
21 |
+
"false": 0,
|
22 |
+
"true": 1
|
23 |
+
},
|
24 |
+
"layer_norm_eps": 1e-12,
|
25 |
+
"max_position_embeddings": 512,
|
26 |
+
"model_type": "bert",
|
27 |
+
"num_attention_heads": 12,
|
28 |
+
"num_hidden_layers": 12,
|
29 |
+
"output_past": true,
|
30 |
+
"pad_token_id": 1,
|
31 |
+
"pooler_fc_size": 768,
|
32 |
+
"pooler_num_attention_heads": 12,
|
33 |
+
"pooler_num_fc_layers": 3,
|
34 |
+
"pooler_size_per_head": 128,
|
35 |
+
"pooler_type": "first_token_transform",
|
36 |
+
"position_embedding_type": "absolute",
|
37 |
+
"problem_type": "single_label_classification",
|
38 |
+
"torch_dtype": "float32",
|
39 |
+
"transformers_version": "4.28.0.dev0",
|
40 |
+
"type_vocab_size": 2,
|
41 |
+
"use_cache": true,
|
42 |
+
"vocab_size": 21128
|
43 |
+
}
|
pretrained_models/roberta_wwm_ext_hunyin_2epoch/pytorch_model.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:cd02e6af0b827ddf0cf89fe32850c1da32c1ce8f83e0157e2f2fb11a93b1a4f9
|
3 |
+
size 409149557
|
pretrained_models/roberta_wwm_ext_hunyin_2epoch/special_tokens_map.json
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cls_token": "[CLS]",
|
3 |
+
"mask_token": "[MASK]",
|
4 |
+
"pad_token": "[PAD]",
|
5 |
+
"sep_token": "[SEP]",
|
6 |
+
"unk_token": "[UNK]"
|
7 |
+
}
|
pretrained_models/roberta_wwm_ext_hunyin_2epoch/tokenizer.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
pretrained_models/roberta_wwm_ext_hunyin_2epoch/tokenizer_config.json
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cls_token": "[CLS]",
|
3 |
+
"do_lower_case": true,
|
4 |
+
"mask_token": "[MASK]",
|
5 |
+
"model_max_length": 1000000000000000019884624838656,
|
6 |
+
"pad_token": "[PAD]",
|
7 |
+
"sep_token": "[SEP]",
|
8 |
+
"special_tokens_map_file": "/home/zhangc/law_related/law_telecom/PLMs/chinese-roberta-wwm-ext/special_tokens_map.json",
|
9 |
+
"strip_accents": null,
|
10 |
+
"tokenize_chinese_chars": true,
|
11 |
+
"tokenizer_class": "BertTokenizer",
|
12 |
+
"unk_token": "[UNK]"
|
13 |
+
}
|
pretrained_models/roberta_wwm_ext_hunyin_2epoch/vocab.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
server.py
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import subprocess
|
3 |
+
import os
|
4 |
+
import codecs
|
5 |
+
import logging
|
6 |
+
import os
|
7 |
+
import math
|
8 |
+
|
9 |
+
import json
|
10 |
+
import random
|
11 |
+
from tqdm import tqdm
|
12 |
+
from transformers import pipeline
|
13 |
+
from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig
|
14 |
+
|
15 |
+
|
16 |
+
from flask import Flask, request, jsonify
|
17 |
+
import json
|
18 |
+
import random
|
19 |
+
from tqdm import tqdm
|
20 |
+
import os
|
21 |
+
import pickle as pkl
|
22 |
+
from argparse import Namespace
|
23 |
+
|
24 |
+
from models import Elect
|
25 |
+
|
26 |
+
import torch
|
27 |
+
from transformers import AutoModel,AutoTokenizer
|
28 |
+
|
29 |
+
from sklearn.preprocessing import MultiLabelBinarizer
|
30 |
+
|
31 |
+
logger = logging.getLogger(__name__)
|
32 |
+
|
33 |
+
|
34 |
+
app = Flask(__name__)
|
35 |
+
|
36 |
+
hunyin_classifier = None
|
37 |
+
|
38 |
+
fatiao_args = Namespace()
|
39 |
+
fatiao_tokenizer = None
|
40 |
+
fatiao_model = None
|
41 |
+
|
42 |
+
|
43 |
+
@app.route('/check_hunyin', methods=['GET', 'POST'])
|
44 |
+
def check_hunyin():
|
45 |
+
input_text = request.json['input'].strip()
|
46 |
+
force_return = request.json['force_return'] if 'force_return' in request.json else False
|
47 |
+
|
48 |
+
print("input_text:", input_text)
|
49 |
+
|
50 |
+
if len(input_text) == 0:
|
51 |
+
json_result = {
|
52 |
+
"output": []
|
53 |
+
}
|
54 |
+
return jsonify(json_result)
|
55 |
+
|
56 |
+
if not force_return:
|
57 |
+
classifier_result = hunyin_classifier(input_text[:500])
|
58 |
+
print(classifier_result)
|
59 |
+
classifier_result = classifier_result[0]['label']
|
60 |
+
|
61 |
+
# 加一条规则,如果输入文本中包含“婚”字,那么直接判定为婚姻相关
|
62 |
+
if '婚' in input_text:
|
63 |
+
classifier_result = True
|
64 |
+
|
65 |
+
# 如果不是婚姻相关的,直接返回空
|
66 |
+
if classifier_result == False:
|
67 |
+
json_result = {
|
68 |
+
"output": []
|
69 |
+
}
|
70 |
+
return jsonify(json_result)
|
71 |
+
|
72 |
+
inputs = fatiao_tokenizer(input_text, padding='max_length', truncation=True, max_length=256, return_tensors="pt")
|
73 |
+
batch = {
|
74 |
+
'ids': inputs['input_ids'],
|
75 |
+
'mask': inputs['attention_mask'],
|
76 |
+
'token_type_ids':inputs["token_type_ids"]
|
77 |
+
}
|
78 |
+
model_output = fatiao_model(batch)
|
79 |
+
pred = torch.sigmoid(model_output).cpu().detach().numpy()[0]
|
80 |
+
pred_laws = []
|
81 |
+
for law_id, score in sorted(enumerate(pred), key=lambda x: x[1], reverse=True):
|
82 |
+
pred_laws.append({
|
83 |
+
'id': law_id,
|
84 |
+
'score': float(score),
|
85 |
+
'text': fatiao_args.mlb.classes_[law_id]
|
86 |
+
})
|
87 |
+
|
88 |
+
json_result = {
|
89 |
+
"output": pred_laws[:3]
|
90 |
+
}
|
91 |
+
|
92 |
+
print("json_result:", json_result)
|
93 |
+
return jsonify(json_result)
|
94 |
+
|
95 |
+
|
96 |
+
if __name__ == '__main__':
|
97 |
+
|
98 |
+
|
99 |
+
# 加载咨询分类模型,用于判断是否与婚姻有关
|
100 |
+
hunyin_classifier_path = "./pretrained_models/roberta_wwm_ext_hunyin_2epoch"
|
101 |
+
hunyin_config = AutoConfig.from_pretrained(
|
102 |
+
hunyin_classifier_path,
|
103 |
+
num_labels=2,
|
104 |
+
)
|
105 |
+
hunyin_tokenizer = AutoTokenizer.from_pretrained(
|
106 |
+
hunyin_classifier_path
|
107 |
+
)
|
108 |
+
hunyin_model = AutoModelForSequenceClassification.from_pretrained(
|
109 |
+
hunyin_classifier_path,
|
110 |
+
config=hunyin_config,
|
111 |
+
)
|
112 |
+
hunyin_classifier = pipeline(model=hunyin_model, tokenizer=hunyin_tokenizer, task="text-classification", device=0)
|
113 |
+
|
114 |
+
# 加载法条检索模型
|
115 |
+
|
116 |
+
fatiao_args.ckpt_dir = "./pretrained_models/chinese-roberta-wwm-ext"
|
117 |
+
fatiao_args.device = "cuda:0"
|
118 |
+
|
119 |
+
with open(os.path.join("data/labels2id.pkl"), "rb") as f:
|
120 |
+
laws2id = pkl.load(f)
|
121 |
+
fatiao_args.labels = list(laws2id.keys())
|
122 |
+
# get id2laws
|
123 |
+
id2laws = {}
|
124 |
+
for k, v in laws2id.items():
|
125 |
+
id2laws[v] = k
|
126 |
+
# fatiao_args.id2laws = id2laws
|
127 |
+
print("法条个数:", len(id2laws))
|
128 |
+
|
129 |
+
fatiao_tokenizer = AutoTokenizer.from_pretrained(fatiao_args.ckpt_dir)
|
130 |
+
|
131 |
+
fatiao_args.tokenizer = fatiao_tokenizer
|
132 |
+
fatiao_model = Elect(fatiao_args, "cuda:0").to("cuda:0")
|
133 |
+
fatiao_model.eval()
|
134 |
+
|
135 |
+
mlb = MultiLabelBinarizer() # mlb.classes_: idx to law article
|
136 |
+
mlb.fit([fatiao_args.labels])
|
137 |
+
fatiao_args.mlb = mlb
|
138 |
+
|
139 |
+
with torch.no_grad():
|
140 |
+
for idx, l in enumerate(fatiao_args.labels):
|
141 |
+
# remove 《民法典》第xxxx条:
|
142 |
+
text = ':'.join(l.split(':')[1:]).lower()
|
143 |
+
la_in = fatiao_tokenizer(text, padding='max_length', truncation=True, max_length=256,
|
144 |
+
return_tensors="pt")
|
145 |
+
ids = la_in['input_ids'].to(fatiao_args.device)
|
146 |
+
mask = la_in['attention_mask'].to(fatiao_args.device)
|
147 |
+
fatiao_model.la[idx] += (fatiao_model.plm(input_ids=ids, attention_mask=mask)[0][:,0]).squeeze(0)
|
148 |
+
|
149 |
+
|
150 |
+
fatiao_model.load_state_dict(torch.load('./pretrained_models/ELECT', map_location=torch.device(fatiao_args.device)))
|
151 |
+
fatiao_model.to(fatiao_args.device)
|
152 |
+
|
153 |
+
|
154 |
+
|
155 |
+
logger.info("model loaded")
|
156 |
+
app.run(host="0.0.0.0", port=9098, debug=False)
|
utils/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .arg_parser import get_parser
|
2 |
+
# from .eval_metric import EvalMetric
|
utils/arg_parser.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
|
3 |
+
|
4 |
+
def get_parser():
|
5 |
+
parser = argparse.ArgumentParser()
|
6 |
+
parser.add_argument("--data_dir", default="telecom_data/", type=str,
|
7 |
+
help="The input data dir. Should contain the training files for the CoNLL-2003 NER task.", )
|
8 |
+
parser.add_argument("--data_file", default="data_filter.pkl", type=str)
|
9 |
+
parser.add_argument("--ckpt_dir", default="./PLMs/chinese-roberta-wwm-ext", type=str,
|
10 |
+
help="The checkpoints dir. Should contain the pretrained model.", )
|
11 |
+
parser.add_argument("--preprocessor", default="BasePreprocessor", type=str,
|
12 |
+
help="Name of preprocessor.", )
|
13 |
+
parser.add_argument("--device", default="cuda:0", type=str)
|
14 |
+
parser.add_argument("--batch_size", default=128, type=int)
|
15 |
+
parser.add_argument("--max_epoch", default=100, type=int)
|
16 |
+
parser.add_argument("--top_k", default=5, type=int)
|
17 |
+
parser.add_argument("--output_name", default='ELECT_test_output.json', type=str)
|
18 |
+
return parser
|
19 |
+
|
20 |
+
'''
|
21 |
+
python main_elect_inference.py \
|
22 |
+
--data_file jicheng_questions.json \
|
23 |
+
--output_name jicheng_questions_output.json
|
24 |
+
'''
|