import os from typing import Dict from diacritization_evaluation import der, wer import torch from torch import nn from torch import optim from torch.cuda.amp import autocast from torch.utils.tensorboard.writer import SummaryWriter from tqdm.notebook import tqdm from tqdm import trange from diacritization_evaluation import util from .config_manager import ConfigManager from .dataset import load_iterators from .diacritizer import CBHGDiacritizer, Seq2SeqDiacritizer from .options import OptimizerType import gdown class Trainer: def run(self): raise NotImplementedError class GeneralTrainer(Trainer): def __init__(self, config_path: str, model_kind: str) -> None: self.config_path = config_path self.model_kind = model_kind self.config_manager = ConfigManager( config_path=config_path, model_kind=model_kind ) self.config = self.config_manager.config self.losses = [] self.lr = 0 self.pad_idx = 0 self.criterion = nn.CrossEntropyLoss(ignore_index=self.pad_idx) self.set_device() self.config_manager.create_remove_dirs() self.text_encoder = self.config_manager.text_encoder self.start_symbol_id = self.text_encoder.start_symbol_id self.summary_manager = SummaryWriter(log_dir=self.config_manager.log_dir) self.model = self.config_manager.get_model() self.optimizer = self.get_optimizer() self.model = self.model.to(self.device) self.load_model(model_path=self.config.get("train_resume_model_path")) self.load_diacritizer() self.initialize_model() def set_device(self): if self.config.get("device"): self.device = self.config["device"] else: self.device = "cuda" if torch.cuda.is_available() else "cpu" def load_diacritizer(self): if self.model_kind in ["cbhg", "baseline"]: self.diacritizer = CBHGDiacritizer(self.config_path, self.model_kind) elif self.model_kind in ["seq2seq", "tacotron_based"]: self.diacritizer = Seq2SeqDiacritizer(self.config_path, self.model_kind) def initialize_model(self): if self.global_step > 1: return if self.model_kind == "transformer": print("Initializing using xavier_uniform_") self.model.apply(initialize_weights) def load_model(self, model_path: str = None, load_optimizer: bool = True): with open( self.config_manager.base_dir / f"{self.model_kind}_network.txt", "w" ) as file: file.write(str(self.model)) if model_path is None: last_model_path = self.config_manager.get_last_model_path() if last_model_path is None: self.global_step = 1 return else: last_model_path = model_path print(f"loading from {last_model_path}") saved_model = torch.load(last_model_path, torch.device(self.config.get("device"))) self.model.load_state_dict(saved_model["model_state_dict"]) if load_optimizer: self.optimizer.load_state_dict(saved_model["optimizer_state_dict"]) self.global_step = saved_model["global_step"] + 1 class DiacritizationTester(GeneralTrainer): def __init__(self, config_path: str, model_kind: str, model_path: str) -> None: # if config_path == 'config/test.yml' or config_path == "Arabic_Diacritization/config/test.yml": # print("Exporting the pretrained models ... ") # url = 'https://drive.google.com/uc?id=12aYNY7cbsLNzhdPdC2K3u1sgrb1lpzwO' # gdown.cached_download(url,'model.zip', quiet=False, postprocess=gdown.extractall) self.config_path = config_path self.model_kind = model_kind self.config_manager = ConfigManager( config_path=config_path, model_kind=model_kind ) self.config = self.config_manager.config # print(self.config) self.pad_idx = 0 self.criterion = nn.CrossEntropyLoss(ignore_index=self.pad_idx) self.set_device() self.text_encoder = self.config_manager.text_encoder self.start_symbol_id = self.text_encoder.start_symbol_id self.model = self.config_manager.get_model() self.model = self.model.to(self.device) self.load_model(model_path=model_path, load_optimizer=False) self.load_diacritizer() self.diacritizer.set_model(self.model) self.initialize_model() def collate_fn(self, data): """ Padding the input and output sequences """ def merge(sequences): lengths = [len(seq) for seq in sequences] padded_seqs = torch.zeros(len(sequences), max(lengths)).long() for i, seq in enumerate(sequences): end = lengths[i] padded_seqs[i, :end] = seq[:end] return padded_seqs, lengths data.sort(key=lambda x: len(x[0]), reverse=True) # separate source and target sequences src_seqs, trg_seqs, original = zip(*data) # merge sequences (from tuple of 1D tensor to 2D tensor) src_seqs, src_lengths = merge(src_seqs) trg_seqs, trg_lengths = merge(trg_seqs) batch = { "original": original, "src": src_seqs, "target": trg_seqs, "lengths": torch.LongTensor(src_lengths), # src_lengths = trg_lengths } return batch def get_batch(self, sentence): data = self.text_encoder.clean(sentence) text, inputs, diacritics = util.extract_haraqat(data) inputs = torch.Tensor(self.text_encoder.input_to_sequence("".join(inputs))) diacritics = torch.Tensor(self.text_encoder.target_to_sequence(diacritics)) batch = self.collate_fn([(inputs, diacritics, text)]) return batch def infer(self, sentence): self.model.eval() batch = self.get_batch(sentence) predicted = self.diacritizer.diacritize_batch(batch) return predicted[0]