|
''' |
|
Author: Qiguang Chen |
|
LastEditors: Qiguang Chen |
|
Date: 2023-01-23 17:26:47 |
|
LastEditTime: 2023-02-14 20:07:02 |
|
Description: |
|
|
|
''' |
|
import argparse |
|
import os |
|
import signal |
|
import sys |
|
|
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
|
|
|
import time |
|
from gradio import networking |
|
from common.utils import load_yaml, str2bool |
|
import json |
|
import threading |
|
|
|
from flask import Flask, request, render_template, render_template_string |
|
|
|
|
|
def get_example(start, end, predict_data_file_path): |
|
data_list = [] |
|
with open(predict_data_file_path, "r", encoding="utf8") as f1: |
|
for index, line1 in enumerate(f1): |
|
if index < start: |
|
continue |
|
if index > end: |
|
break |
|
line1 = json.loads(line1.strip()) |
|
obj = {"text": line1["text"]} |
|
obj["intent"] = [{"intent": line1["golden_intent"], |
|
"pred_intent": line1["pred_intent"]}] |
|
obj["slot"] = [{"text": t, "pred_slot": ps, "slot": s} for t, s, ps in zip( |
|
line1["text"], line1["pred_slot"], line1["golden_slot"])] |
|
data_list.append(obj) |
|
return data_list |
|
|
|
|
|
def analysis(predict_data_file_path): |
|
intent_dict = {} |
|
slot_dict = {} |
|
sample_num = 0 |
|
with open(predict_data_file_path, "r", encoding="utf8") as f1: |
|
for index, line1 in enumerate(f1): |
|
sample_num += 1 |
|
line1 = json.loads(line1.strip()) |
|
for s, ps in zip(line1["golden_slot"], line1["pred_slot"]): |
|
if s not in slot_dict: |
|
slot_dict[s] = {"_error_": 0, "_total_": 0} |
|
if s != ps: |
|
slot_dict[s]["_error_"] += 1 |
|
if ps not in slot_dict[s]: |
|
slot_dict[s][ps] = 0 |
|
slot_dict[s][ps] += 1 |
|
slot_dict[s]["_total_"] += 1 |
|
for i, pi in zip([line1["golden_intent"]], [line1["pred_intent"]]): |
|
if i not in intent_dict: |
|
intent_dict[i] = {"_error_": 0, "_total_": 0} |
|
if i != pi: |
|
intent_dict[i]["_error_"] += 1 |
|
if pi not in intent_dict[i]: |
|
intent_dict[i][pi] = 0 |
|
intent_dict[i][pi] += 1 |
|
intent_dict[i]["_total_"] += 1 |
|
intent_dict_list = [{"value": intent_dict[name]["_error_"], "name": name} for name in intent_dict] |
|
|
|
for intent in intent_dict_list: |
|
temp_intent = sorted( |
|
intent_dict[intent["name"]].items(), key=lambda d: d[1], reverse=True) |
|
|
|
temp_intent = [[key, value] for key, value in temp_intent] |
|
intent_dict[intent["name"]] = temp_intent |
|
slot_dict_list = [{"value": slot_dict[name]["_error_"], "name": name} for name in slot_dict] |
|
|
|
for slot in slot_dict_list: |
|
temp_slot = sorted( |
|
slot_dict[slot["name"]].items(), key=lambda d: d[1], reverse=True) |
|
temp_slot = [[key, value] for key, value in temp_slot] |
|
slot_dict[slot["name"]] = temp_slot |
|
return intent_dict_list, slot_dict_list, intent_dict, slot_dict, sample_num |
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--config_path', '-cp', type=str, default="config/visual.yaml") |
|
parser.add_argument('--output_path', '-op', type=str, default=None) |
|
parser.add_argument('--push_to_public', '-p', type=str2bool, nargs='?', |
|
const=True, default=None, |
|
help="Push to public network.(Higher priority than config file)") |
|
args = parser.parse_args() |
|
button_html = "" |
|
config = load_yaml(args.config_path) |
|
if args.output_path is not None: |
|
config["output_path"] = args.output_path |
|
if args.push_to_public is not None: |
|
config["is_push_to_public"] = args.push_to_public |
|
intent_dict_list, slot_dict_list, intent_dict, slot_dict, sample_num = analysis(config["output_path"]) |
|
PAGE_SIZE = config["page-size"] |
|
PAGE_NUM = int(sample_num / PAGE_SIZE) + 1 |
|
|
|
app = Flask(__name__, template_folder="static//template") |
|
|
|
|
|
@app.route("/") |
|
def hello(): |
|
page = request.args.get('page') |
|
if page is None: |
|
page = 0 |
|
page = int(page) if int(page) >= 0 else 0 |
|
init_index = page*PAGE_SIZE |
|
examples = get_example(init_index, init_index + |
|
PAGE_SIZE - 1, config["output_path"]) |
|
return render_template('visualization.html', |
|
examples=examples, |
|
intent_dict_list=intent_dict_list, |
|
slot_dict_list=slot_dict_list, |
|
intent_dict=intent_dict, |
|
slot_dict=slot_dict, |
|
page=page) |
|
|
|
thread_lock_1 = False |
|
|
|
|
|
|
|
|
|
class PushToPublicThread(): |
|
def __init__(self, config) -> None: |
|
self.thread = threading.Thread(target=self.push_to_public, args=(config,)) |
|
self.thread_lock_2 = False |
|
self.thread.daemon = True |
|
|
|
def start(self): |
|
self.thread.start() |
|
|
|
def push_to_public(self, config): |
|
print("Push visualization results to public by Gradio....") |
|
print("Push to URL: ", networking.setup_tunnel(config["host"], str(config["port"]))) |
|
print("This share link expires in 72 hours. And do not close this process for public sharing.") |
|
while not self.thread_lock_2: |
|
continue |
|
|
|
def exit(self, signum, frame): |
|
self.thread_lock_2 = True |
|
print("Exit..") |
|
os._exit(0) |
|
|
|
if __name__ == '__main__': |
|
|
|
if config["is_push_to_public"]: |
|
|
|
thread_1 = threading.Thread(target=lambda: app.run( |
|
config["host"], config["port"])) |
|
thread_1.start() |
|
thread_2 = PushToPublicThread(config) |
|
signal.signal(signal.SIGINT, thread_2.exit) |
|
signal.signal(signal.SIGTERM, thread_2.exit) |
|
thread_2.start() |
|
while True: |
|
time.sleep(1) |
|
else: |
|
app.run(config["host"], config["port"]) |
|
|