bagualu-ie / app.py
han liu
init
ff78ef7
raw
history blame
7.38 kB
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
## @Author: liuhan(liuhan@idea.edu.cn)
## @Created: 2022/12/28 11:24:43
# coding=utf-8
# Copyright 2021 The IDEA Authors. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Dict
from logging import basicConfig
import json
import os
import numpy as np
from transformers import AutoTokenizer
import argparse
import copy
import streamlit as st
import time
from models import BagualuIEModel, BagualuIEExtractModel
class BagualuIEPipelines:
def __init__(self, args: argparse.Namespace) -> None:
self.args = args
# load model
self.model = BagualuIEModel.from_pretrained(args.pretrained_model_root)
# get tokenizer
added_token = [f"[unused{i + 1}]" for i in range(99)]
self.tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model_root,
additional_special_tokens=added_token)
def predict(self, test_data: List[dict], cuda: bool = True) -> List[dict]:
""" predict
Args:
test_data (List[dict]): test data
cuda (bool, optional): cuda. Defaults to True.
Returns:
List[dict]: result
"""
result = []
if cuda:
self.model = self.model.cuda()
self.model.eval()
batch_size = self.args.batch_size
extract_model = BagualuIEExtractModel(self.tokenizer, self.args)
for i in range(0, len(test_data), batch_size):
batch_data = test_data[i: i + batch_size]
batch_result = extract_model.extract(batch_data, self.model, cuda)
result.extend(batch_result)
return result
@st.experimental_memo()
def load_model(model_path):
parser = argparse.ArgumentParser()
# pipeline arguments
group_parser = parser.add_argument_group("piplines args")
group_parser.add_argument("--pretrained_model_root", default="", type=str)
group_parser.add_argument("--load_checkpoints_path", default="", type=str)
group_parser.add_argument("--threshold_ent", default=0.3, type=float)
group_parser.add_argument("--threshold_rel", default=0.3, type=float)
group_parser.add_argument("--entity_multi_label", action="store_true", default=True)
group_parser.add_argument("--relation_multi_label", action="store_true", default=True)
# data model arguments
group_parser = parser.add_argument_group("data_model")
group_parser.add_argument("--batch_size", default=4, type=int)
group_parser.add_argument("--max_length", default=512, type=int)
# pytorch_lightning.Trainer参数
args = parser.parse_args()
args.pretrained_model_root = model_path
model = BagualuIEPipelines(args)
return model
def main():
# model = load_model('/cognitive_comp/liuhan/pretrained/uniex_macbert_base_v7.1/')
model = load_model('IDEA-CCNL/Erlangshen-BERT-120M-IE-Chinese')
#
st.subheader("Erlangshen-BERT-120M-IE-Chinese Zero-shot 体验")
st.markdown("""
Erlangshen-BERT-120M-IE-Chinese是以110M参数的base模型为底座,基于大规模信息抽取数据进行预训练后的模型,
通过统一的抽取架构设计,可支持few-shot、zero-shot场景下的实体识别、关系三元组抽取任务。
更多信息见https://github.com/IDEA-CCNL/GTS-Engine/tree/main
模型效果见https://huggingface.co/IDEA-CCNL/Erlangshen-BERT-120M-IE-Chinese
""")
st.info("Please input the following information to experiencing Bagualu-IE「请输入以下信息开始体验 Bagualu-IE...」")
model_type = st.selectbox('Select task type「选择任务类型」',['Named Entity Recognition「命名实体识别」','Relation Extraction「关系抽取」'])
if '命名实体识别' in model_type:
example = st.selectbox('Example', ['Example: 人物信息', 'Example: 财经新闻'])
else:
example = st.selectbox('Example', ['Example: 雇佣关系', 'Example: 影视关系'])
form = st.form("参数设置")
if '命名实体识别' in model_type:
if '人物信息' in example:
sentences = form.text_area(
"Please input the context「请输入句子」",
"姚明,男,汉族,无党派人士,前中国职业篮球运动员。")
choice = form.text_input("Please input the choice「请输入抽取实体名称,用中文;分割」", "姓名;性别;民族;运动项目;政治面貌")
else:
sentences = form.text_area(
"Please input the context「请输入句子」",
"寒流吹响华尔街,摩根士丹利、高盛、瑞信三大银行裁员合计超过8千人")
choice = form.text_input("Please input the choice「请输入抽取实体名称,用中文;分割」", "裁员单位;裁员人数")
else:
if '雇佣关系' in example:
sentences = form.text_area(
"Please input the context「请输入句子」",
"东阳市企业家协会六届一次会员大会上,横店集团董事长、总裁徐永安当选为东阳市企业家协会会长。")
choice = form.text_input("Please input the choice「请输入抽取关系名称,用中文;分割(头实体类型|关系|尾实体类型)」", "企业|董事长|人物")
else:
sentences = form.text_area(
"Please input the context「请输入句子」",
"《傲骨贤妻第六季》是一套美国法律剧情电视连续剧,2014年9月29日在CBS上首播。")
choice = form.text_input("Please input the choice「请输入抽取关系名称,用中文;分割(头实体类型|关系|尾实体类型)」", "影视作品|上映时间|时间")
form.form_submit_button("Submit「点击一下,开始预测!」")
if '命名实体识别' in model_type:
data = [{"task": '实体识别',
"text": sentences,
"entity_list": [],
"choice": choice.split(';'),
}]
else:
choice = [one.split('|') for one in choice.split(';')]
data = [{"task": '关系抽取',
"text": sentences,
"entity_list": [],
"choice": choice,
}]
start = time.time()
# is_cuda= True if torch.cuda.is_available() else False
# result = model.predict(data, cuda=is_cuda)
# st.success(f"Prediction is successful, consumes {str(time.time()-start)} seconds")
# st.json(result[0])
rs = model.predict(data, False)
st.success(f"Prediction is successful, consumes {str(time.time() - start)} seconds")
st.json(rs[0])
if __name__ == "__main__":
main()