| import os |
| import prettytable as pt |
|
|
| from evaluation.metrics import evaluator |
| from config import Config |
|
|
|
|
| config = Config() |
|
|
| def evaluate(pred_dir, method, testset, only_S_MAE=False, epoch=0): |
| filename = os.path.join('evaluation', 'eval-{}.txt'.format(method)) |
| if os.path.exists(filename): |
| id_suffix = 1 |
| filename = filename.rstrip('.txt') + '_{}.txt'.format(id_suffix) |
| while os.path.exists(filename): |
| id_suffix += 1 |
| filename = filename.replace('_{}.txt'.format(id_suffix-1), '_{}.txt'.format(id_suffix)) |
| gt_paths = sorted([ |
| os.path.join(config.data_root_dir, config.task, testset, 'gt', p) |
| for p in os.listdir(os.path.join(config.data_root_dir, config.task, testset, 'gt')) |
| ]) |
| pred_paths = sorted([os.path.join(pred_dir, method, testset, p) for p in os.listdir(os.path.join(pred_dir, method, testset))]) |
| with open(filename, 'a+') as file_to_write: |
| tb = pt.PrettyTable() |
| field_names = [ |
| "Dataset", "Method", "maxFm", "wFmeasure", 'MAE', "Smeasure", "meanEm", "maxEm", "meanFm", |
| "adpEm", "adpFm", 'HCE' |
| ] |
| tb.field_names = [name for name in field_names if not only_S_MAE or all(metric not in name for metric in ['Em', 'Fm'])] |
| em, sm, fm, mae, wfm, hce = evaluator( |
| gt_paths=gt_paths[:], |
| pred_paths=pred_paths[:], |
| metrics=['S', 'MAE', 'E', 'F', 'HCE'][:10*(not only_S_MAE) + 2], |
| verbose=config.verbose_eval, |
| ) |
| e_max, e_mean, e_adp = em['curve'].max(), em['curve'].mean(), em['adp'].mean() |
| f_max, f_mean, f_wfm, f_adp = fm['curve'].max(), fm['curve'].mean(), wfm, fm['adp'] |
| tb.add_row( |
| [ |
| method+str(epoch), testset, f_max.round(3), f_wfm.round(3), mae.round(3), sm.round(3), |
| e_mean.round(3), e_max.round(3), f_mean.round(3), em['adp'].round(3), f_adp.round(3), hce.round(3) |
| ] if not only_S_MAE else [method, testset, mae.round(3), sm.round(3)] |
| ) |
| print(tb) |
| file_to_write.write(str(tb).replace('+', '|')+'\n') |
| file_to_write.close() |
| return {'e_max': e_max, 'e_mean': e_mean, 'e_adp': e_adp, 'sm': sm, 'mae': mae, 'f_max': f_max, 'f_mean': f_mean, 'f_wfm': f_wfm, 'f_adp': f_adp, 'hce': hce} |
|
|
|
|
| def main(): |
| only_S_MAE = False |
| pred_dir = '.' |
| method = 'tmp_val' |
| testsets = 'DIS-VD+DIS-TE1' |
| for testset in testsets.split('+'): |
| res_dct = evaluate(pred_dir, method, testset, only_S_MAE=only_S_MAE) |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|