test1 / job (2) /code /base.py
ehovel2023's picture
Upload 11 files
e287bc1
raw
history blame
No virus
5.98 kB
import os
from dataclasses import dataclass
from typing import List, Union
import yaml
from dataclasses_json import DataClassJsonMixin
from aidisdk import AIDIClient
from aidisdk.experiment import Image, Table
@dataclass
class RunConfig(DataClassJsonMixin):
"""Dataclass for config of run."""
endpoint: str
token: str
group_name: str
images_dataset_id: str
gt_dataset_id: str
labels_dataset_id: str
predictions_dataset_id: str
prediction_name: str
setting_file_name: str
@dataclass
class DetectionEvalConfig(DataClassJsonMixin):
"""Dataclass for config of evaluation."""
images_dir: str # 评测集图片文件夹路径
gt: str # gt数据文件路径
prediction: str # 评测预测结果文件路径
setting: str # 评测配置
@dataclass
class SemanticSegmentationEvalConfig(DataClassJsonMixin):
"""Dataclass for config of evaluation."""
images_dir: str # 评测原图文件夹
labels_dir: str # gt图片文件夹
prediction_dir: str # 评测预测图片文件夹
setting: str # 评测配置
images_json: str # 评测原图json文件
@dataclass
class EvalResult(DataClassJsonMixin):
"""Dataclass for result of evaluation."""
summary: dict # 评测结果summary类型数据 ap ar
tables: List[Table] # 评测结果table类型数据
plots: List[dict] # 评测结果plot类型数据
images: List[Image] # 评测结果images类型数据
class BaseEvaluation:
# 初始化
def __init__(self, run_config: RunConfig):
self.run_config = run_config
self.client = AIDIClient(endpoint=run_config.endpoint)
def get_data(self, dataset_id_info: str, file_type: str) -> str:
dataset_version = dataset_id_info.split("://")[0]
dataset_id = int(dataset_id_info.split("://")[1])
if dataset_version == "dataset":
dataset_interface = self.client.dataset.load(dataset_id)
data_path_list = dataset_interface.file_list(download=True)
if file_type == "gt" and len(data_path_list):
gt_path = data_path_list[0]
return gt_path
elif file_type == "images_dir" and len(data_path_list):
dir_name = os.path.dirname(data_path_list[0])
return dir_name
else:
raise NotImplementedError
else:
raise ValueError("dataset version not supported")
# detection前处理操作,从data获取评测集、预测结果、评测配置等信息
def detection_preprocess(self) -> DetectionEvalConfig:
# 从aidi-data获取gt
images_dir = self.get_data(
self.run_config.images_dataset_id, "images_dir"
)
gt_path = self.get_data(self.run_config.gt_dataset_id, "gt")
self.client.experiment.init_group(self.run_config.group_name)
# 从实验管理artifact获取pr结果
pr_name = self.run_config.prediction_name.split("/")[0]
file_name = self.run_config.prediction_name.split("/")[1]
pr_file = self.client.experiment.use_artifact(name=pr_name).get_file(
name=file_name
)
# 评测setting数据上传
with open(self.run_config.setting_file_name, "r") as fid:
cfg_dict = yaml.load(fid, Loader=yaml.Loader)
self.client.experiment.log_config(cfg_dict)
return DetectionEvalConfig(
images_dir=images_dir,
gt=gt_path,
prediction=pr_file,
setting=self.run_config.setting_file_name,
)
# semantic_segmentation前处理操作,从data获取评测集、预测结果、评测配置等信息
def semantic_segmentation_preprocess(
self,
) -> SemanticSegmentationEvalConfig:
# 从aidi-data获取images、labels、prediction
images_dir = self.get_data(
self.run_config.images_dataset_id, "images_dir"
)
labels_dir = self.get_data(
self.run_config.labels_dataset_id, "images_dir"
)
prediction_dir = self.get_data(
self.run_config.predictions_dataset_id, "images_dir"
)
if self.run_config.gt_dataset_id:
images_json_file = self.get_data(
self.run_config.gt_dataset_id, "images_dir"
)
else:
images_json_file = None
self.client.experiment.init_group(self.run_config.group_name)
# 评测setting数据上传
with open(self.run_config.setting_file_name, "r") as fid:
cfg_dict = yaml.load(fid, Loader=yaml.Loader)
self.client.experiment.log_config(cfg_dict)
return SemanticSegmentationEvalConfig(
images_dir=images_dir,
labels_dir=labels_dir,
prediction_dir=prediction_dir,
setting=self.run_config.setting_file_name,
images_json=images_json_file,
)
# 评测操作,从preprocess获取的信息进行评测
# 该方法需要具体实现类根据自身情况进行重写
def evaluate(
self,
eval_config: Union[
DetectionEvalConfig, SemanticSegmentationEvalConfig
],
) -> EvalResult:
# 评测操作
raise NotImplementedError
# 后处理操作,从evaluate获取的信息进行后处理
def postprocess(self, eval_result: EvalResult) -> None:
self.client.experiment.log_summary(eval_result.summary)
if eval_result.tables is not None:
for table in eval_result.tables:
self.client.experiment.log_table(table)
if eval_result.plots is not None:
for plot in eval_result.plots:
self.client.experiment.log_plot(
plot["Table"].name, plot["Table"], plot["Line"]
)
if eval_result.images is not None:
for image in eval_result.images:
self.client.experiment.log_image(image)