File size: 7,381 Bytes
ff78ef7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
## @Author: liuhan(liuhan@idea.edu.cn)
## @Created: 2022/12/28 11:24:43
# coding=utf-8
# Copyright 2021 The IDEA Authors. All rights reserved.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#     http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Dict
from logging import basicConfig
import json
import os
import numpy as np
from transformers import AutoTokenizer
import argparse
import copy
import streamlit as st
import time



from models import BagualuIEModel, BagualuIEExtractModel


class BagualuIEPipelines:
    def __init__(self, args: argparse.Namespace) -> None:
        self.args = args
        # load model
        self.model = BagualuIEModel.from_pretrained(args.pretrained_model_root)


        # get tokenizer
        added_token = [f"[unused{i + 1}]" for i in range(99)]
        self.tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model_root,
                                                       additional_special_tokens=added_token)

    def predict(self, test_data: List[dict], cuda: bool = True) -> List[dict]:
        """ predict

        Args:
            test_data (List[dict]): test data
            cuda (bool, optional): cuda. Defaults to True.

        Returns:
            List[dict]: result
        """
        result = []
        if cuda:
            self.model = self.model.cuda()
        self.model.eval()

        batch_size = self.args.batch_size
        extract_model = BagualuIEExtractModel(self.tokenizer, self.args)
    
        for i in range(0, len(test_data), batch_size):
            batch_data = test_data[i: i + batch_size]
            batch_result = extract_model.extract(batch_data, self.model, cuda)
            result.extend(batch_result)
        return result


@st.experimental_memo()
def load_model(model_path):
    parser = argparse.ArgumentParser()

    # pipeline arguments
    group_parser = parser.add_argument_group("piplines args")
    group_parser.add_argument("--pretrained_model_root", default="", type=str)
    group_parser.add_argument("--load_checkpoints_path", default="", type=str)

    group_parser.add_argument("--threshold_ent", default=0.3, type=float)
    group_parser.add_argument("--threshold_rel", default=0.3, type=float)
    group_parser.add_argument("--entity_multi_label", action="store_true", default=True)
    group_parser.add_argument("--relation_multi_label", action="store_true", default=True)


    # data model arguments
    group_parser = parser.add_argument_group("data_model")
    group_parser.add_argument("--batch_size", default=4, type=int)
    group_parser.add_argument("--max_length", default=512, type=int)
    # pytorch_lightning.Trainer参数
    args = parser.parse_args()
    args.pretrained_model_root = model_path

    model = BagualuIEPipelines(args)
    return model

def main():
    
    # model = load_model('/cognitive_comp/liuhan/pretrained/uniex_macbert_base_v7.1/')
    model = load_model('IDEA-CCNL/Erlangshen-BERT-120M-IE-Chinese')

    # 

    st.subheader("Erlangshen-BERT-120M-IE-Chinese Zero-shot 体验")



    st.markdown("""
            Erlangshen-BERT-120M-IE-Chinese是以110M参数的base模型为底座,基于大规模信息抽取数据进行预训练后的模型,
            通过统一的抽取架构设计,可支持few-shot、zero-shot场景下的实体识别、关系三元组抽取任务。
            更多信息见https://github.com/IDEA-CCNL/GTS-Engine/tree/main
            模型效果见https://huggingface.co/IDEA-CCNL/Erlangshen-BERT-120M-IE-Chinese
            """)

    st.info("Please input the following information to experiencing Bagualu-IE「请输入以下信息开始体验 Bagualu-IE...」")
    model_type = st.selectbox('Select task type「选择任务类型」',['Named Entity Recognition「命名实体识别」','Relation Extraction「关系抽取」'])
    if '命名实体识别' in model_type:
        example = st.selectbox('Example', ['Example: 人物信息', 'Example: 财经新闻'])
    else:
        example = st.selectbox('Example', ['Example: 雇佣关系', 'Example: 影视关系'])
    form = st.form("参数设置")
    if '命名实体识别' in model_type:
        if '人物信息' in example:
            sentences = form.text_area(
                "Please input the context「请输入句子」", 
                "姚明,男,汉族,无党派人士,前中国职业篮球运动员。")
            choice = form.text_input("Please input the choice「请输入抽取实体名称,用中文;分割」", "姓名;性别;民族;运动项目;政治面貌")
        else:
            sentences = form.text_area(
                "Please input the context「请输入句子」", 
                "寒流吹响华尔街,摩根士丹利、高盛、瑞信三大银行裁员合计超过8千人")
            choice = form.text_input("Please input the choice「请输入抽取实体名称,用中文;分割」", "裁员单位;裁员人数")

    else:
        if '雇佣关系' in example:
            sentences = form.text_area(
                "Please input the context「请输入句子」", 
                "东阳市企业家协会六届一次会员大会上,横店集团董事长、总裁徐永安当选为东阳市企业家协会会长。")
            choice = form.text_input("Please input the choice「请输入抽取关系名称,用中文;分割(头实体类型|关系|尾实体类型)」", "企业|董事长|人物")
        else:
            sentences = form.text_area(
                "Please input the context「请输入句子」", 
                "《傲骨贤妻第六季》是一套美国法律剧情电视连续剧,2014年9月29日在CBS上首播。")
            choice = form.text_input("Please input the choice「请输入抽取关系名称,用中文;分割(头实体类型|关系|尾实体类型)」", "影视作品|上映时间|时间")

    form.form_submit_button("Submit「点击一下,开始预测!」")

    
    if '命名实体识别' in model_type:
        data = [{"task": '实体识别',
                "text": sentences,
                "entity_list": [], 
                "choice": choice.split(';'),
                }]
    else:
        choice = [one.split('|') for one in choice.split(';')]
        data = [{"task": '关系抽取',
                "text": sentences,
                "entity_list": [], 
                "choice": choice,
                }]


    start = time.time()
    # is_cuda= True if torch.cuda.is_available() else False
    # result = model.predict(data, cuda=is_cuda)
    
    # st.success(f"Prediction is successful, consumes {str(time.time()-start)} seconds")
    # st.json(result[0])

    rs = model.predict(data, False)
    st.success(f"Prediction is successful, consumes {str(time.time() - start)} seconds")
    st.json(rs[0])





if __name__ == "__main__":
    main()