from enum import Enum import os from pathlib import Path import shutil import subprocess from typing import Any, Dict import ruamel.yaml import torch from poetry_diacritizer.models.baseline import BaseLineModel from poetry_diacritizer.models.cbhg import CBHGModel from poetry_diacritizer.models.gpt import GPTModel from poetry_diacritizer.models.seq2seq import Decoder as Seq2SeqDecoder, Encoder as Seq2SeqEncoder, Seq2Seq from poetry_diacritizer.models.tacotron_based import ( Decoder as TacotronDecoder, Encoder as TacotronEncoder, Tacotron, ) from poetry_diacritizer.options import AttentionType, LossType, OptimizerType from poetry_diacritizer.util.text_encoders import ( ArabicEncoderWithStartSymbol, BasicArabicEncoder, TextEncoder, ) class ConfigManager: """Co/home/almodhfer/Projects/daicritization/temp_results/CA_MSA/cbhg-new/model-10.ptnfig Manager""" def __init__(self, config_path: str, model_kind: str): available_models = ["baseline", "cbhg", "seq2seq", "tacotron_based", "gpt"] if model_kind not in available_models: raise TypeError(f"model_kind must be in {available_models}") self.config_path = Path(config_path) self.model_kind = model_kind self.yaml = ruamel.yaml.YAML() self.config: Dict[str, Any] = self._load_config() self.git_hash = self._get_git_hash() self.session_name = ".".join( [ self.config["data_type"], self.config["session_name"], f"{model_kind}", ] ) self.data_dir = Path( os.path.join(self.config["data_directory"], self.config["data_type"]) ) self.base_dir = Path( os.path.join(self.config["log_directory"], self.session_name) ) self.log_dir = Path(os.path.join(self.base_dir, "logs")) self.prediction_dir = Path(os.path.join(self.base_dir, "predictions")) self.plot_dir = Path(os.path.join(self.base_dir, "plots")) self.models_dir = Path(os.path.join(self.base_dir, "models")) if "sp_model_path" in self.config: self.sp_model_path = self.config["sp_model_path"] else: self.sp_model_path = None self.text_encoder: TextEncoder = self.get_text_encoder() self.config["len_input_symbols"] = len(self.text_encoder.input_symbols) self.config["len_target_symbols"] = len(self.text_encoder.target_symbols) if self.model_kind in ["seq2seq", "tacotron_based"]: self.config["attention_type"] = AttentionType[self.config["attention_type"]] self.config["optimizer"] = OptimizerType[self.config["optimizer_type"]] def _load_config(self): with open(self.config_path, "rb") as model_yaml: _config = self.yaml.load(model_yaml) return _config @staticmethod def _get_git_hash(): try: return ( subprocess.check_output(["git", "describe", "--always"]) .strip() .decode() ) except Exception as e: print(f"WARNING: could not retrieve git hash. {e}") def _check_hash(self): try: git_hash = ( subprocess.check_output(["git", "describe", "--always"]) .strip() .decode() ) if self.config["git_hash"] != git_hash: print( f"""WARNING: git hash mismatch. Current: {git_hash}. Config hash: {self.config['git_hash']}""" ) except Exception as e: print(f"WARNING: could not check git hash. {e}") @staticmethod def _print_dict_values(values, key_name, level=0, tab_size=2): tab = level * tab_size * " " print(tab + "-", key_name, ":", values) def _print_dictionary(self, dictionary, recursion_level=0): for key in dictionary.keys(): if isinstance(key, dict): recursion_level += 1 self._print_dictionary(dictionary[key], recursion_level) else: self._print_dict_values( dictionary[key], key_name=key, level=recursion_level ) def print_config(self): print("\nCONFIGURATION", self.session_name) self._print_dictionary(self.config) def update_config(self): self.config["git_hash"] = self._get_git_hash() def dump_config(self): self.update_config() _config = {} for key, val in self.config.items(): if isinstance(val, Enum): _config[key] = val.name else: _config[key] = val with open(self.base_dir / "config.yml", "w") as model_yaml: self.yaml.dump(_config, model_yaml) def create_remove_dirs( self, clear_dir: bool = False, clear_logs: bool = False, clear_weights: bool = False, clear_all: bool = False, ): self.base_dir.mkdir(exist_ok=True, parents=True) self.plot_dir.mkdir(exist_ok=True) self.prediction_dir.mkdir(exist_ok=True) if clear_dir: delete = input(f"Delete {self.log_dir} AND {self.models_dir}? (y/[n])") if delete == "y": shutil.rmtree(self.log_dir, ignore_errors=True) shutil.rmtree(self.models_dir, ignore_errors=True) if clear_logs: delete = input(f"Delete {self.log_dir}? (y/[n])") if delete == "y": shutil.rmtree(self.log_dir, ignore_errors=True) if clear_weights: delete = input(f"Delete {self.models_dir}? (y/[n])") if delete == "y": shutil.rmtree(self.models_dir, ignore_errors=True) self.log_dir.mkdir(exist_ok=True) self.models_dir.mkdir(exist_ok=True) def get_last_model_path(self): """ Given a checkpoint, get the last save model name Args: checkpoint (str): the path where models are saved """ models = os.listdir(self.models_dir) models = [model for model in models if model[-3:] == ".pt"] if len(models) == 0: return None _max = max(int(m.split(".")[0].split("-")[0]) for m in models) model_name = f"{_max}-snapshot.pt" last_model_path = os.path.join(self.models_dir, model_name) return last_model_path def load_model(self, model_path: str = None): """ loading a model from path Args: checkpoint (str): the path to the model name (str): the name of the model, which is in the path model (Tacotron): the model to load its save state optimizer: the optimizer to load its saved state """ model = self.get_model() with open(self.base_dir / f"{self.model_kind}_network.txt", "w") as file: file.write(str(model)) if model_path is None: last_model_path = self.get_last_model_path() if last_model_path is None: return model, 1 else: last_model_path = model_path saved_model = torch.load(last_model_path) out = model.load_state_dict(saved_model["model_state_dict"]) print(out) global_step = saved_model["global_step"] + 1 return model, global_step def get_model(self, ignore_hash=False): if not ignore_hash: self._check_hash() if self.model_kind == "cbhg": return self.get_cbhg() elif self.model_kind == "seq2seq": return self.get_seq2seq() elif self.model_kind == "tacotron_based": return self.get_tacotron_based() elif self.model_kind == "baseline": return self.get_baseline() elif self.model_kind == "gpt": return self.get_gpt() def get_gpt(self): model = GPTModel( self.config["base_model_path"], freeze=self.config["freeze"], n_layer=self.config["n_layer"], use_lstm=self.config["use_lstm"], ) return model def get_baseline(self): model = BaseLineModel( embedding_dim=self.config["embedding_dim"], inp_vocab_size=self.config["len_input_symbols"], targ_vocab_size=self.config["len_target_symbols"], layers_units=self.config["layers_units"], use_batch_norm=self.config["use_batch_norm"], ) return model def get_cbhg(self): model = CBHGModel( embedding_dim=self.config["embedding_dim"], inp_vocab_size=self.config["len_input_symbols"], targ_vocab_size=self.config["len_target_symbols"], use_prenet=self.config["use_prenet"], prenet_sizes=self.config["prenet_sizes"], cbhg_gru_units=self.config["cbhg_gru_units"], cbhg_filters=self.config["cbhg_filters"], cbhg_projections=self.config["cbhg_projections"], post_cbhg_layers_units=self.config["post_cbhg_layers_units"], post_cbhg_use_batch_norm=self.config["post_cbhg_use_batch_norm"], ) return model def get_seq2seq(self): encoder = Seq2SeqEncoder( embedding_dim=self.config["encoder_embedding_dim"], inp_vocab_size=self.config["len_input_symbols"], layers_units=self.config["encoder_units"], use_batch_norm=self.config["use_batch_norm"], ) decoder = TacotronDecoder( self.config["len_target_symbols"], start_symbol_id=self.text_encoder.start_symbol_id, embedding_dim=self.config["decoder_embedding_dim"], encoder_dim=self.config["encoder_dim"], decoder_units=self.config["decoder_units"], decoder_layers=self.config["decoder_layers"], attention_type=self.config["attention_type"], attention_units=self.config["attention_units"], is_attention_accumulative=self.config["is_attention_accumulative"], use_prenet=self.config["use_decoder_prenet"], prenet_depth=self.config["decoder_prenet_depth"], teacher_forcing_probability=self.config["teacher_forcing_probability"], ) model = Tacotron(encoder=encoder, decoder=decoder) return model def get_tacotron_based(self): encoder = TacotronEncoder( embedding_dim=self.config["encoder_embedding_dim"], inp_vocab_size=self.config["len_input_symbols"], prenet_sizes=self.config["prenet_sizes"], use_prenet=self.config["use_encoder_prenet"], cbhg_gru_units=self.config["cbhg_gru_units"], cbhg_filters=self.config["cbhg_filters"], cbhg_projections=self.config["cbhg_projections"], ) decoder = TacotronDecoder( self.config["len_target_symbols"], start_symbol_id=self.text_encoder.start_symbol_id, embedding_dim=self.config["decoder_embedding_dim"], encoder_dim=self.config["encoder_dim"], decoder_units=self.config["decoder_units"], decoder_layers=self.config["decoder_layers"], attention_type=self.config["attention_type"], attention_units=self.config["attention_units"], is_attention_accumulative=self.config["is_attention_accumulative"], use_prenet=self.config["use_decoder_prenet"], prenet_depth=self.config["decoder_prenet_depth"], teacher_forcing_probability=self.config["teacher_forcing_probability"], ) model = Tacotron(encoder=encoder, decoder=decoder) return model def get_text_encoder(self): """Getting the class of TextEncoder from config""" if self.config["text_cleaner"] not in [ "basic_cleaners", "valid_arabic_cleaners", None, ]: raise Exception(f"cleaner is not known {self.config['text_cleaner']}") if self.config["text_encoder"] == "BasicArabicEncoder": text_encoder = BasicArabicEncoder( cleaner_fn=self.config["text_cleaner"], sp_model_path=self.sp_model_path ) elif self.config["text_encoder"] == "ArabicEncoderWithStartSymbol": text_encoder = ArabicEncoderWithStartSymbol( cleaner_fn=self.config["text_cleaner"], sp_model_path=self.sp_model_path ) else: raise Exception( f"the text encoder is not found {self.config['text_encoder']}" ) return text_encoder def get_loss_type(self): try: loss_type = LossType[self.config["loss_type"]] except: raise Exception(f"The loss type is not correct {self.config['loss_type']}") return loss_type if __name__ == "__main__": config_path = "config/tacotron-base-config.yml" model_kind = "tacotron" config = ConfigManager(config_path=config_path, model_kind=model_kind)