Spaces:
Running
Running
import csv | |
import os | |
import sys | |
__dir__ = os.path.dirname(os.path.abspath(__file__)) | |
sys.path.append(__dir__) | |
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..'))) | |
import numpy as np | |
from tools.data import build_dataloader | |
from tools.engine import Config, Trainer | |
from tools.utility import ArgsParser | |
def parse_args(): | |
parser = ArgsParser() | |
args = parser.parse_args() | |
return args | |
def main(): | |
FLAGS = parse_args() | |
cfg = Config(FLAGS.config) | |
FLAGS = vars(FLAGS) | |
opt = FLAGS.pop('opt') | |
cfg.merge_dict(FLAGS) | |
cfg.merge_dict(opt) | |
cfg.cfg['Global']['use_amp'] = False | |
if cfg.cfg['Global']['output_dir'][-1] == '/': | |
cfg.cfg['Global']['output_dir'] = cfg.cfg['Global']['output_dir'][:-1] | |
cfg.cfg['Global']['max_text_length'] = 200 | |
cfg.cfg['Architecture']['Decoder']['max_len'] = 200 | |
cfg.cfg['Metric']['name'] = 'RecMetricLong' | |
if cfg.cfg['Global']['pretrained_model'] is None: | |
cfg.cfg['Global'][ | |
'pretrained_model'] = cfg.cfg['Global']['output_dir'] + '/best.pth' | |
trainer = Trainer(cfg, mode='eval') | |
best_model_dict = trainer.status.get('metrics', {}) | |
trainer.logger.info('metric in ckpt ***************') | |
for k, v in best_model_dict.items(): | |
trainer.logger.info('{}:{}'.format(k, v)) | |
data_dirs_list = [ | |
[ | |
'../ltb/ultra_long_26_35_list.txt', | |
'../ltb/ultra_long_36_55_list.txt', | |
'../ltb/ultra_long_56_list.txt', | |
], | |
] | |
cfg = cfg.cfg | |
cfg['Eval']['dataset']['name'] = 'SimpleDataSet' | |
file_csv = open( | |
cfg['Global']['output_dir'] + '/' + | |
cfg['Global']['output_dir'].split('/')[-1] + | |
'_result1_1_test_all_long_simple_bi_bs1.csv', 'w') | |
csv_w = csv.writer(file_csv) | |
for data_dirs in data_dirs_list: | |
acc_each = [] | |
acc_each_num = [] | |
acc_each_dis = [] | |
each_long = {} | |
for datadir in data_dirs: | |
config_each = cfg.copy() | |
config_each['Eval']['dataset']['label_file_list'] = [datadir] | |
valid_dataloader = build_dataloader(config_each, 'Eval', | |
trainer.logger) | |
trainer.logger.info( | |
f'{datadir} valid dataloader has {len(valid_dataloader)} iters' | |
) | |
trainer.valid_dataloader = valid_dataloader | |
metric = trainer.eval() | |
acc_each.append(metric['acc'] * 100) | |
acc_each_dis.append(metric['norm_edit_dis']) | |
acc_each_num.append(metric['all_num']) | |
trainer.logger.info('metric eval ***************') | |
for k, v in metric.items(): | |
trainer.logger.info('{}:{}'.format(k, v)) | |
if 'each' in k: | |
csv_w.writerow([k] + v[26:]) | |
each_long[k] = each_long.get(k, []) + [np.array(v[26:])] | |
avg1 = np.array(acc_each) * np.array(acc_each_num) / sum(acc_each_num) | |
csv_w.writerow(acc_each + [avg1.sum().tolist()] + | |
[sum(acc_each) / len(acc_each)]) | |
print(acc_each + [avg1.sum().tolist()] + | |
[sum(acc_each) / len(acc_each)]) | |
avg1 = np.array(acc_each_dis) * np.array(acc_each_num) / sum( | |
acc_each_num) | |
csv_w.writerow(acc_each_dis + [avg1.sum().tolist()] + | |
[sum(acc_each_dis) / len(acc_each)]) | |
sum_all = np.array(each_long['each_len_num']).sum(0) | |
for k, v in each_long.items(): | |
if k != 'each_len_num': | |
v_sum_weight = (np.array(v) * | |
np.array(each_long['each_len_num'])).sum(0) | |
sum_all_pad = np.where(sum_all == 0, 1., sum_all) | |
v_all = v_sum_weight / sum_all_pad | |
v_all = np.where(sum_all == 0, 0., v_all) | |
csv_w.writerow([k] + v_all.tolist()) | |
v_26_40 = (v_all[:10] * sum_all[:10]) / sum_all[:10].sum() | |
csv_w.writerow([k + '26_35'] + [v_26_40.sum().tolist()] + | |
[sum_all[:10].sum().tolist()]) | |
v_41_55 = (v_all[10:30] * | |
sum_all[10:30]) / sum_all[10:30].sum() | |
csv_w.writerow([k + '36_55'] + [v_41_55.sum().tolist()] + | |
[sum_all[10:30].sum().tolist()]) | |
v_56_70 = (v_all[30:] * sum_all[30:]) / sum_all[30:].sum() | |
csv_w.writerow([k + '56'] + [v_56_70.sum().tolist()] + | |
[sum_all[30:].sum().tolist()]) | |
else: | |
csv_w.writerow([k] + sum_all.tolist()) | |
file_csv.close() | |
if __name__ == '__main__': | |
main() | |