import os from dataclasses import dataclass from typing import List, Union import yaml from dataclasses_json import DataClassJsonMixin from aidisdk import AIDIClient from aidisdk.experiment import Image, Table @dataclass class RunConfig(DataClassJsonMixin): """Dataclass for config of run.""" endpoint: str token: str group_name: str images_dataset_id: str gt_dataset_id: str labels_dataset_id: str predictions_dataset_id: str prediction_name: str setting_file_name: str @dataclass class DetectionEvalConfig(DataClassJsonMixin): """Dataclass for config of evaluation.""" images_dir: str # 评测集图片文件夹路径 gt: str # gt数据文件路径 prediction: str # 评测预测结果文件路径 setting: str # 评测配置 @dataclass class SemanticSegmentationEvalConfig(DataClassJsonMixin): """Dataclass for config of evaluation.""" images_dir: str # 评测原图文件夹 labels_dir: str # gt图片文件夹 prediction_dir: str # 评测预测图片文件夹 setting: str # 评测配置 images_json: str # 评测原图json文件 @dataclass class EvalResult(DataClassJsonMixin): """Dataclass for result of evaluation.""" summary: dict # 评测结果summary类型数据 ap ar tables: List[Table] # 评测结果table类型数据 plots: List[dict] # 评测结果plot类型数据 images: List[Image] # 评测结果images类型数据 class BaseEvaluation: # 初始化 def __init__(self, run_config: RunConfig): self.run_config = run_config self.client = AIDIClient(endpoint=run_config.endpoint) def get_data(self, dataset_id_info: str, file_type: str) -> str: dataset_version = dataset_id_info.split("://")[0] dataset_id = int(dataset_id_info.split("://")[1]) if dataset_version == "dataset": dataset_interface = self.client.dataset.load(dataset_id) data_path_list = dataset_interface.file_list(download=True) if file_type == "gt" and len(data_path_list): gt_path = data_path_list[0] return gt_path elif file_type == "images_dir" and len(data_path_list): dir_name = os.path.dirname(data_path_list[0]) return dir_name else: raise NotImplementedError else: raise ValueError("dataset version not supported") # detection前处理操作,从data获取评测集、预测结果、评测配置等信息 def detection_preprocess(self) -> DetectionEvalConfig: # 从aidi-data获取gt images_dir = self.get_data( self.run_config.images_dataset_id, "images_dir" ) gt_path = self.get_data(self.run_config.gt_dataset_id, "gt") self.client.experiment.init_group(self.run_config.group_name) # 从实验管理artifact获取pr结果 pr_name = self.run_config.prediction_name.split("/")[0] file_name = self.run_config.prediction_name.split("/")[1] pr_file = self.client.experiment.use_artifact(name=pr_name).get_file( name=file_name ) # 评测setting数据上传 with open(self.run_config.setting_file_name, "r") as fid: cfg_dict = yaml.load(fid, Loader=yaml.Loader) self.client.experiment.log_config(cfg_dict) return DetectionEvalConfig( images_dir=images_dir, gt=gt_path, prediction=pr_file, setting=self.run_config.setting_file_name, ) # semantic_segmentation前处理操作,从data获取评测集、预测结果、评测配置等信息 def semantic_segmentation_preprocess( self, ) -> SemanticSegmentationEvalConfig: # 从aidi-data获取images、labels、prediction images_dir = self.get_data( self.run_config.images_dataset_id, "images_dir" ) labels_dir = self.get_data( self.run_config.labels_dataset_id, "images_dir" ) prediction_dir = self.get_data( self.run_config.predictions_dataset_id, "images_dir" ) if self.run_config.gt_dataset_id: images_json_file = self.get_data( self.run_config.gt_dataset_id, "images_dir" ) else: images_json_file = None self.client.experiment.init_group(self.run_config.group_name) # 评测setting数据上传 with open(self.run_config.setting_file_name, "r") as fid: cfg_dict = yaml.load(fid, Loader=yaml.Loader) self.client.experiment.log_config(cfg_dict) return SemanticSegmentationEvalConfig( images_dir=images_dir, labels_dir=labels_dir, prediction_dir=prediction_dir, setting=self.run_config.setting_file_name, images_json=images_json_file, ) # 评测操作,从preprocess获取的信息进行评测 # 该方法需要具体实现类根据自身情况进行重写 def evaluate( self, eval_config: Union[ DetectionEvalConfig, SemanticSegmentationEvalConfig ], ) -> EvalResult: # 评测操作 raise NotImplementedError # 后处理操作,从evaluate获取的信息进行后处理 def postprocess(self, eval_result: EvalResult) -> None: self.client.experiment.log_summary(eval_result.summary) if eval_result.tables is not None: for table in eval_result.tables: self.client.experiment.log_table(table) if eval_result.plots is not None: for plot in eval_result.plots: self.client.experiment.log_plot( plot["Table"].name, plot["Table"], plot["Line"] ) if eval_result.images is not None: for image in eval_result.images: self.client.experiment.log_image(image)