PFEemp2024's picture
solving GPU error for previous version
4a1df2e
raw
history blame contribute delete
No virus
7.39 kB
import copy
from anonymous_demo.functional.config.config_manager import ConfigManager
from anonymous_demo.core.tad.classic.__bert__.models import TADBERT
_tad_config_template = {
"model": TADBERT,
"optimizer": "adamw",
"learning_rate": 0.00002,
"patience": 99999,
"pretrained_bert": "microsoft/mdeberta-v3-base",
"cache_dataset": True,
"warmup_step": -1,
"show_metric": False,
"max_seq_len": 80,
"dropout": 0,
"l2reg": 0.000001,
"num_epoch": 10,
"batch_size": 16,
"initializer": "xavier_uniform_",
"seed": 52,
"polarities_dim": 3,
"log_step": 10,
"evaluate_begin": 0,
"cross_validate_fold": -1,
"use_amp": False,
# split train and test datasets into 5 folds and repeat 3 training
}
_tad_config_base = {
"model": TADBERT,
"optimizer": "adamw",
"learning_rate": 0.00002,
"pretrained_bert": "microsoft/deberta-v3-base",
"cache_dataset": True,
"warmup_step": -1,
"show_metric": False,
"max_seq_len": 80,
"patience": 99999,
"dropout": 0,
"l2reg": 0.000001,
"num_epoch": 10,
"batch_size": 16,
"initializer": "xavier_uniform_",
"seed": 52,
"polarities_dim": 3,
"log_step": 10,
"evaluate_begin": 0,
"cross_validate_fold": -1
# split train and test datasets into 5 folds and repeat 3 training
}
_tad_config_english = {
"model": TADBERT,
"optimizer": "adamw",
"learning_rate": 0.00002,
"patience": 99999,
"pretrained_bert": "microsoft/deberta-v3-base",
"cache_dataset": True,
"warmup_step": -1,
"show_metric": False,
"max_seq_len": 80,
"dropout": 0,
"l2reg": 0.000001,
"num_epoch": 10,
"batch_size": 16,
"initializer": "xavier_uniform_",
"seed": 52,
"polarities_dim": 3,
"log_step": 10,
"evaluate_begin": 0,
"cross_validate_fold": -1
# split train and test datasets into 5 folds and repeat 3 training
}
_tad_config_multilingual = {
"model": TADBERT,
"optimizer": "adamw",
"learning_rate": 0.00002,
"patience": 99999,
"pretrained_bert": "microsoft/mdeberta-v3-base",
"cache_dataset": True,
"warmup_step": -1,
"show_metric": False,
"max_seq_len": 80,
"dropout": 0,
"l2reg": 0.000001,
"num_epoch": 10,
"batch_size": 16,
"initializer": "xavier_uniform_",
"seed": 52,
"polarities_dim": 3,
"log_step": 10,
"evaluate_begin": 0,
"cross_validate_fold": -1
# split train and test datasets into 5 folds and repeat 3 training
}
_tad_config_chinese = {
"model": TADBERT,
"optimizer": "adamw",
"learning_rate": 0.00002,
"patience": 99999,
"cache_dataset": True,
"warmup_step": -1,
"show_metric": False,
"pretrained_bert": "bert-base-chinese",
"max_seq_len": 80,
"dropout": 0,
"l2reg": 0.000001,
"num_epoch": 10,
"batch_size": 16,
"initializer": "xavier_uniform_",
"seed": 52,
"polarities_dim": 3,
"log_step": 10,
"evaluate_begin": 0,
"cross_validate_fold": -1
# split train and test datasets into 5 folds and repeat 3 training
}
class TADConfigManager(ConfigManager):
def __init__(self, args, **kwargs):
"""
Available Params: {'model': BERT,
'optimizer': "adamw",
'learning_rate': 0.00002,
'pretrained_bert': "roberta-base",
'cache_dataset': True,
'warmup_step': -1,
'show_metric': False,
'max_seq_len': 80,
'patience': 99999,
'dropout': 0,
'l2reg': 0.000001,
'num_epoch': 10,
'batch_size': 16,
'initializer': 'xavier_uniform_',
'seed': {52, 25}
'embed_dim': 768,
'hidden_dim': 768,
'polarities_dim': 3,
'log_step': 10,
'evaluate_begin': 0,
'cross_validate_fold': -1 # split train and test datasets into 5 folds and repeat 3 training
}
:param args:
:param kwargs:
"""
super().__init__(args, **kwargs)
@staticmethod
def set_tad_config(configType: str, newitem: dict):
if isinstance(newitem, dict):
if configType == "template":
_tad_config_template.update(newitem)
elif configType == "base":
_tad_config_base.update(newitem)
elif configType == "english":
_tad_config_english.update(newitem)
elif configType == "chinese":
_tad_config_chinese.update(newitem)
elif configType == "multilingual":
_tad_config_multilingual.update(newitem)
elif configType == "glove":
_tad_config_glove.update(newitem)
else:
raise ValueError(
"Wrong value of config type supplied, please use one from following type: template, base, english, chinese, multilingual, glove"
)
else:
raise TypeError(
"Wrong type of new config item supplied, please use dict e.g.{'NewConfig': NewValue}"
)
@staticmethod
def set_tad_config_template(newitem):
TADConfigManager.set_tad_config("template", newitem)
@staticmethod
def set_tad_config_base(newitem):
TADConfigManager.set_tad_config("base", newitem)
@staticmethod
def set_tad_config_english(newitem):
TADConfigManager.set_tad_config("english", newitem)
@staticmethod
def set_tad_config_chinese(newitem):
TADConfigManager.set_tad_config("chinese", newitem)
@staticmethod
def set_tad_config_multilingual(newitem):
TADConfigManager.set_tad_config("multilingual", newitem)
@staticmethod
def set_tad_config_glove(newitem):
TADConfigManager.set_tad_config("glove", newitem)
@staticmethod
def get_tad_config_template() -> ConfigManager:
_tad_config_template.update(_tad_config_template)
return TADConfigManager(copy.deepcopy(_tad_config_template))
@staticmethod
def get_tad_config_base() -> ConfigManager:
_tad_config_template.update(_tad_config_base)
return TADConfigManager(copy.deepcopy(_tad_config_template))
@staticmethod
def get_tad_config_english() -> ConfigManager:
_tad_config_template.update(_tad_config_english)
return TADConfigManager(copy.deepcopy(_tad_config_template))
@staticmethod
def get_tad_config_chinese() -> ConfigManager:
_tad_config_template.update(_tad_config_chinese)
return TADConfigManager(copy.deepcopy(_tad_config_template))
@staticmethod
def get_tad_config_multilingual() -> ConfigManager:
_tad_config_template.update(_tad_config_multilingual)
return TADConfigManager(copy.deepcopy(_tad_config_template))
@staticmethod
def get_tad_config_glove() -> ConfigManager:
_tad_config_template.update(_tad_config_glove)
return TADConfigManager(copy.deepcopy(_tad_config_template))