File size: 5,101 Bytes
147e44c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
#!/usr/bin/python3
# -*- coding: utf-8 -*-
import argparse
import os
import time
from allennlp.data.tokenizers.pretrained_transformer_tokenizer import PretrainedTransformerTokenizer
from allennlp.data.token_indexers.single_id_token_indexer import SingleIdTokenIndexer
from allennlp.data.vocabulary import Vocabulary
from allennlp.modules.text_field_embedders.basic_text_field_embedder import BasicTextFieldEmbedder
from allennlp.modules.token_embedders.embedding import Embedding
from allennlp.modules.seq2vec_encoders.cnn_encoder import CnnEncoder
from allennlp.models.archival import archive_model, load_archive
from allennlp_models.rc.modules.seq2seq_encoders.stacked_self_attention import StackedSelfAttentionEncoder
from allennlp.predictors.predictor import Predictor
from allennlp.predictors.text_classifier import TextClassifierPredictor
import gradio as gr
import torch
from project_settings import project_path
from toolbox.allennlp_models.text_classifier.models.hierarchical_text_classifier import HierarchicalClassifier
from toolbox.allennlp_models.text_classifier.dataset_readers.hierarchical_classification_json import HierarchicalClassificationJsonReader
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--cn_archive_file",
default=(project_path / "trained_models/telemarketing_intent_classification_cn").as_posix(),
type=str
)
parser.add_argument(
"--en_archive_file",
default=(project_path / "trained_models/telemarketing_intent_classification_en").as_posix(),
type=str
)
parser.add_argument(
"--jp_archive_file",
default=(project_path / "trained_models/telemarketing_intent_classification_jp").as_posix(),
type=str
)
parser.add_argument(
"--vi_archive_file",
default=(project_path / "trained_models/telemarketing_intent_classification_vi").as_posix(),
type=str
)
parser.add_argument(
"--predictor_name",
default="text_classifier",
type=str
)
args = parser.parse_args()
return args
def main():
args = get_args()
cn_archive = load_archive(archive_file=args.cn_archive_file)
cn_predictor = Predictor.from_archive(cn_archive, predictor_name=args.predictor_name)
en_archive = load_archive(archive_file=args.en_archive_file)
en_predictor = Predictor.from_archive(en_archive, predictor_name=args.predictor_name)
jp_archive = load_archive(archive_file=args.jp_archive_file)
jp_predictor = Predictor.from_archive(jp_archive, predictor_name=args.predictor_name)
vi_archive = load_archive(archive_file=args.vi_archive_file)
vi_predictor = Predictor.from_archive(vi_archive, predictor_name=args.predictor_name)
predictor_map = {
"chinese": cn_predictor,
"english": en_predictor,
"japanese": jp_predictor,
"vietnamese": vi_predictor,
}
def fn(text: str, language: str):
predictor = predictor_map.get(language, cn_predictor)
json_dict = {'sentence': text}
outputs = predictor.predict_json(
json_dict
)
outputs = predictor._model.decode(outputs)
label = outputs['label'][0]
prob = outputs['prob'][0]
prob = round(prob, 4)
return label, prob
description = """
电销场景意图识别.
语言: 汉语, 英语, 日语, 越南语.
数据集是私有的.
model: selfattention-cnn
dataset: telemarketing_intent (https://huggingface.co/datasets/qgyd2021/telemarketing_intent)
accuracy:
chinese: 0.8002
english: 0.7011
japanese: 0.8154
vietnamese: 0.8168
"""
demo = gr.Interface(
fn=fn,
inputs=[
gr.Text(label="text"),
gr.Dropdown(
choices=list(sorted(predictor_map.keys())),
label="language"
)
],
outputs=[gr.Text(label="intent"), gr.Number(label="prob")],
examples=[
["你找谁", "chinese"],
["你是谁啊", "chinese"],
["不好意思我现在很忙", "chinese"],
["对不起, 不需要哈", "chinese"],
["u have got the wrong number", "english"],
["sure, thank a lot", "english"],
["please leave your message for 95688496", "english"],
["yes well", "english"],
["失礼の", "japanese"],
["ビートいう発表の後に、お名前とご用件をお話ください。", "japanese"],
["わかんない。", "japanese"],
["に出ることができません", "japanese"],
["À không phải em nha.", "vietnamese"],
["Dạ nhầm số rồi ạ?", "vietnamese"],
["Ừ, cảm ơn em nhá.", "vietnamese"],
["Không, chị không có tiền.", "vietnamese"],
],
examples_per_page=50,
title="Telemarketing Intent Classification",
description=description,
)
demo.launch()
return
if __name__ == '__main__':
main()
|