|
from .config_manager import ConfigManager |
|
import os |
|
from typing import Dict |
|
|
|
from torch import nn |
|
from tqdm import tqdm |
|
from tqdm import trange |
|
|
|
from dataset import load_iterators |
|
from trainer import GeneralTrainer |
|
|
|
|
|
class DiacritizationTester(GeneralTrainer): |
|
def __init__(self, config_path: str, model_kind: str) -> None: |
|
self.config_path = config_path |
|
self.model_kind = model_kind |
|
self.config_manager = ConfigManager( |
|
config_path=config_path, model_kind=model_kind |
|
) |
|
self.config = self.config_manager.config |
|
self.pad_idx = 0 |
|
self.criterion = nn.CrossEntropyLoss(ignore_index=self.pad_idx) |
|
self.set_device() |
|
|
|
self.text_encoder = self.config_manager.text_encoder |
|
self.start_symbol_id = self.text_encoder.start_symbol_id |
|
|
|
self.model = self.config_manager.get_model() |
|
|
|
self.model = self.model.to(self.device) |
|
|
|
self.load_model(model_path=self.config["test_model_path"], load_optimizer=False) |
|
self.load_diacritizer() |
|
self.diacritizer.set_model(self.model) |
|
|
|
self.initialize_model() |
|
|
|
self.print_config() |
|
|
|
def run(self): |
|
self.config_manager.config["load_training_data"] = False |
|
self.config_manager.config["load_validation_data"] = False |
|
self.config_manager.config["load_test_data"] = True |
|
_, test_iterator, _ = load_iterators(self.config_manager) |
|
tqdm_eval = trange(0, len(test_iterator), leave=True) |
|
tqdm_error_rates = trange(0, len(test_iterator), leave=True) |
|
|
|
loss, acc = self.evaluate(test_iterator, tqdm_eval, log = False) |
|
error_rates, _ = self.evaluate_with_error_rates(test_iterator, tqdm_error_rates, log = False) |
|
|
|
tqdm_eval.close() |
|
tqdm_error_rates.close() |
|
|
|
WER = error_rates["WER"] |
|
DER = error_rates["DER"] |
|
DER1 = error_rates["DER*"] |
|
WER1 = error_rates["WER*"] |
|
|
|
error_rates = f"DER: {DER}, WER: {WER}, DER*: {DER1}, WER*: {WER1}" |
|
|
|
print(f"global step : {self.global_step}") |
|
print(f"Evaluate {self.global_step}: accuracy, {acc}, loss: {loss}") |
|
print(f"WER/DER {self.global_step}: {error_rates}") |
|
|