|
import os |
|
import torch |
|
import torch.nn as nn |
|
from torch import Tensor |
|
from torch.utils.data import DataLoader, Dataset |
|
import torchmetrics |
|
from torch.utils.tensorboard import SummaryWriter |
|
|
|
from tqdm import tqdm |
|
|
|
|
|
from load_dataset import load_local_dataset |
|
from transformer import get_model, Transformer |
|
from config import load_config, get_weights_file_path |
|
|
|
from tokenizers import Tokenizer |
|
from tokenizers.models import WordLevel, BPE |
|
from tokenizers.trainers import WordLevelTrainer, BpeTrainer |
|
from tokenizers.pre_tokenizers import Whitespace |
|
|
|
from pathlib import Path |
|
|
|
from dataset import BilingualDataset |
|
from bleu import calculate_bleu_score |
|
from decode_method import greedy_decode |
|
|
|
|
|
def run_validation( |
|
model: Transformer, |
|
validation_ds: DataLoader, |
|
src_tokenizer: Tokenizer, |
|
tgt_tokenizer: Tokenizer, |
|
device, |
|
print_msg, |
|
global_state, |
|
writer, |
|
num_examples:int = 2 |
|
): |
|
model.eval() |
|
|
|
|
|
count = 0 |
|
source_texts = [] |
|
expected = [] |
|
predicted = [] |
|
|
|
console_width = 50 |
|
with torch.no_grad(): |
|
for batch in validation_ds: |
|
count += 1 |
|
encoder_input = batch['encoder_input'].to(device) |
|
encoder_mask = batch['encoder_mask'].to(device) |
|
|
|
assert encoder_input.size(0) == 1, "batch_size = 1 for validation" |
|
|
|
model_out = greedy_decode(model, encoder_input, encoder_mask, src_tokenizer, tgt_tokenizer, 300, device) |
|
|
|
source_text = batch['src_text'][0] |
|
target_text = batch['tgt_text'][0] |
|
model_out_text = tgt_tokenizer.decode(model_out.detach().cpu().numpy()) |
|
|
|
source_texts.append(source_text) |
|
expected.append(target_text) |
|
predicted.append(model_out_text) |
|
|
|
print_msg("-"*console_width) |
|
print_msg(f"SOURCE: {source_text}") |
|
print_msg(f"TARGET: {target_text}") |
|
print_msg(f"PREDICTED: {model_out_text}") |
|
|
|
if count == num_examples: |
|
break |