|
import copy |
|
|
|
from anonymous_demo.functional.config.config_manager import ConfigManager |
|
from anonymous_demo.core.tad.classic.__bert__.models import TADBERT |
|
|
|
_tad_config_template = { |
|
"model": TADBERT, |
|
"optimizer": "adamw", |
|
"learning_rate": 0.00002, |
|
"patience": 99999, |
|
"pretrained_bert": "microsoft/mdeberta-v3-base", |
|
"cache_dataset": True, |
|
"warmup_step": -1, |
|
"show_metric": False, |
|
"max_seq_len": 80, |
|
"dropout": 0, |
|
"l2reg": 0.000001, |
|
"num_epoch": 10, |
|
"batch_size": 16, |
|
"initializer": "xavier_uniform_", |
|
"seed": 52, |
|
"polarities_dim": 3, |
|
"log_step": 10, |
|
"evaluate_begin": 0, |
|
"cross_validate_fold": -1, |
|
"use_amp": False, |
|
|
|
} |
|
|
|
_tad_config_base = { |
|
"model": TADBERT, |
|
"optimizer": "adamw", |
|
"learning_rate": 0.00002, |
|
"pretrained_bert": "microsoft/deberta-v3-base", |
|
"cache_dataset": True, |
|
"warmup_step": -1, |
|
"show_metric": False, |
|
"max_seq_len": 80, |
|
"patience": 99999, |
|
"dropout": 0, |
|
"l2reg": 0.000001, |
|
"num_epoch": 10, |
|
"batch_size": 16, |
|
"initializer": "xavier_uniform_", |
|
"seed": 52, |
|
"polarities_dim": 3, |
|
"log_step": 10, |
|
"evaluate_begin": 0, |
|
"cross_validate_fold": -1 |
|
|
|
} |
|
|
|
_tad_config_english = { |
|
"model": TADBERT, |
|
"optimizer": "adamw", |
|
"learning_rate": 0.00002, |
|
"patience": 99999, |
|
"pretrained_bert": "microsoft/deberta-v3-base", |
|
"cache_dataset": True, |
|
"warmup_step": -1, |
|
"show_metric": False, |
|
"max_seq_len": 80, |
|
"dropout": 0, |
|
"l2reg": 0.000001, |
|
"num_epoch": 10, |
|
"batch_size": 16, |
|
"initializer": "xavier_uniform_", |
|
"seed": 52, |
|
"polarities_dim": 3, |
|
"log_step": 10, |
|
"evaluate_begin": 0, |
|
"cross_validate_fold": -1 |
|
|
|
} |
|
|
|
_tad_config_multilingual = { |
|
"model": TADBERT, |
|
"optimizer": "adamw", |
|
"learning_rate": 0.00002, |
|
"patience": 99999, |
|
"pretrained_bert": "microsoft/mdeberta-v3-base", |
|
"cache_dataset": True, |
|
"warmup_step": -1, |
|
"show_metric": False, |
|
"max_seq_len": 80, |
|
"dropout": 0, |
|
"l2reg": 0.000001, |
|
"num_epoch": 10, |
|
"batch_size": 16, |
|
"initializer": "xavier_uniform_", |
|
"seed": 52, |
|
"polarities_dim": 3, |
|
"log_step": 10, |
|
"evaluate_begin": 0, |
|
"cross_validate_fold": -1 |
|
|
|
} |
|
|
|
_tad_config_chinese = { |
|
"model": TADBERT, |
|
"optimizer": "adamw", |
|
"learning_rate": 0.00002, |
|
"patience": 99999, |
|
"cache_dataset": True, |
|
"warmup_step": -1, |
|
"show_metric": False, |
|
"pretrained_bert": "bert-base-chinese", |
|
"max_seq_len": 80, |
|
"dropout": 0, |
|
"l2reg": 0.000001, |
|
"num_epoch": 10, |
|
"batch_size": 16, |
|
"initializer": "xavier_uniform_", |
|
"seed": 52, |
|
"polarities_dim": 3, |
|
"log_step": 10, |
|
"evaluate_begin": 0, |
|
"cross_validate_fold": -1 |
|
|
|
} |
|
|
|
|
|
class TADConfigManager(ConfigManager): |
|
def __init__(self, args, **kwargs): |
|
""" |
|
Available Params: {'model': BERT, |
|
'optimizer': "adamw", |
|
'learning_rate': 0.00002, |
|
'pretrained_bert': "roberta-base", |
|
'cache_dataset': True, |
|
'warmup_step': -1, |
|
'show_metric': False, |
|
'max_seq_len': 80, |
|
'patience': 99999, |
|
'dropout': 0, |
|
'l2reg': 0.000001, |
|
'num_epoch': 10, |
|
'batch_size': 16, |
|
'initializer': 'xavier_uniform_', |
|
'seed': {52, 25} |
|
'embed_dim': 768, |
|
'hidden_dim': 768, |
|
'polarities_dim': 3, |
|
'log_step': 10, |
|
'evaluate_begin': 0, |
|
'cross_validate_fold': -1 # split train and test datasets into 5 folds and repeat 3 training |
|
} |
|
:param args: |
|
:param kwargs: |
|
""" |
|
super().__init__(args, **kwargs) |
|
|
|
@staticmethod |
|
def set_tad_config(configType: str, newitem: dict): |
|
if isinstance(newitem, dict): |
|
if configType == "template": |
|
_tad_config_template.update(newitem) |
|
elif configType == "base": |
|
_tad_config_base.update(newitem) |
|
elif configType == "english": |
|
_tad_config_english.update(newitem) |
|
elif configType == "chinese": |
|
_tad_config_chinese.update(newitem) |
|
elif configType == "multilingual": |
|
_tad_config_multilingual.update(newitem) |
|
elif configType == "glove": |
|
_tad_config_glove.update(newitem) |
|
else: |
|
raise ValueError( |
|
"Wrong value of config type supplied, please use one from following type: template, base, english, chinese, multilingual, glove" |
|
) |
|
else: |
|
raise TypeError( |
|
"Wrong type of new config item supplied, please use dict e.g.{'NewConfig': NewValue}" |
|
) |
|
|
|
@staticmethod |
|
def set_tad_config_template(newitem): |
|
TADConfigManager.set_tad_config("template", newitem) |
|
|
|
@staticmethod |
|
def set_tad_config_base(newitem): |
|
TADConfigManager.set_tad_config("base", newitem) |
|
|
|
@staticmethod |
|
def set_tad_config_english(newitem): |
|
TADConfigManager.set_tad_config("english", newitem) |
|
|
|
@staticmethod |
|
def set_tad_config_chinese(newitem): |
|
TADConfigManager.set_tad_config("chinese", newitem) |
|
|
|
@staticmethod |
|
def set_tad_config_multilingual(newitem): |
|
TADConfigManager.set_tad_config("multilingual", newitem) |
|
|
|
@staticmethod |
|
def set_tad_config_glove(newitem): |
|
TADConfigManager.set_tad_config("glove", newitem) |
|
|
|
@staticmethod |
|
def get_tad_config_template() -> ConfigManager: |
|
_tad_config_template.update(_tad_config_template) |
|
return TADConfigManager(copy.deepcopy(_tad_config_template)) |
|
|
|
@staticmethod |
|
def get_tad_config_base() -> ConfigManager: |
|
_tad_config_template.update(_tad_config_base) |
|
return TADConfigManager(copy.deepcopy(_tad_config_template)) |
|
|
|
@staticmethod |
|
def get_tad_config_english() -> ConfigManager: |
|
_tad_config_template.update(_tad_config_english) |
|
return TADConfigManager(copy.deepcopy(_tad_config_template)) |
|
|
|
@staticmethod |
|
def get_tad_config_chinese() -> ConfigManager: |
|
_tad_config_template.update(_tad_config_chinese) |
|
return TADConfigManager(copy.deepcopy(_tad_config_template)) |
|
|
|
@staticmethod |
|
def get_tad_config_multilingual() -> ConfigManager: |
|
_tad_config_template.update(_tad_config_multilingual) |
|
return TADConfigManager(copy.deepcopy(_tad_config_template)) |
|
|
|
@staticmethod |
|
def get_tad_config_glove() -> ConfigManager: |
|
_tad_config_template.update(_tad_config_glove) |
|
return TADConfigManager(copy.deepcopy(_tad_config_template)) |
|
|