File size: 2,048 Bytes
b8a6dde
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import os
import torch
import torch.nn as nn
from torch import Tensor
from torch.utils.data import DataLoader, Dataset
import torchmetrics
from torch.utils.tensorboard import SummaryWriter

from tqdm import tqdm

# from datasets import load_dataset
from load_dataset import load_local_dataset
from transformer import get_model, Transformer
from config import load_config, get_weights_file_path

from tokenizers import Tokenizer
from tokenizers.models import WordLevel, BPE
from tokenizers.trainers import WordLevelTrainer, BpeTrainer
from tokenizers.pre_tokenizers import Whitespace

from pathlib import Path

from dataset import BilingualDataset
from bleu import calculate_bleu_score
from decode_method import greedy_decode


def run_validation(
    model: Transformer, 
    validation_ds: DataLoader,
    src_tokenizer: Tokenizer,
    tgt_tokenizer: Tokenizer,
    device, 
    print_msg,
    global_state,
    writer,
    num_examples:int = 2
):
    model.eval()

    # inferance
    count = 0
    source_texts = []
    expected = []
    predicted = []

    console_width = 50
    with torch.no_grad():
        for batch in validation_ds:
            count += 1
            encoder_input = batch['encoder_input'].to(device)
            encoder_mask = batch['encoder_mask'].to(device)

            assert encoder_input.size(0) == 1, "batch_size = 1 for validation"

            model_out = greedy_decode(model, encoder_input, encoder_mask, src_tokenizer, tgt_tokenizer, 300, device)

            source_text = batch['src_text'][0]
            target_text = batch['tgt_text'][0]
            model_out_text = tgt_tokenizer.decode(model_out.detach().cpu().numpy())

            source_texts.append(source_text)
            expected.append(target_text)
            predicted.append(model_out_text)

            print_msg("-"*console_width)
            print_msg(f"SOURCE: {source_text}")
            print_msg(f"TARGET: {target_text}")
            print_msg(f"PREDICTED: {model_out_text}")

            if count == num_examples:
                break