|
|
|
"""*********************************************************************************************""" |
|
|
|
|
|
|
|
|
|
"""*********************************************************************************************""" |
|
|
|
|
|
""" |
|
Note: |
|
In order to calculate the significance tests, |
|
the downstream expert's forward() method should log a metric score for each testing sample. |
|
The `records['sample_wise_metric']` should be a list containing the testing result of each sample. |
|
|
|
For example: |
|
``` |
|
python |
|
# for frame-wise classification |
|
for sample in samples: |
|
records['sample_wise_metric'] += [torch.FloatTensor(sample).mean().item()] |
|
# for utterance-wise classification |
|
records['sample_wise_metric'] += (predicted_classid == labels).view(-1).cpu().tolist() |
|
``` |
|
""" |
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import glob |
|
import random |
|
import argparse |
|
from tqdm import tqdm |
|
from argparse import Namespace |
|
|
|
import torch |
|
import torchaudio |
|
import numpy as np |
|
from statsmodels.stats.contingency_tables import mcnemar |
|
from scipy import stats |
|
|
|
from s3prl.downstream.runner import Runner |
|
from s3prl.utility.helper import hack_isinstance, override, defaultdict |
|
|
|
|
|
def get_ttest_args(): |
|
parser = argparse.ArgumentParser() |
|
|
|
|
|
|
|
parser.add_argument('-m', '--mode', choices=['ttest', 'fisher', 'mcnemar'], default='ttest') |
|
|
|
parser.add_argument('-em', '--evaluate_metric', default='acc') |
|
parser.add_argument('-t', '--evaluate_split', default='test') |
|
parser.add_argument('-o', '--override', help='Used to override args and config, this is at the highest priority') |
|
|
|
|
|
|
|
|
|
parser.add_argument('-e1', '--past_exp1', metavar='{CKPT_PATH,CKPT_DIR}', help='Load from a checkpoint') |
|
parser.add_argument('-e2', '--past_exp2', metavar='{CKPT_PATH,CKPT_DIR}', help='Load from another checkpoint') |
|
parser.add_argument('-u1', '--upstream1', default='default', type=str, help='used to override the upstream string for checkpoint e1') |
|
parser.add_argument('-u2', '--upstream2', default='default', type=str, help='used to override the upstream string for checkpoint e2') |
|
|
|
|
|
parser.add_argument('--seed', default=1337, type=int) |
|
parser.add_argument('--verbose', action='store_true', help='Print model infomation') |
|
parser.add_argument('--ckpt_name', default='best-states-dev', \ |
|
help='The string used for searching the checkpoint, \ |
|
example choices: `states-*`, `best-states-dev`, `best-states-test`.') |
|
args = parser.parse_args() |
|
|
|
args1, config1 = get_past_exp(args, args.past_exp1, args.ckpt_name) |
|
args2, config2 = get_past_exp(args, args.past_exp2, args.ckpt_name) |
|
if args.upstream1 != 'default': args1.upstream = args.upstream1 |
|
if args.upstream2 != 'default': args2.upstream = args.upstream2 |
|
|
|
return args.mode, args1, config1, args2, config2 |
|
|
|
|
|
def get_past_exp(args, past_exp, name): |
|
|
|
if os.path.isdir(past_exp): |
|
ckpt_pths = glob.glob(os.path.join(past_exp, f'{name}.ckpt')) |
|
assert len(ckpt_pths) > 0 |
|
if len(ckpt_pths) == 1: |
|
ckpt_pth = ckpt_pths[0] |
|
else: |
|
ckpt_pths = sorted(ckpt_pths, key=lambda pth: int(pth.split('-')[-1].split('.')[0])) |
|
ckpt_pth = ckpt_pths[-1] |
|
else: |
|
ckpt_pth = past_exp |
|
|
|
print(f'[Runner] - Loading from {ckpt_pth}') |
|
|
|
|
|
ckpt = torch.load(ckpt_pth, map_location='cpu') |
|
|
|
def update_args(old, new, preserve_list=None): |
|
out_dict = vars(old) |
|
new_dict = vars(new) |
|
for key in list(new_dict.keys()): |
|
if key in preserve_list: |
|
new_dict.pop(key) |
|
out_dict.update(new_dict) |
|
return Namespace(**out_dict) |
|
|
|
|
|
cannot_overwrite_args = [ |
|
'mode', 'evaluate_split', 'override', |
|
'backend', 'local_rank', 'past_exp', |
|
] |
|
args = update_args(args, ckpt['Args'], preserve_list=cannot_overwrite_args) |
|
|
|
args.init_ckpt = ckpt_pth |
|
args.mode = 'evaluate' |
|
config = ckpt['Config'] |
|
|
|
if args.override: |
|
override(args.override, args, config) |
|
return args, config |
|
|
|
|
|
class Tester(Runner): |
|
""" |
|
Used to handle the evaluation loop and return the testing records for Paired Sample T-test. |
|
""" |
|
def __init__(self, args, config): |
|
super(Tester, self).__init__(args, config) |
|
|
|
def evaluate(self): |
|
"""evaluate function will always be called on a single process even during distributed training""" |
|
|
|
split = self.args.evaluate_split |
|
|
|
|
|
random.seed(self.args.seed) |
|
np.random.seed(self.args.seed) |
|
torch.manual_seed(self.args.seed) |
|
if torch.cuda.is_available(): |
|
torch.cuda.manual_seed_all(self.args.seed) |
|
with torch.cuda.device(self.args.device): |
|
torch.cuda.empty_cache() |
|
|
|
|
|
self.downstream.eval() |
|
self.upstream.eval() |
|
|
|
|
|
dataloader = self.downstream.get_dataloader(split) |
|
|
|
records = defaultdict(list) |
|
for batch_id, (wavs, *others) in enumerate(tqdm(dataloader, dynamic_ncols=True, desc=split)): |
|
|
|
wavs = [torch.FloatTensor(wav).to(self.args.device) for wav in wavs] |
|
with torch.no_grad(): |
|
features = self.upstream(wavs) |
|
self.downstream( |
|
split, |
|
features, *others, |
|
records = records, |
|
) |
|
return records |
|
|
|
|
|
def process_records(records, metric): |
|
assert 'sample_wise_metric' in records, 'Utterance-wise / sample-wise metric is necessary for proceeding the Paired Sample T-test.' |
|
average = torch.FloatTensor(records[metric]).mean().item() |
|
return average, records['sample_wise_metric'] |
|
|
|
|
|
def main(): |
|
torch.multiprocessing.set_sharing_strategy('file_system') |
|
torchaudio.set_audio_backend('sox_io') |
|
hack_isinstance() |
|
|
|
|
|
mode, args1, config1, args2, config2 = get_ttest_args() |
|
|
|
|
|
random.seed(args1.seed) |
|
np.random.seed(args1.seed) |
|
torch.manual_seed(args1.seed) |
|
if torch.cuda.is_available(): torch.cuda.manual_seed_all(args1.seed) |
|
torch.backends.cudnn.deterministic = True |
|
torch.backends.cudnn.benchmark = False |
|
|
|
tester1 = Tester(args1, config1) |
|
records1 = eval(f'tester1.{args1.mode}')() |
|
average1, sample_metric1 = process_records(records1, args1.evaluate_metric) |
|
|
|
tester2 = Tester(args2, config2) |
|
records2 = eval(f'tester2.{args2.mode}')() |
|
average2, sample_metric2 = process_records(records2, args2.evaluate_metric) |
|
|
|
if mode == 'ttest': |
|
statistic, p_value = stats.ttest_rel(sample_metric1, sample_metric2) |
|
elif mode == 'fisher': |
|
correct1 = sample_metric1.count(True) |
|
correct2 = sample_metric2.count(True) |
|
contingency_table = [[correct1, correct2], |
|
[len(sample_metric1)-correct1, len(sample_metric2)-correct2]] |
|
statistic, p_value = stats.fisher_exact(contingency_table) |
|
elif mode == 'mcnemar': |
|
correct1 = sample_metric1.count(True) |
|
correct2 = sample_metric2.count(True) |
|
contingency_table = [[correct1, correct2], |
|
[len(sample_metric1)-correct1, len(sample_metric2)-correct2]] |
|
b = mcnemar(contingency_table, exact=True) |
|
statistic, p_value = b.statistic, b.pvalue |
|
else: |
|
raise NotImplementedError |
|
|
|
print(f'[Runner] - The testing scores of the two ckpts are {average1} and {average2}, respectively.') |
|
print(f'[Runner] - The statistic of the significant test of the two ckpts is {statistic}') |
|
print(f'[Runner] - The P value of significant test of the two ckpts is {p_value}') |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|