SER_AUDIO / utils.py
saikrishna32's picture
added requirements
4aa6431
raw
history blame contribute delete
No virus
8.29 kB
import json
import torch
import random
import numpy as np
import transformers
import argparse, logging
transformers.logging.set_verbosity(40)
logging.basicConfig(
format='%(asctime)s %(levelname)-3s ==> %(message)s',
level=logging.INFO,
datefmt='%Y-%m-%d %H:%M:%S'
)
def set_seed(seed):
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
def get_results(input_dict):
return_dict = dict()
return_dict["uar"] = input_dict["uar"]
return_dict["acc"] = input_dict["acc"]
return_dict["loss"] = input_dict["loss"]
return return_dict
def log_epoch_result(
result_hist_dict: dict,
epoch: int,
train_result: dict,
dev_result: dict,
test_result: dict,
log_dir: str,
fold_idx: int
):
# read result
result_hist_dict[epoch] = dict()
result_hist_dict[epoch]["train"] = get_results(train_result)
result_hist_dict[epoch]["dev"] = get_results(dev_result)
result_hist_dict[epoch]["test"] = get_results(test_result)
# dump the dictionary
jsonString = json.dumps(result_hist_dict, indent=4)
jsonFile = open(str(log_dir.joinpath(f'fold_{fold_idx}.json')), "w")
jsonFile.write(jsonString)
jsonFile.close()
def log_best_result(
result_hist_dict: dict,
epoch: int,
best_dev_uar: float,
best_dev_acc: float,
best_test_uar: float,
best_test_acc: float,
log_dir: str,
fold_idx: int
):
# log best result
result_hist_dict["best"] = dict()
result_hist_dict["best"]["dev"], result_hist_dict["best"]["test"] = dict(), dict()
result_hist_dict["best"]["dev"]["uar"] = best_dev_uar
result_hist_dict["best"]["dev"]["acc"] = best_dev_acc
result_hist_dict["best"]["test"]["uar"] = best_test_uar
result_hist_dict["best"]["test"]["acc"] = best_test_acc
# save results for this fold
jsonString = json.dumps(result_hist_dict, indent=4)
jsonFile = open(str(log_dir.joinpath(f'fold_{fold_idx}.json')), "w")
jsonFile.write(jsonString)
jsonFile.close()
def parse_finetune_args():
# parser
parser = argparse.ArgumentParser(description='emo2vec finetune experiments')
parser.add_argument(
'--data_dir',
default='/media/data/projects/speech-privacy/trust-ser/audio',
type=str,
help='raw audio path'
)
parser.add_argument(
'--model_dir',
default='/media/data/projects/speech-privacy/trust-ser/model',
type=str,
help='model save path'
)
parser.add_argument(
'--split_dir',
default='/media/data/projects/speech-privacy/trust-ser/train_split',
type=str,
help='train split path'
)
parser.add_argument(
'--log_dir',
default='log/finetune',
type=str,
help='model save path'
)
parser.add_argument(
'--uar_dir',
default='log/uar',
type=str,
help='model uar history'
)
parser.add_argument(
'--attack_dir',
default='/media/data/projects/speech-privacy/trust-ser/attack',
type=str,
help='attack data'
)
parser.add_argument(
'--privacy_attack_dir',
default='/media/data/projects/speech-privacy/trust-ser/privacy',
type=str,
help='privacy attack method data'
)
parser.add_argument(
'--privacy_attack',
default='gender',
type=str,
help='Privacy attack method'
)
parser.add_argument(
'--fairness_dir',
default='/media/data/projects/speech-privacy/trust-ser/fairness',
type=str,
help='model save path'
)
parser.add_argument(
'--sustainability_dir',
default='/media/data/projects/speech-privacy/trust-ser/sustainability',
type=str,
help='model save path'
)
parser.add_argument(
'--attack_method',
default='pgd',
type=str,
help='attack method'
)
parser.add_argument(
'--pretrain_model',
default='wav2vec2_0',
type=str,
help="pretrained model type"
)
parser.add_argument(
'--finetune',
default='frozen',
type=str,
help="partial finetune or not"
)
parser.add_argument(
'--learning_rate',
default=0.0002,
type=float,
help="learning rate",
)
parser.add_argument(
'--num_epochs',
default=50,
type=int,
help="total training rounds",
)
parser.add_argument(
'--optimizer',
default='adam',
type=str,
help="optimizer",
)
parser.add_argument(
'--dataset',
default="iemocap",
type=str,
help="Dataset name",
)
parser.add_argument(
'--audio_duration',
default=6,
type=int,
help="audio length for training"
)
parser.add_argument(
'--downstream_model',
default='rnn',
type=str,
help="model type"
)
parser.add_argument(
'--num_layers',
default=1,
type=int,
help="num of layers",
)
parser.add_argument(
'--snr',
default=45,
type=int,
help="SNR of the audio",
)
parser.add_argument(
'--conv_layers',
default=3,
type=int,
help="num of conv layers",
)
parser.add_argument(
'--hidden_size',
default=256,
type=int,
help="hidden size",
)
parser.add_argument(
'--pooling',
default='att',
type=str,
help="pooling method: att, average",
)
parser.add_argument(
'--norm',
default='nonorm',
type=str,
help="normalization or not",
)
parser.add_argument(
'--finetune_method',
default='finetune',
type=str,
help='finetune method: adapter, embedding prompt, input prompt'
)
parser.add_argument(
'--adapter_hidden_dim',
default=128,
type=int,
help='adapter dimension'
)
parser.add_argument(
'--finetune_emb',
default="all",
type=str,
help='adapter dimension'
)
parser.add_argument(
'--embedding_prompt_dim',
default=5,
type=int,
help='adapter dimension'
)
parser.add_argument(
'--lora_rank',
default=16,
type=int,
help='lora rank'
)
parser.add_argument(
'--LPF',
default=False,
type=bool,
help='need Low pass filter on Audio'
)
parser.add_argument(
'--HPF',
default=False,
type=bool,
help='need High pass filter on Audio'
)
args = parser.parse_args()
if args.finetune_method == "adapter" or args.finetune_method == "adapter_l":
setting = f'lr{str(args.learning_rate).replace(".", "")}_ep{args.num_epochs}_{args.finetune_method}_{args.adapter_hidden_dim}'
elif args.finetune_method == "embedding_prompt":
setting = f'lr{str(args.learning_rate).replace(".", "")}_ep{args.num_epochs}_{args.finetune_method}_{args.embedding_prompt_dim}'
elif args.finetune_method == "lora":
setting = f'lr{str(args.learning_rate).replace(".", "")}_ep{args.num_epochs}_{args.finetune_method}_{args.lora_rank}'
elif args.finetune_method == "finetune":
setting = f'lr{str(args.learning_rate).replace(".", "")}_ep{args.num_epochs}_{args.finetune_method}'
elif args.finetune_method == "combined":
setting = f'lr{str(args.learning_rate).replace(".", "")}_ep{args.num_epochs}_{args.finetune_method}_{args.adapter_hidden_dim}_{args.embedding_prompt_dim}_{args.lora_rank}'
args.setting = setting
if args.finetune_emb != "all":
args.setting = args.setting + "_avgtok"
return args