test1 / job (2) /code /detection3d.py
ehovel2023's picture
Upload 11 files
e287bc1
raw history blame
No virus
6.32 kB
import json
import math
import os
import shutil
from code.base import BaseEvaluation, DetectionEvalConfig, EvalResult
from typing import List
import cv2
import numpy as np
from hat.evaluation.detection3d import evaluate
from hat.visualize.detection3d.draw_samples import (
draw_sample,
list_failure_samples,
)
from aidisdk.experiment import Image, Line, Table
def generate_plot(file_name: str) -> List[dict]:
plots = []
results = json.load(open(file_name, "rb"))
recall = results["recall"]
precision = results["precision"]
fppi = results["fppi"]
tab1_data = []
for idx, _item in enumerate(recall):
data_dict = {"recall": recall[idx], "precision": precision[idx]}
tab1_data.append(data_dict)
tab2_data = []
for idx, _item in enumerate(fppi):
data_dict = {"fppi": fppi[idx], "recall": recall[idx]}
tab2_data.append(data_dict)
table1 = Table(
name="recall_vs_precision-{}".format(
file_name.split("/")[-1].split(".")[0]
),
columns=["recall", "precision"],
data=tab1_data,
)
table2 = Table(
name="fppi_vs_recall-{}".format(
file_name.split("/")[-1].split(".")[0]
),
columns=["fppi", "recall"],
data=tab2_data,
)
plot1 = {
"Table": table1,
"Line": Line(x="recall", y="precision", stroke="recall-precision"),
}
plot2 = {
"Table": table2,
"Line": Line(x="fppi", y="recall", stroke="fppi-recall"),
}
plots.append(plot1)
plots.append(plot2)
return plots
class Detection3dEval(BaseEvaluation):
def __init__(self, run_config):
super().__init__(run_config)
def preprocess(self) -> DetectionEvalConfig:
return super().detection_preprocess()
def evaluate(self, eval_config: DetectionEvalConfig) -> EvalResult:
if os.path.exists("outputs"):
shutil.rmtree("outputs")
os.makedirs("outputs", exist_ok=True)
results = evaluate(
eval_config.gt,
eval_config.prediction,
eval_config.setting,
"outputs",
)
summary = {}
for result in results:
for key, item in result.items():
val = item
summary[key] = val
tables = []
results = json.load(open("outputs/tables.json", "rb"))
for result in results:
data = []
for dict_data in result["data"]:
new_dict_data = {}
for k, v in dict_data.items():
if type(v) == float and math.isnan(v):
v = "nan"
if type(v) in [list, tuple, set]:
v = str(v)
new_dict_data[k] = v
data.append(new_dict_data)
table = Table(
name=result["name"],
columns=result["header"],
data=data,
)
tables.append(table)
plots = []
if os.path.exists("outputs/result.json"):
plots_1 = generate_plot("outputs/result.json")
plots.extend(plots_1)
if os.path.exists("outputs/result_auto.json"):
plots_2 = generate_plot("outputs/result_auto.json")
plots.extend(plots_2)
images = []
samples = list_failure_samples(open("outputs/all.json", "rb"))
if os.path.exists("outputs/samples"):
shutil.rmtree("outputs/samples")
os.makedirs("outputs/samples", exist_ok=True)
for sample in samples:
fp_score = max(
[
det["score"]
for det in list(
filter(
lambda det: det["eval_type"] == "FP"
or det["eval_type"] == "ignore",
sample["det_bboxes"],
)
)
]
+ [-1]
)
tp_score = max(
[
det["score"]
for det in list(
filter(
lambda det: det["eval_type"] == "TP",
sample["det_bboxes"],
)
)
]
+ [-1]
)
tp_drot = max(
[
det["metrics"]["drot"]
for det in list(
filter(
lambda det: det["eval_type"] == "TP",
sample["det_bboxes"],
)
)
]
+ [-1]
)
tp_dxy = max(
[
det["metrics"]["dxy"]
for det in list(
filter(
lambda det: det["eval_type"] == "TP",
sample["det_bboxes"],
)
)
]
+ [-1]
)
image_name = sample["image_key"]
image_file_path = os.path.join(eval_config.images_dir, image_name)
output_file_path = os.path.join("outputs/samples/", image_name)
if os.path.exists(image_file_path):
with open(image_file_path, "rb") as image_file:
image_content = image_file.read()
npar = np.fromstring(image_content, dtype="uint8")
image = cv2.imdecode(npar, 1)
image = draw_sample(image, sample)
cv2.imwrite(output_file_path, image)
image = Image(
image_name,
attrs={
"fp_score": fp_score,
"tp_score": tp_score,
"tp_drot": tp_drot,
"tp_dxy": tp_dxy,
},
)
image.add_slice(data_or_path=output_file_path)
images.append(image)
eval_result = EvalResult(
summary=summary,
tables=tables,
plots=plots,
images=images,
)
return eval_result