File size: 4,691 Bytes
a48216a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import json
import subprocess
import os
import codecs
import logging
import os
import math

import json
import random
from tqdm import tqdm
from transformers import pipeline
from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig


from flask import Flask, request, jsonify
import json
import random
from tqdm import tqdm
import os
import pickle as pkl
from argparse import Namespace

from models import Elect

import torch
from transformers import AutoModel,AutoTokenizer

from sklearn.preprocessing import MultiLabelBinarizer

logger = logging.getLogger(__name__)


app = Flask(__name__)

hunyin_classifier = None

fatiao_args = Namespace()
fatiao_tokenizer = None
fatiao_model = None


@app.route('/check_hunyin', methods=['GET', 'POST'])
def check_hunyin():
    input_text = request.json['input'].strip()
    force_return = request.json['force_return'] if 'force_return' in request.json else False

    print("input_text:", input_text)

    if len(input_text) == 0:
        json_result = {
            "output": []
        }
        return jsonify(json_result)
    
    if not force_return:
        classifier_result = hunyin_classifier(input_text[:500])
        print(classifier_result)
        classifier_result = classifier_result[0]['label']

        # 加一条规则,如果输入文本中包含“婚”字,那么直接判定为婚姻相关
        if '婚' in input_text:
            classifier_result = True

        # 如果不是婚姻相关的,直接返回空
        if classifier_result == False:
            json_result = {
                "output": []
            }
            return jsonify(json_result)

    inputs = fatiao_tokenizer(input_text, padding='max_length', truncation=True, max_length=256, return_tensors="pt")
    batch = {
        'ids': inputs['input_ids'],
        'mask': inputs['attention_mask'],
        'token_type_ids':inputs["token_type_ids"]
    }
    model_output = fatiao_model(batch)
    pred = torch.sigmoid(model_output).cpu().detach().numpy()[0]
    pred_laws = []
    for law_id, score in sorted(enumerate(pred), key=lambda x: x[1], reverse=True):
        pred_laws.append({
            'id': law_id,
            'score': float(score),
            'text': fatiao_args.mlb.classes_[law_id]
        })

    json_result = {
            "output": pred_laws[:3]
        }

    print("json_result:", json_result)
    return jsonify(json_result)

            
if __name__ == '__main__':


    # 加载咨询分类模型,用于判断是否与婚姻有关
    hunyin_classifier_path = "./pretrained_models/roberta_wwm_ext_hunyin_2epoch"
    hunyin_config = AutoConfig.from_pretrained(
        hunyin_classifier_path,
        num_labels=2,
    )
    hunyin_tokenizer = AutoTokenizer.from_pretrained(
        hunyin_classifier_path
    )
    hunyin_model = AutoModelForSequenceClassification.from_pretrained(
        hunyin_classifier_path,
        config=hunyin_config,
    )
    hunyin_classifier = pipeline(model=hunyin_model, tokenizer=hunyin_tokenizer, task="text-classification", device=0)

    # 加载法条检索模型
    
    fatiao_args.ckpt_dir = "./pretrained_models/chinese-roberta-wwm-ext"
    fatiao_args.device = "cuda:0"

    with open(os.path.join("data/labels2id.pkl"), "rb") as f:
        laws2id = pkl.load(f) 
        fatiao_args.labels = list(laws2id.keys())
    # get id2laws
    id2laws = {}
    for k, v in laws2id.items():
        id2laws[v] = k
    # fatiao_args.id2laws = id2laws
    print("法条个数:", len(id2laws))

    fatiao_tokenizer = AutoTokenizer.from_pretrained(fatiao_args.ckpt_dir)

    fatiao_args.tokenizer = fatiao_tokenizer
    fatiao_model = Elect(fatiao_args, "cuda:0").to("cuda:0")
    fatiao_model.eval()

    mlb = MultiLabelBinarizer() # mlb.classes_: idx to law article
    mlb.fit([fatiao_args.labels])
    fatiao_args.mlb = mlb

    with torch.no_grad():
        for idx, l in enumerate(fatiao_args.labels):
            # remove 《民法典》第xxxx条:
            text = ':'.join(l.split(':')[1:]).lower()
            la_in = fatiao_tokenizer(text, padding='max_length', truncation=True, max_length=256,
                   return_tensors="pt")
            ids = la_in['input_ids'].to(fatiao_args.device)
            mask = la_in['attention_mask'].to(fatiao_args.device)
            fatiao_model.la[idx] += (fatiao_model.plm(input_ids=ids, attention_mask=mask)[0][:,0]).squeeze(0)   
    
    
    fatiao_model.load_state_dict(torch.load('./pretrained_models/ELECT', map_location=torch.device(fatiao_args.device)))
    fatiao_model.to(fatiao_args.device)



    logger.info("model loaded")
    app.run(host="0.0.0.0", port=9098, debug=False)