VEGA_AE / Scripts /Exp /Acc /gen_accuracy.py
unknown
Initial
33264ad
raw
history blame
8.6 kB
import os, sys,string,re,glob
import json
import csv
import copy
import pathlib
import time
folder = str(pathlib.Path(__file__).parent.resolve())
#Total_encas
func_num_dic = {
"riscv" : 568,
"pulp" : 698,
"xcore" : 188
}
wrong_lis_all = []
wrong_stmt = []
err_def_dic = {}
def get_wrong_list():
global wrong_stmt
global wrong_lis_all
global err_def_dic
with open(folder+"/wrong_func_list_def.csv", 'r', encoding='utf-8') as fcsv:
reader = csv.reader(fcsv)
for row in reader:
if row[0] == "idx":
continue
wrong_stmt.append(row[0].strip().lower() + " " + row[1].strip().lower() + " " + row[2].strip().lower())
wrong_lis_all.append(" ".join(row))
if " ".join([row[2], row[3]]) not in err_def_dic.keys():
err_def_dic[" ".join([row[2], row[3]])]= 1
else:
err_def_dic[" ".join([row[2], row[3]])] += 1
def calculate_accuracy():
func_res = {}
stable_stmt_dic = {}
all_func_lis = []
global wrong_stmt
global wrong_lis_all
global err_def_dic
total_dic = {}
wrong_dic = {}
asm_file = []
for line in open(folder+"/result.jsonl", 'r', encoding="utf-8"):
dic = json.loads(line)
all_func_lis.append(dic["File"].strip().lower() + " " + dic["Func"].strip().lower() + " " + dic["Target"].strip().lower())
if int(dic["vega_pre"]) == 1:
if " ".join([dic["Target"], dic["Module"], dic["File"], dic["Func"]]) not in func_res.keys():
func_res[" ".join([dic["Target"], dic["Module"], dic["File"], dic["Func"]])] = [dic["vega_code"].replace("zmtarzm", dic["Target"])]
else:
func_res[" ".join([dic["Target"], dic["Module"], dic["File"], dic["Func"]])].append(dic["vega_code"].replace("zmtarzm", dic["Target"]))
if " ".join([dic["Target"], dic["Module"]]) not in stable_stmt_dic.keys():
stable_stmt_dic[" ".join([dic["Target"], dic["Module"]])] = [0, 0]
if dic["Stable"].lower() == "true":
stable_stmt_dic[" ".join([dic["Target"], dic["Module"]])][0] += 1
stable_stmt_dic[" ".join([dic["Target"], dic["Module"]])][1] += 1
else:
stable_stmt_dic[" ".join([dic["Target"], dic["Module"]])][1] += 1
if dic["Target"] + " " + dic["Module"] not in total_dic.keys():
total_dic[dic["Target"] + " " + dic["Module"]] = [dic["File"].strip() + " " + dic["Func"].strip() + " " + dic["Target"].strip()]
wrong_dic[dic["Target"] + " " + dic["Module"]] = []
else:
total_dic[dic["Target"] + " " + dic["Module"]].append(dic["File"].strip() + " " + dic["Func"].strip() + " " + dic["Target"].strip())
if dic["File"].strip().lower() + " " + dic["Func"].strip().lower() + " " + dic["Target"].strip().lower() in wrong_stmt:
#print(dic["File"].strip() + " " + dic["Func"].strip() + " " + dic["Target"].strip())
wrong_dic[dic["Target"] + " " + dic["Module"]].append(dic["File"].strip() + " " + dic["Func"].strip() + " " + dic["Target"].strip())
if dic["ans_code"].replace(" ", "") != dic["vega_code"].replace(" ", "") or dic["ans_pre"] != dic["vega_pre"]:
wrong_dic[dic["Target"] + " " + dic["Module"]].append(dic["File"].strip() + " " + dic["Func"].strip() + " " + dic["Target"].strip())
if dic["ans_code"] != dic["vega_code"]:
wrong_lis_all.append(" ".join([dic["File"], dic["Func"], dic["Target"], dic["Module"], "Err_V"]))
if dic["ans_pre"] != dic["vega_pre"]:
wrong_lis_all.append(" ".join([dic["File"], dic["Func"], dic["Target"], dic["Module"], "Err_CS"]))
all_func_lis = list(set(all_func_lis))
with open(folder+"/Fig8_Acc.csv", 'a', encoding='utf-8', newline="") as f:
f_csv = csv.writer(f)
avg_dic = {}
all_dic = {}
for k in total_dic.keys():
Correct_Func_Num = len(list(set(total_dic[k])))-len(list(set(wrong_dic[k])))
Total_Func_Num = len(list(set(total_dic[k])))
Accuracy_Func = 1-round(len(list(set(wrong_dic[k]))) * 1.0 / len(list(set(total_dic[k]))), 3)
Wrong_Func_Percentage = round(len(list(set(wrong_dic[k]))) * 1.0 / len(list(set(total_dic[k]))), 3)
Pre_Equal_1_Stmt_Percentage = round(stable_stmt_dic[k][0]/stable_stmt_dic[k][1], 3)
Pre_Less_1_Stmt_Percentage = 1 - round(stable_stmt_dic[k][0]/stable_stmt_dic[k][1], 3)
if k.split(" ")[0] not in avg_dic.keys():
avg_dic[k.split(" ")[0]] = Accuracy_Func
all_dic[k.split(" ")[0]] = Correct_Func_Num
else:
avg_dic[k.split(" ")[0]] += Accuracy_Func
all_dic[k.split(" ")[0]] += Correct_Func_Num
tem_k = k.replace("PULP", "RI5CY")
f_csv.writerow(tem_k.split(" ") + [Correct_Func_Num, Total_Func_Num, Accuracy_Func, Wrong_Func_Percentage, Pre_Equal_1_Stmt_Percentage, Pre_Less_1_Stmt_Percentage])
for k in avg_dic:
if k.lower() == "riscv":
f_csv.writerow([k, "AVG", round(avg_dic[k] / 7.0, 3)])
f_csv.writerow([k, "ALL", round(all_dic[k] / func_num_dic[k.lower()], 3)])
elif k.lower() == "pulp":
f_csv.writerow(["RI5CY", "AVG", round(avg_dic[k] / 7.0, 3)])
f_csv.writerow(["RI5CY", "ALL", round(all_dic[k] / func_num_dic[k.lower()], 3)])
else:
f_csv.writerow([k, "AVG", round(avg_dic[k] / 6.0, 3)])
f_csv.writerow([k, "ALL", round(all_dic[k] / func_num_dic[k.lower()], 3)])
for k in func_res.keys():
Tar_Path = folder + "/../ForkFlow/VEGA_Code/" + "/".join(k.split(" ")) + ".cpp"
Tar_Path = Tar_Path.replace("enum/NodeType", "enum NodeType")
Tar_Path = Tar_Path.replace("enum/CondCode", "enum CondCode")
Tar_Path = Tar_Path.replace("ExpandSSRInsts/ExpandPseudo", "ExpandSSRInsts/ExpandSSRInsts")
if os.path.exists(Tar_Path):
with open(Tar_Path, 'w') as file:
for idx, l in enumerate(func_res[k]):
if idx < len(func_res[k])-1:
file.write(l)
file.write("\n")
else:
file.write(l)
else:
print(Tar_Path)
return total_dic
if __name__ == '__main__':
get_wrong_list()
with open(folder+"/Fig8_Acc.csv", 'w', encoding='utf-8', newline="") as f:
f_csv = csv.writer(f)
f_csv.writerow(["Target", "Module", "Correct", "Total", "Accurate", "Inaccurate", "Confidence Score≈1.00", "Confidence Score in [0.50, 1.00)"])
total_dic = calculate_accuracy()
wrong_lis_all = list(set(wrong_lis_all))
with open(folder+"/wrong_list_all.csv", 'w', encoding='utf-8', newline="") as f:
f_csv = csv.writer(f)
for err in wrong_lis_all:
f_csv.writerow(err.split(" "))
with open(folder+"/../ForkFlow/wrong_list_all.csv", 'w', encoding='utf-8', newline="") as f:
f_csv = csv.writer(f)
for err in wrong_lis_all:
f_csv.writerow(err.split(" "))
wrong_dic = {}
with open(folder+"/wrong_list_all.csv", 'r', encoding='utf-8') as f:
f_csv = csv.reader(f)
for row in f_csv:
if " ".join([row[-3].lower(), row[-1].lower()]) not in wrong_dic.keys():
wrong_dic[" ".join([row[-3].lower(), row[-1].lower()])] = 1
else:
wrong_dic[" ".join([row[-3].lower(), row[-1].lower()])] += 1
#print(wrong_dic)
target_func_num_dic = {}
for k in total_dic:
if k.split(" ")[0].lower() not in target_func_num_dic:
target_func_num_dic[k.split(" ")[0].lower()] = len(list(set(total_dic[k])))
else:
target_func_num_dic[k.split(" ")[0].lower()] += len(list(set(total_dic[k])))
with open(folder+"/Table2.csv", 'w', encoding='utf-8', newline = "") as f:
f_csv = csv.writer(f)
for k in target_func_num_dic:
#print(target_func_num_dic[k])
for err_type in ["err_v", "err_cs", "err_def"]:
if k + " " + err_type in wrong_dic.keys():
f_csv.writerow([k.replace("pulp", "ri5cy"), err_type, round(float(wrong_dic[k + " " + err_type]) / float(target_func_num_dic[k]), 3)])
else:
f_csv.writerow([k.replace("pulp", "ri5cy"), err_type, 0])