qgyd2021's picture
[update]add model
aab0209
#!/usr/bin/python3
# -*- coding: utf-8 -*-
import argparse
import os
import platform
import time
from allennlp.data.tokenizers.pretrained_transformer_tokenizer import PretrainedTransformerTokenizer
from allennlp.data.token_indexers.single_id_token_indexer import SingleIdTokenIndexer
from allennlp.data.vocabulary import Vocabulary
from allennlp.modules.text_field_embedders.basic_text_field_embedder import BasicTextFieldEmbedder
from allennlp.modules.token_embedders.embedding import Embedding
from allennlp.modules.seq2vec_encoders.cnn_encoder import CnnEncoder
from allennlp.models.archival import archive_model, load_archive
from allennlp_models.rc.modules.seq2seq_encoders.stacked_self_attention import StackedSelfAttentionEncoder
from allennlp.predictors.predictor import Predictor
from allennlp.predictors.text_classifier import TextClassifierPredictor
import gradio as gr
import torch
from project_settings import project_path
from toolbox.allennlp_models.text_classifier.models.hierarchical_text_classifier import HierarchicalClassifier
from toolbox.allennlp_models.text_classifier.dataset_readers.hierarchical_classification_json import HierarchicalClassificationJsonReader
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--cn_archive_file",
default=(project_path / "trained_models/telemarketing_intent_classification_cn").as_posix(),
type=str
)
parser.add_argument(
"--en_archive_file",
default=(project_path / "trained_models/telemarketing_intent_classification_en").as_posix(),
type=str
)
parser.add_argument(
"--jp_archive_file",
default=(project_path / "trained_models/telemarketing_intent_classification_jp").as_posix(),
type=str
)
parser.add_argument(
"--vi_archive_file",
default=(project_path / "trained_models/telemarketing_intent_classification_vi").as_posix(),
type=str
)
parser.add_argument(
"--predictor_name",
default="text_classifier",
type=str
)
args = parser.parse_args()
return args
def main():
args = get_args()
cn_archive = load_archive(archive_file=args.cn_archive_file)
cn_predictor = Predictor.from_archive(cn_archive, predictor_name=args.predictor_name)
en_archive = load_archive(archive_file=args.en_archive_file)
en_predictor = Predictor.from_archive(en_archive, predictor_name=args.predictor_name)
jp_archive = load_archive(archive_file=args.jp_archive_file)
jp_predictor = Predictor.from_archive(jp_archive, predictor_name=args.predictor_name)
vi_archive = load_archive(archive_file=args.vi_archive_file)
vi_predictor = Predictor.from_archive(vi_archive, predictor_name=args.predictor_name)
predictor_map = {
"chinese": cn_predictor,
"english": en_predictor,
"japanese": jp_predictor,
"vietnamese": vi_predictor,
}
def fn(text: str, language: str):
predictor = predictor_map.get(language, cn_predictor)
json_dict = {'sentence': text}
outputs = predictor.predict_json(
json_dict
)
outputs = predictor._model.decode(outputs)
label = outputs['label'][0]
prob = outputs['prob'][0]
prob = round(prob, 4)
return label, prob
description = """
电销场景意图识别.
语言: 汉语, 英语, 日语, 越南语.
数据集是私有的.
model: selfattention-cnn
dataset: telemarketing_intent (https://huggingface.co/datasets/qgyd2021/telemarketing_intent)
accuracy:
chinese: 0.8002
english: 0.7011
japanese: 0.8154
vietnamese: 0.8168
"""
demo = gr.Interface(
fn=fn,
inputs=[
gr.Text(label="text"),
gr.Dropdown(
choices=list(sorted(predictor_map.keys())),
label="language"
)
],
outputs=[gr.Text(label="intent"), gr.Number(label="prob")],
examples=[
["你找谁", "chinese"],
["你是谁啊", "chinese"],
["不好意思我现在很忙", "chinese"],
["对不起, 不需要哈", "chinese"],
["u have got the wrong number", "english"],
["sure, thank a lot", "english"],
["please leave your message for 95688496", "english"],
["yes well", "english"],
["失礼の", "japanese"],
["ビートいう発表の後に、お名前とご用件をお話ください。", "japanese"],
["わかんない。", "japanese"],
["に出ることができません", "japanese"],
["À không phải em nha.", "vietnamese"],
["Dạ nhầm số rồi ạ?", "vietnamese"],
["Ừ, cảm ơn em nhá.", "vietnamese"],
["Không, chị không có tiền.", "vietnamese"],
],
examples_per_page=50,
title="Telemarketing Intent Classification",
description=description,
)
demo.launch(share=True if platform.system() == "Windows" else False)
return
if __name__ == '__main__':
main()