File size: 5,101 Bytes
147e44c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
#!/usr/bin/python3
# -*- coding: utf-8 -*-
import argparse
import os
import time

from allennlp.data.tokenizers.pretrained_transformer_tokenizer import PretrainedTransformerTokenizer
from allennlp.data.token_indexers.single_id_token_indexer import SingleIdTokenIndexer
from allennlp.data.vocabulary import Vocabulary
from allennlp.modules.text_field_embedders.basic_text_field_embedder import BasicTextFieldEmbedder
from allennlp.modules.token_embedders.embedding import Embedding
from allennlp.modules.seq2vec_encoders.cnn_encoder import CnnEncoder
from allennlp.models.archival import archive_model, load_archive
from allennlp_models.rc.modules.seq2seq_encoders.stacked_self_attention import StackedSelfAttentionEncoder
from allennlp.predictors.predictor import Predictor
from allennlp.predictors.text_classifier import TextClassifierPredictor
import gradio as gr
import torch

from project_settings import project_path
from toolbox.allennlp_models.text_classifier.models.hierarchical_text_classifier import HierarchicalClassifier
from toolbox.allennlp_models.text_classifier.dataset_readers.hierarchical_classification_json import HierarchicalClassificationJsonReader


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--cn_archive_file",
        default=(project_path / "trained_models/telemarketing_intent_classification_cn").as_posix(),
        type=str
    )
    parser.add_argument(
        "--en_archive_file",
        default=(project_path / "trained_models/telemarketing_intent_classification_en").as_posix(),
        type=str
    )
    parser.add_argument(
        "--jp_archive_file",
        default=(project_path / "trained_models/telemarketing_intent_classification_jp").as_posix(),
        type=str
    )
    parser.add_argument(
        "--vi_archive_file",
        default=(project_path / "trained_models/telemarketing_intent_classification_vi").as_posix(),
        type=str
    )
    parser.add_argument(
        "--predictor_name",
        default="text_classifier",
        type=str
    )
    args = parser.parse_args()
    return args


def main():
    args = get_args()

    cn_archive = load_archive(archive_file=args.cn_archive_file)
    cn_predictor = Predictor.from_archive(cn_archive, predictor_name=args.predictor_name)
    en_archive = load_archive(archive_file=args.en_archive_file)
    en_predictor = Predictor.from_archive(en_archive, predictor_name=args.predictor_name)
    jp_archive = load_archive(archive_file=args.jp_archive_file)
    jp_predictor = Predictor.from_archive(jp_archive, predictor_name=args.predictor_name)
    vi_archive = load_archive(archive_file=args.vi_archive_file)
    vi_predictor = Predictor.from_archive(vi_archive, predictor_name=args.predictor_name)

    predictor_map = {
        "chinese": cn_predictor,
        "english": en_predictor,
        "japanese": jp_predictor,
        "vietnamese": vi_predictor,
    }

    def fn(text: str, language: str):
        predictor = predictor_map.get(language, cn_predictor)

        json_dict = {'sentence': text}
        outputs = predictor.predict_json(
            json_dict
        )
        outputs = predictor._model.decode(outputs)
        label = outputs['label'][0]
        prob = outputs['prob'][0]
        prob = round(prob, 4)
        return label, prob

    description = """
    电销场景意图识别. 
    语言: 汉语, 英语, 日语, 越南语. 
    数据集是私有的. 
    
    model: selfattention-cnn
    dataset: telemarketing_intent (https://huggingface.co/datasets/qgyd2021/telemarketing_intent)
    
    accuracy:
    chinese: 0.8002
    english: 0.7011
    japanese: 0.8154
    vietnamese: 0.8168
    
    """
    demo = gr.Interface(
        fn=fn,
        inputs=[
            gr.Text(label="text"),
            gr.Dropdown(
                choices=list(sorted(predictor_map.keys())),
                label="language"
            )
        ],
        outputs=[gr.Text(label="intent"), gr.Number(label="prob")],
        examples=[
            ["你找谁", "chinese"],
            ["你是谁啊", "chinese"],
            ["不好意思我现在很忙", "chinese"],
            ["对不起, 不需要哈", "chinese"],
            ["u have got the wrong number", "english"],
            ["sure, thank a lot", "english"],
            ["please leave your message for 95688496", "english"],
            ["yes well", "english"],
            ["失礼の", "japanese"],
            ["ビートいう発表の後に、お名前とご用件をお話ください。", "japanese"],
            ["わかんない。", "japanese"],
            ["に出ることができません", "japanese"],
            ["À không phải em nha.", "vietnamese"],
            ["Dạ nhầm số rồi ạ?", "vietnamese"],
            ["Ừ, cảm ơn em nhá.", "vietnamese"],
            ["Không, chị không có tiền.", "vietnamese"],
        ],
        examples_per_page=50,
        title="Telemarketing Intent Classification",
        description=description,
    )
    demo.launch()

    return


if __name__ == '__main__':
    main()