#!/usr/bin/python3 # -*- coding: utf-8 -*- import argparse import json from allennlp.data.tokenizers.pretrained_transformer_tokenizer import PretrainedTransformerTokenizer from allennlp.data.token_indexers.single_id_token_indexer import SingleIdTokenIndexer from allennlp.data.vocabulary import Vocabulary from allennlp.modules.text_field_embedders.basic_text_field_embedder import BasicTextFieldEmbedder from allennlp.modules.token_embedders.embedding import Embedding from allennlp.modules.seq2vec_encoders.cnn_encoder import CnnEncoder from allennlp.models.archival import archive_model, load_archive from allennlp_models.rc.modules.seq2seq_encoders.stacked_self_attention import StackedSelfAttentionEncoder from allennlp.predictors.predictor import Predictor from allennlp.predictors.text_classifier import TextClassifierPredictor import gradio as gr import numpy as np import pandas as pd import torch from tqdm import tqdm from project_settings import project_path from toolbox.allennlp_models.text_classifier.models.hierarchical_text_classifier import HierarchicalClassifier from toolbox.allennlp_models.text_classifier.dataset_readers.hierarchical_classification_json import HierarchicalClassificationJsonReader def get_args(): parser = argparse.ArgumentParser() parser.add_argument( "--excel_file", default=r"D:\Users\tianx\PycharmProjects\telemarketing_intent\data\excel\telemarketing_intent_vi.xlsx", type=str, ) parser.add_argument( "--archive_file", default=(project_path / "trained_models/telemarketing_intent_classification_vi").as_posix(), type=str ) parser.add_argument( "--predictor_name", default="text_classifier", type=str ) parser.add_argument( "--top_k", default=10, type=int ) parser.add_argument( "--output_file", default="intent_top_k.jsonl", type=str ) args = parser.parse_args() return args def main(): args = get_args() archive = load_archive(archive_file=args.archive_file) predictor = Predictor.from_archive(archive, predictor_name=args.predictor_name) df = pd.read_excel(args.excel_file) with open(args.output_file, "w", encoding="utf-8") as f: for i, row in tqdm(df.iterrows(), total=len(df)): if i < 26976: continue source = row["source"] text = row["text"] label0 = row["label0"] label1 = row["label1"] selected = row["selected"] checked = row["checked"] if pd.isna(source) or source is None: source = None if pd.isna(text) or text is None: continue text = str(text) if pd.isna(label0) or label0 is None: label0 = None if pd.isna(label1) or label1 is None: label1 = None if pd.isna(selected) or selected is None: selected = None else: try: selected = int(selected) except Exception: print(type(selected)) selected = None if pd.isna(checked) or checked is None: checked = None else: try: checked = int(checked) except Exception: print(type(checked)) checked = None # print(text) json_dict = {'sentence': text} outputs = predictor.predict_json( json_dict ) probs = outputs["probs"] arg_idx = np.argsort(probs) arg_idx_top_k = arg_idx[-10:] label_top_k = [ predictor._model.vocab.get_token_from_index(index=idx, namespace="labels").split("_")[-1] for idx in arg_idx_top_k ] prob_top_k = [ str(round(probs[idx], 5)) for idx in arg_idx_top_k ] row_ = { "source": source, "text": text, "label0": label0, "label1": label1, "selected": selected, "checked": checked, "predict_label_top_k": ";".join(list(reversed(label_top_k))), "predict_prob_top_k": ";".join(list(reversed(prob_top_k))) } row_ = json.dumps(row_, ensure_ascii=False) f.write("{}\n".format(row_)) return if __name__ == '__main__': main()