File size: 1,566 Bytes
e8a4189
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
from typing import Tuple

import torch
from transformer import get_model, Transformer
from config import load_config, get_weights_file_path
from train import get_local_dataset_tokenizer
from tokenizer import get_or_build_local_tokenizer

from tokenizers import Tokenizer
   

def load_model_tokenizer(
    config,
    device = torch.device('cpu'),
) -> Tuple[Transformer, Tokenizer, Tokenizer]:
    """
    Loads a local model and tokenizer from a given config
    """
    if config['model']['preload'] is None:
        raise ValueError('Unspecified preload model')
    
    src_tokenizer = get_or_build_local_tokenizer(
        config=config, 
        ds=None, 
        lang=config['dataset']['src_lang'], 
        tokenizer_type=config['dataset']['src_tokenizer']
    )
    tgt_tokenizer = get_or_build_local_tokenizer(
        config=config, 
        ds=None, 
        lang=config['dataset']['tgt_lang'], 
        tokenizer_type=config['dataset']['tgt_tokenizer']
    )

    model = get_model(
        config,
        src_tokenizer.get_vocab_size(), 
        tgt_tokenizer.get_vocab_size(),
    ).to(device)

    model_filename = get_weights_file_path(config, config['model']['preload'])
    state = torch.load(model_filename, map_location=device)
    model.load_state_dict(state['model_state_dict'])

    print(f'Finish loading model and tokenizers')
    return (model, src_tokenizer, tgt_tokenizer)

if __name__ == '__main__':
    config = load_config(file_name='config/config_final.yaml')
    model, src_tokenizer, tgt_tokenizer = load_model_tokenizer(config)