|
|
|
|
|
import argparse |
|
import os |
|
import platform |
|
import time |
|
|
|
from allennlp.data.tokenizers.pretrained_transformer_tokenizer import PretrainedTransformerTokenizer |
|
from allennlp.data.token_indexers.single_id_token_indexer import SingleIdTokenIndexer |
|
from allennlp.data.vocabulary import Vocabulary |
|
from allennlp.modules.text_field_embedders.basic_text_field_embedder import BasicTextFieldEmbedder |
|
from allennlp.modules.token_embedders.embedding import Embedding |
|
from allennlp.modules.seq2vec_encoders.cnn_encoder import CnnEncoder |
|
from allennlp.models.archival import archive_model, load_archive |
|
from allennlp_models.rc.modules.seq2seq_encoders.stacked_self_attention import StackedSelfAttentionEncoder |
|
from allennlp.predictors.predictor import Predictor |
|
from allennlp.predictors.text_classifier import TextClassifierPredictor |
|
import gradio as gr |
|
import torch |
|
|
|
from project_settings import project_path |
|
from toolbox.allennlp_models.text_classifier.models.hierarchical_text_classifier import HierarchicalClassifier |
|
from toolbox.allennlp_models.text_classifier.dataset_readers.hierarchical_classification_json import HierarchicalClassificationJsonReader |
|
|
|
|
|
def get_args(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
"--cn_archive_file", |
|
default=(project_path / "trained_models/telemarketing_intent_classification_cn").as_posix(), |
|
type=str |
|
) |
|
parser.add_argument( |
|
"--en_archive_file", |
|
default=(project_path / "trained_models/telemarketing_intent_classification_en").as_posix(), |
|
type=str |
|
) |
|
parser.add_argument( |
|
"--jp_archive_file", |
|
default=(project_path / "trained_models/telemarketing_intent_classification_jp").as_posix(), |
|
type=str |
|
) |
|
parser.add_argument( |
|
"--vi_archive_file", |
|
default=(project_path / "trained_models/telemarketing_intent_classification_vi").as_posix(), |
|
type=str |
|
) |
|
parser.add_argument( |
|
"--predictor_name", |
|
default="text_classifier", |
|
type=str |
|
) |
|
args = parser.parse_args() |
|
return args |
|
|
|
|
|
def main(): |
|
args = get_args() |
|
|
|
cn_archive = load_archive(archive_file=args.cn_archive_file) |
|
cn_predictor = Predictor.from_archive(cn_archive, predictor_name=args.predictor_name) |
|
en_archive = load_archive(archive_file=args.en_archive_file) |
|
en_predictor = Predictor.from_archive(en_archive, predictor_name=args.predictor_name) |
|
jp_archive = load_archive(archive_file=args.jp_archive_file) |
|
jp_predictor = Predictor.from_archive(jp_archive, predictor_name=args.predictor_name) |
|
vi_archive = load_archive(archive_file=args.vi_archive_file) |
|
vi_predictor = Predictor.from_archive(vi_archive, predictor_name=args.predictor_name) |
|
|
|
predictor_map = { |
|
"chinese": cn_predictor, |
|
"english": en_predictor, |
|
"japanese": jp_predictor, |
|
"vietnamese": vi_predictor, |
|
} |
|
|
|
def fn(text: str, language: str): |
|
predictor = predictor_map.get(language, cn_predictor) |
|
|
|
json_dict = {'sentence': text} |
|
outputs = predictor.predict_json( |
|
json_dict |
|
) |
|
outputs = predictor._model.decode(outputs) |
|
label = outputs['label'][0] |
|
prob = outputs['prob'][0] |
|
prob = round(prob, 4) |
|
return label, prob |
|
|
|
description = """ |
|
电销场景意图识别. |
|
语言: 汉语, 英语, 日语, 越南语. |
|
数据集是私有的. |
|
|
|
model: selfattention-cnn |
|
dataset: telemarketing_intent (https://huggingface.co/datasets/qgyd2021/telemarketing_intent) |
|
|
|
accuracy: |
|
chinese: 0.8002 |
|
english: 0.7011 |
|
japanese: 0.8154 |
|
vietnamese: 0.8168 |
|
|
|
""" |
|
demo = gr.Interface( |
|
fn=fn, |
|
inputs=[ |
|
gr.Text(label="text"), |
|
gr.Dropdown( |
|
choices=list(sorted(predictor_map.keys())), |
|
label="language" |
|
) |
|
], |
|
outputs=[gr.Text(label="intent"), gr.Number(label="prob")], |
|
examples=[ |
|
["你找谁", "chinese"], |
|
["你是谁啊", "chinese"], |
|
["不好意思我现在很忙", "chinese"], |
|
["对不起, 不需要哈", "chinese"], |
|
["u have got the wrong number", "english"], |
|
["sure, thank a lot", "english"], |
|
["please leave your message for 95688496", "english"], |
|
["yes well", "english"], |
|
["失礼の", "japanese"], |
|
["ビートいう発表の後に、お名前とご用件をお話ください。", "japanese"], |
|
["わかんない。", "japanese"], |
|
["に出ることができません", "japanese"], |
|
["À không phải em nha.", "vietnamese"], |
|
["Dạ nhầm số rồi ạ?", "vietnamese"], |
|
["Ừ, cảm ơn em nhá.", "vietnamese"], |
|
["Không, chị không có tiền.", "vietnamese"], |
|
], |
|
examples_per_page=50, |
|
title="Telemarketing Intent Classification", |
|
description=description, |
|
) |
|
demo.launch(share=True if platform.system() == "Windows" else False) |
|
|
|
return |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|