| | |
| | |
| | import argparse |
| | import json |
| | import os |
| | from pathlib import Path |
| | import platform |
| | import tempfile |
| | import time |
| | from typing import List, Dict |
| | import zipfile |
| |
|
| | from cacheout import Cache |
| | import gradio as gr |
| | import huggingface_hub |
| | import numpy as np |
| | import torch |
| |
|
| | from project_settings import project_path, environment |
| | from toolbox.torch.utils.data.tokenizers.pretrained_bert_tokenizer import PretrainedBertTokenizer |
| | from toolbox.torch.utils.data.vocabulary import Vocabulary |
| |
|
| |
|
| | def get_args(): |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument( |
| | "--waba_intent_examples_file", |
| | default=(project_path / "waba_intent_examples.json").as_posix(), |
| | type=str |
| | ) |
| | parser.add_argument( |
| | "--waba_intent_md_file", |
| | default=(project_path / "waba_intent.md").as_posix(), |
| | type=str |
| | ) |
| | args = parser.parse_args() |
| | return args |
| |
|
| |
|
| | model_cache = Cache(maxsize=256, ttl=1 * 60, timer=time.time) |
| |
|
| |
|
| | def load_waba_intent_model(repo_id: str): |
| | model_local_dir = project_path / "trained_models/{}".format(repo_id) |
| | model_local_dir.mkdir(parents=True, exist_ok=True) |
| | hf_token = environment.get("hf_token") |
| | huggingface_hub.login(token=hf_token) |
| | huggingface_hub.snapshot_download( |
| | repo_id=repo_id, |
| | local_dir=model_local_dir |
| | ) |
| |
|
| | model = torch.jit.load((model_local_dir / "final.zip").as_posix()) |
| | vocabulary = Vocabulary.from_files((model_local_dir / "vocabulary").as_posix()) |
| | tokenizer = PretrainedBertTokenizer(model_local_dir.as_posix()) |
| |
|
| | result = { |
| | "model": model, |
| | "vocabulary": vocabulary, |
| | "tokenizer": tokenizer, |
| | } |
| | return result |
| |
|
| |
|
| | def click_waba_intent_button(repo_id: str, text: str): |
| | model_group = model_cache.get(repo_id) |
| | if model_group is None: |
| | model_group = load_waba_intent_model(repo_id) |
| | model_cache.set(key=repo_id, value=model_group) |
| |
|
| | model = model_group["model"] |
| | vocabulary = model_group["vocabulary"] |
| | tokenizer = model_group["tokenizer"] |
| |
|
| | tokens: List[str] = tokenizer.tokenize(text) |
| | tokens: List[int] = [vocabulary.get_token_index(token, namespace="tokens") for token in tokens] |
| |
|
| | if len(tokens) < 5: |
| | tokens = vocabulary.pad_or_truncate_ids_by_max_length(tokens, max_length=5) |
| | batch_tokens = [tokens] |
| | batch_tokens = torch.from_numpy(np.array(batch_tokens)) |
| |
|
| | outputs = model.forward(batch_tokens) |
| |
|
| | probs = outputs["probs"] |
| | argmax = torch.argmax(probs, dim=-1) |
| | probs = probs.tolist()[0] |
| | argmax = argmax.tolist()[0] |
| |
|
| | label_str = vocabulary.get_token_from_index(argmax, namespace="labels") |
| | prob = probs[argmax] |
| | prob = round(prob, 4) |
| |
|
| | return label_str, prob |
| |
|
| |
|
| | def main(): |
| | args = get_args() |
| |
|
| | brief_description = """ |
| | ## Text Classification |
| | """ |
| |
|
| | |
| | with open(args.waba_intent_examples_file, "r", encoding="utf-8") as f: |
| | waba_intent_examples = json.load(f) |
| | with open(args.waba_intent_md_file, "r", encoding="utf-8") as f: |
| | waba_intent_md = f.read() |
| |
|
| | with gr.Blocks() as blocks: |
| | gr.Markdown(value=brief_description) |
| |
|
| | with gr.Row(): |
| | with gr.Column(scale=5): |
| | with gr.Tabs(): |
| | with gr.TabItem("waba_intent"): |
| | with gr.Row(): |
| | with gr.Column(scale=1): |
| | waba_intent_repo_id = gr.Dropdown( |
| | choices=["nxcloud/waba_intent_en"], |
| | value="nxcloud/waba_intent_en", |
| | label="repo_id" |
| | ) |
| | waba_intent_text = gr.Textbox(label="text", max_lines=5) |
| | waba_intent_button = gr.Button("predict", variant="primary") |
| |
|
| | with gr.Column(scale=1): |
| | waba_intent_label = gr.Textbox(label="label") |
| | waba_intent_prob = gr.Textbox(label="prob") |
| |
|
| | |
| | gr.Examples( |
| | examples=waba_intent_examples, |
| | inputs=[ |
| | waba_intent_repo_id, |
| | waba_intent_text, |
| | ], |
| | outputs=[ |
| | waba_intent_label, |
| | waba_intent_prob |
| | ], |
| | fn=click_waba_intent_button |
| | ) |
| |
|
| | |
| | gr.Markdown(value=waba_intent_md) |
| |
|
| | |
| | waba_intent_button.click( |
| | fn=click_waba_intent_button, |
| | inputs=[ |
| | waba_intent_repo_id, |
| | waba_intent_text, |
| | ], |
| | outputs=[ |
| | waba_intent_label, |
| | waba_intent_prob |
| | ], |
| | ) |
| |
|
| | blocks.queue().launch( |
| | share=False if platform.system() == "Windows" else False, |
| | server_name="127.0.0.1" if platform.system() == "Windows" else "0.0.0.0", |
| | server_port=7860 |
| | ) |
| | return |
| |
|
| |
|
| | if __name__ == '__main__': |
| | main() |
| |
|