|
|
|
|
|
import argparse |
|
import json |
|
|
|
from allennlp.data.tokenizers.pretrained_transformer_tokenizer import PretrainedTransformerTokenizer |
|
from allennlp.data.token_indexers.single_id_token_indexer import SingleIdTokenIndexer |
|
from allennlp.data.vocabulary import Vocabulary |
|
from allennlp.modules.text_field_embedders.basic_text_field_embedder import BasicTextFieldEmbedder |
|
from allennlp.modules.token_embedders.embedding import Embedding |
|
from allennlp.modules.seq2vec_encoders.cnn_encoder import CnnEncoder |
|
from allennlp.models.archival import archive_model, load_archive |
|
from allennlp_models.rc.modules.seq2seq_encoders.stacked_self_attention import StackedSelfAttentionEncoder |
|
from allennlp.predictors.predictor import Predictor |
|
from allennlp.predictors.text_classifier import TextClassifierPredictor |
|
import gradio as gr |
|
import numpy as np |
|
import pandas as pd |
|
import torch |
|
from tqdm import tqdm |
|
|
|
from project_settings import project_path |
|
from toolbox.allennlp_models.text_classifier.models.hierarchical_text_classifier import HierarchicalClassifier |
|
from toolbox.allennlp_models.text_classifier.dataset_readers.hierarchical_classification_json import HierarchicalClassificationJsonReader |
|
|
|
|
|
def get_args(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
"--excel_file", |
|
default=r"D:\Users\tianx\PycharmProjects\telemarketing_intent\data\excel\telemarketing_intent_vi.xlsx", |
|
type=str, |
|
) |
|
parser.add_argument( |
|
"--archive_file", |
|
default=(project_path / "trained_models/telemarketing_intent_classification_vi").as_posix(), |
|
type=str |
|
) |
|
parser.add_argument( |
|
"--predictor_name", |
|
default="text_classifier", |
|
type=str |
|
) |
|
parser.add_argument( |
|
"--top_k", |
|
default=10, |
|
type=int |
|
) |
|
parser.add_argument( |
|
"--output_file", |
|
default="intent_top_k.jsonl", |
|
type=str |
|
) |
|
args = parser.parse_args() |
|
return args |
|
|
|
|
|
def main(): |
|
args = get_args() |
|
|
|
archive = load_archive(archive_file=args.archive_file) |
|
predictor = Predictor.from_archive(archive, predictor_name=args.predictor_name) |
|
|
|
df = pd.read_excel(args.excel_file) |
|
|
|
with open(args.output_file, "w", encoding="utf-8") as f: |
|
for i, row in tqdm(df.iterrows(), total=len(df)): |
|
if i < 26976: |
|
continue |
|
|
|
source = row["source"] |
|
text = row["text"] |
|
label0 = row["label0"] |
|
label1 = row["label1"] |
|
selected = row["selected"] |
|
checked = row["checked"] |
|
|
|
if pd.isna(source) or source is None: |
|
source = None |
|
|
|
if pd.isna(text) or text is None: |
|
continue |
|
text = str(text) |
|
|
|
if pd.isna(label0) or label0 is None: |
|
label0 = None |
|
|
|
if pd.isna(label1) or label1 is None: |
|
label1 = None |
|
|
|
if pd.isna(selected) or selected is None: |
|
selected = None |
|
else: |
|
try: |
|
selected = int(selected) |
|
except Exception: |
|
print(type(selected)) |
|
selected = None |
|
|
|
if pd.isna(checked) or checked is None: |
|
checked = None |
|
else: |
|
try: |
|
checked = int(checked) |
|
except Exception: |
|
print(type(checked)) |
|
checked = None |
|
|
|
|
|
json_dict = {'sentence': text} |
|
outputs = predictor.predict_json( |
|
json_dict |
|
) |
|
probs = outputs["probs"] |
|
arg_idx = np.argsort(probs) |
|
|
|
arg_idx_top_k = arg_idx[-10:] |
|
label_top_k = [ |
|
predictor._model.vocab.get_token_from_index(index=idx, namespace="labels").split("_")[-1] for idx in arg_idx_top_k |
|
] |
|
prob_top_k = [ |
|
str(round(probs[idx], 5)) for idx in arg_idx_top_k |
|
] |
|
|
|
row_ = { |
|
"source": source, |
|
"text": text, |
|
"label0": label0, |
|
"label1": label1, |
|
"selected": selected, |
|
"checked": checked, |
|
"predict_label_top_k": ";".join(list(reversed(label_top_k))), |
|
"predict_prob_top_k": ";".join(list(reversed(prob_top_k))) |
|
} |
|
row_ = json.dumps(row_, ensure_ascii=False) |
|
f.write("{}\n".format(row_)) |
|
|
|
return |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|