import json import subprocess import os import codecs import logging import os import math import json import random from tqdm import tqdm from transformers import pipeline from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig from flask import Flask, request, jsonify import json import random from tqdm import tqdm import os import pickle as pkl from argparse import Namespace from models import Elect import torch from transformers import AutoModel,AutoTokenizer from sklearn.preprocessing import MultiLabelBinarizer logger = logging.getLogger(__name__) app = Flask(__name__) hunyin_classifier = None fatiao_args = Namespace() fatiao_tokenizer = None fatiao_model = None @app.route('/check_hunyin', methods=['GET', 'POST']) def check_hunyin(): input_text = request.json['input'].strip() force_return = request.json['force_return'] if 'force_return' in request.json else False print("input_text:", input_text) if len(input_text) == 0: json_result = { "output": [] } return jsonify(json_result) if not force_return: classifier_result = hunyin_classifier(input_text[:500]) print(classifier_result) classifier_result = classifier_result[0]['label'] # 加一条规则,如果输入文本中包含“婚”字,那么直接判定为婚姻相关 if '婚' in input_text: classifier_result = True # 如果不是婚姻相关的,直接返回空 if classifier_result == False: json_result = { "output": [] } return jsonify(json_result) inputs = fatiao_tokenizer(input_text, padding='max_length', truncation=True, max_length=256, return_tensors="pt") batch = { 'ids': inputs['input_ids'], 'mask': inputs['attention_mask'], 'token_type_ids':inputs["token_type_ids"] } model_output = fatiao_model(batch) pred = torch.sigmoid(model_output).cpu().detach().numpy()[0] pred_laws = [] for law_id, score in sorted(enumerate(pred), key=lambda x: x[1], reverse=True): pred_laws.append({ 'id': law_id, 'score': float(score), 'text': fatiao_args.mlb.classes_[law_id] }) json_result = { "output": pred_laws[:3] } print("json_result:", json_result) return jsonify(json_result) if __name__ == '__main__': # 加载咨询分类模型,用于判断是否与婚姻有关 hunyin_classifier_path = "./pretrained_models/roberta_wwm_ext_hunyin_2epoch" hunyin_config = AutoConfig.from_pretrained( hunyin_classifier_path, num_labels=2, ) hunyin_tokenizer = AutoTokenizer.from_pretrained( hunyin_classifier_path ) hunyin_model = AutoModelForSequenceClassification.from_pretrained( hunyin_classifier_path, config=hunyin_config, ) hunyin_classifier = pipeline(model=hunyin_model, tokenizer=hunyin_tokenizer, task="text-classification", device=0) # 加载法条检索模型 fatiao_args.ckpt_dir = "./pretrained_models/chinese-roberta-wwm-ext" fatiao_args.device = "cuda:0" with open(os.path.join("data/labels2id.pkl"), "rb") as f: laws2id = pkl.load(f) fatiao_args.labels = list(laws2id.keys()) # get id2laws id2laws = {} for k, v in laws2id.items(): id2laws[v] = k # fatiao_args.id2laws = id2laws print("法条个数:", len(id2laws)) fatiao_tokenizer = AutoTokenizer.from_pretrained(fatiao_args.ckpt_dir) fatiao_args.tokenizer = fatiao_tokenizer fatiao_model = Elect(fatiao_args, "cuda:0").to("cuda:0") fatiao_model.eval() mlb = MultiLabelBinarizer() # mlb.classes_: idx to law article mlb.fit([fatiao_args.labels]) fatiao_args.mlb = mlb with torch.no_grad(): for idx, l in enumerate(fatiao_args.labels): # remove 《民法典》第xxxx条: text = ':'.join(l.split(':')[1:]).lower() la_in = fatiao_tokenizer(text, padding='max_length', truncation=True, max_length=256, return_tensors="pt") ids = la_in['input_ids'].to(fatiao_args.device) mask = la_in['attention_mask'].to(fatiao_args.device) fatiao_model.la[idx] += (fatiao_model.plm(input_ids=ids, attention_mask=mask)[0][:,0]).squeeze(0) fatiao_model.load_state_dict(torch.load('./pretrained_models/ELECT', map_location=torch.device(fatiao_args.device))) fatiao_model.to(fatiao_args.device) logger.info("model loaded") app.run(host="0.0.0.0", port=9098, debug=False)