File size: 5,978 Bytes
e287bc1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
import os
from dataclasses import dataclass
from typing import List, Union

import yaml
from dataclasses_json import DataClassJsonMixin

from aidisdk import AIDIClient
from aidisdk.experiment import Image, Table


@dataclass
class RunConfig(DataClassJsonMixin):
    """Dataclass for config of run."""

    endpoint: str
    token: str
    group_name: str
    images_dataset_id: str
    gt_dataset_id: str
    labels_dataset_id: str
    predictions_dataset_id: str
    prediction_name: str
    setting_file_name: str


@dataclass
class DetectionEvalConfig(DataClassJsonMixin):
    """Dataclass for config of evaluation."""

    images_dir: str  # 评测集图片文件夹路径
    gt: str  # gt数据文件路径
    prediction: str  # 评测预测结果文件路径
    setting: str  # 评测配置


@dataclass
class SemanticSegmentationEvalConfig(DataClassJsonMixin):
    """Dataclass for config of evaluation."""

    images_dir: str  # 评测原图文件夹
    labels_dir: str  # gt图片文件夹
    prediction_dir: str  # 评测预测图片文件夹
    setting: str  # 评测配置
    images_json: str  # 评测原图json文件


@dataclass
class EvalResult(DataClassJsonMixin):
    """Dataclass for result of evaluation."""

    summary: dict  # 评测结果summary类型数据 ap ar
    tables: List[Table]  # 评测结果table类型数据
    plots: List[dict]  # 评测结果plot类型数据
    images: List[Image]  # 评测结果images类型数据


class BaseEvaluation:
    # 初始化
    def __init__(self, run_config: RunConfig):
        self.run_config = run_config
        self.client = AIDIClient(endpoint=run_config.endpoint)

    def get_data(self, dataset_id_info: str, file_type: str) -> str:
        dataset_version = dataset_id_info.split("://")[0]
        dataset_id = int(dataset_id_info.split("://")[1])
        if dataset_version == "dataset":
            dataset_interface = self.client.dataset.load(dataset_id)
            data_path_list = dataset_interface.file_list(download=True)
            if file_type == "gt" and len(data_path_list):
                gt_path = data_path_list[0]
                return gt_path
            elif file_type == "images_dir" and len(data_path_list):
                dir_name = os.path.dirname(data_path_list[0])
                return dir_name
            else:
                raise NotImplementedError
        else:
            raise ValueError("dataset version not supported")

    # detection前处理操作,从data获取评测集、预测结果、评测配置等信息
    def detection_preprocess(self) -> DetectionEvalConfig:
        # 从aidi-data获取gt
        images_dir = self.get_data(
            self.run_config.images_dataset_id, "images_dir"
        )
        gt_path = self.get_data(self.run_config.gt_dataset_id, "gt")

        self.client.experiment.init_group(self.run_config.group_name)

        # 从实验管理artifact获取pr结果
        pr_name = self.run_config.prediction_name.split("/")[0]
        file_name = self.run_config.prediction_name.split("/")[1]
        pr_file = self.client.experiment.use_artifact(name=pr_name).get_file(
            name=file_name
        )

        # 评测setting数据上传
        with open(self.run_config.setting_file_name, "r") as fid:
            cfg_dict = yaml.load(fid, Loader=yaml.Loader)
        self.client.experiment.log_config(cfg_dict)

        return DetectionEvalConfig(
            images_dir=images_dir,
            gt=gt_path,
            prediction=pr_file,
            setting=self.run_config.setting_file_name,
        )

    # semantic_segmentation前处理操作,从data获取评测集、预测结果、评测配置等信息
    def semantic_segmentation_preprocess(
        self,
    ) -> SemanticSegmentationEvalConfig:
        # 从aidi-data获取images、labels、prediction
        images_dir = self.get_data(
            self.run_config.images_dataset_id, "images_dir"
        )
        labels_dir = self.get_data(
            self.run_config.labels_dataset_id, "images_dir"
        )
        prediction_dir = self.get_data(
            self.run_config.predictions_dataset_id, "images_dir"
        )
        if self.run_config.gt_dataset_id:
            images_json_file = self.get_data(
                self.run_config.gt_dataset_id, "images_dir"
            )
        else:
            images_json_file = None

        self.client.experiment.init_group(self.run_config.group_name)

        # 评测setting数据上传
        with open(self.run_config.setting_file_name, "r") as fid:
            cfg_dict = yaml.load(fid, Loader=yaml.Loader)
        self.client.experiment.log_config(cfg_dict)

        return SemanticSegmentationEvalConfig(
            images_dir=images_dir,
            labels_dir=labels_dir,
            prediction_dir=prediction_dir,
            setting=self.run_config.setting_file_name,
            images_json=images_json_file,
        )

    # 评测操作,从preprocess获取的信息进行评测
    # 该方法需要具体实现类根据自身情况进行重写
    def evaluate(
        self,
        eval_config: Union[
            DetectionEvalConfig, SemanticSegmentationEvalConfig
        ],
    ) -> EvalResult:
        # 评测操作
        raise NotImplementedError

    # 后处理操作,从evaluate获取的信息进行后处理
    def postprocess(self, eval_result: EvalResult) -> None:
        self.client.experiment.log_summary(eval_result.summary)
        if eval_result.tables is not None:
            for table in eval_result.tables:
                self.client.experiment.log_table(table)
        if eval_result.plots is not None:
            for plot in eval_result.plots:
                self.client.experiment.log_plot(
                    plot["Table"].name, plot["Table"], plot["Line"]
                )
        if eval_result.images is not None:
            for image in eval_result.images:
                self.client.experiment.log_image(image)