File size: 5,978 Bytes
e287bc1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 |
import os
from dataclasses import dataclass
from typing import List, Union
import yaml
from dataclasses_json import DataClassJsonMixin
from aidisdk import AIDIClient
from aidisdk.experiment import Image, Table
@dataclass
class RunConfig(DataClassJsonMixin):
"""Dataclass for config of run."""
endpoint: str
token: str
group_name: str
images_dataset_id: str
gt_dataset_id: str
labels_dataset_id: str
predictions_dataset_id: str
prediction_name: str
setting_file_name: str
@dataclass
class DetectionEvalConfig(DataClassJsonMixin):
"""Dataclass for config of evaluation."""
images_dir: str # 评测集图片文件夹路径
gt: str # gt数据文件路径
prediction: str # 评测预测结果文件路径
setting: str # 评测配置
@dataclass
class SemanticSegmentationEvalConfig(DataClassJsonMixin):
"""Dataclass for config of evaluation."""
images_dir: str # 评测原图文件夹
labels_dir: str # gt图片文件夹
prediction_dir: str # 评测预测图片文件夹
setting: str # 评测配置
images_json: str # 评测原图json文件
@dataclass
class EvalResult(DataClassJsonMixin):
"""Dataclass for result of evaluation."""
summary: dict # 评测结果summary类型数据 ap ar
tables: List[Table] # 评测结果table类型数据
plots: List[dict] # 评测结果plot类型数据
images: List[Image] # 评测结果images类型数据
class BaseEvaluation:
# 初始化
def __init__(self, run_config: RunConfig):
self.run_config = run_config
self.client = AIDIClient(endpoint=run_config.endpoint)
def get_data(self, dataset_id_info: str, file_type: str) -> str:
dataset_version = dataset_id_info.split("://")[0]
dataset_id = int(dataset_id_info.split("://")[1])
if dataset_version == "dataset":
dataset_interface = self.client.dataset.load(dataset_id)
data_path_list = dataset_interface.file_list(download=True)
if file_type == "gt" and len(data_path_list):
gt_path = data_path_list[0]
return gt_path
elif file_type == "images_dir" and len(data_path_list):
dir_name = os.path.dirname(data_path_list[0])
return dir_name
else:
raise NotImplementedError
else:
raise ValueError("dataset version not supported")
# detection前处理操作,从data获取评测集、预测结果、评测配置等信息
def detection_preprocess(self) -> DetectionEvalConfig:
# 从aidi-data获取gt
images_dir = self.get_data(
self.run_config.images_dataset_id, "images_dir"
)
gt_path = self.get_data(self.run_config.gt_dataset_id, "gt")
self.client.experiment.init_group(self.run_config.group_name)
# 从实验管理artifact获取pr结果
pr_name = self.run_config.prediction_name.split("/")[0]
file_name = self.run_config.prediction_name.split("/")[1]
pr_file = self.client.experiment.use_artifact(name=pr_name).get_file(
name=file_name
)
# 评测setting数据上传
with open(self.run_config.setting_file_name, "r") as fid:
cfg_dict = yaml.load(fid, Loader=yaml.Loader)
self.client.experiment.log_config(cfg_dict)
return DetectionEvalConfig(
images_dir=images_dir,
gt=gt_path,
prediction=pr_file,
setting=self.run_config.setting_file_name,
)
# semantic_segmentation前处理操作,从data获取评测集、预测结果、评测配置等信息
def semantic_segmentation_preprocess(
self,
) -> SemanticSegmentationEvalConfig:
# 从aidi-data获取images、labels、prediction
images_dir = self.get_data(
self.run_config.images_dataset_id, "images_dir"
)
labels_dir = self.get_data(
self.run_config.labels_dataset_id, "images_dir"
)
prediction_dir = self.get_data(
self.run_config.predictions_dataset_id, "images_dir"
)
if self.run_config.gt_dataset_id:
images_json_file = self.get_data(
self.run_config.gt_dataset_id, "images_dir"
)
else:
images_json_file = None
self.client.experiment.init_group(self.run_config.group_name)
# 评测setting数据上传
with open(self.run_config.setting_file_name, "r") as fid:
cfg_dict = yaml.load(fid, Loader=yaml.Loader)
self.client.experiment.log_config(cfg_dict)
return SemanticSegmentationEvalConfig(
images_dir=images_dir,
labels_dir=labels_dir,
prediction_dir=prediction_dir,
setting=self.run_config.setting_file_name,
images_json=images_json_file,
)
# 评测操作,从preprocess获取的信息进行评测
# 该方法需要具体实现类根据自身情况进行重写
def evaluate(
self,
eval_config: Union[
DetectionEvalConfig, SemanticSegmentationEvalConfig
],
) -> EvalResult:
# 评测操作
raise NotImplementedError
# 后处理操作,从evaluate获取的信息进行后处理
def postprocess(self, eval_result: EvalResult) -> None:
self.client.experiment.log_summary(eval_result.summary)
if eval_result.tables is not None:
for table in eval_result.tables:
self.client.experiment.log_table(table)
if eval_result.plots is not None:
for plot in eval_result.plots:
self.client.experiment.log_plot(
plot["Table"].name, plot["Table"], plot["Line"]
)
if eval_result.images is not None:
for image in eval_result.images:
self.client.experiment.log_image(image)
|