File size: 1,566 Bytes
e8a4189 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 |
from typing import Tuple
import torch
from transformer import get_model, Transformer
from config import load_config, get_weights_file_path
from train import get_local_dataset_tokenizer
from tokenizer import get_or_build_local_tokenizer
from tokenizers import Tokenizer
def load_model_tokenizer(
config,
device = torch.device('cpu'),
) -> Tuple[Transformer, Tokenizer, Tokenizer]:
"""
Loads a local model and tokenizer from a given config
"""
if config['model']['preload'] is None:
raise ValueError('Unspecified preload model')
src_tokenizer = get_or_build_local_tokenizer(
config=config,
ds=None,
lang=config['dataset']['src_lang'],
tokenizer_type=config['dataset']['src_tokenizer']
)
tgt_tokenizer = get_or_build_local_tokenizer(
config=config,
ds=None,
lang=config['dataset']['tgt_lang'],
tokenizer_type=config['dataset']['tgt_tokenizer']
)
model = get_model(
config,
src_tokenizer.get_vocab_size(),
tgt_tokenizer.get_vocab_size(),
).to(device)
model_filename = get_weights_file_path(config, config['model']['preload'])
state = torch.load(model_filename, map_location=device)
model.load_state_dict(state['model_state_dict'])
print(f'Finish loading model and tokenizers')
return (model, src_tokenizer, tgt_tokenizer)
if __name__ == '__main__':
config = load_config(file_name='config/config_final.yaml')
model, src_tokenizer, tgt_tokenizer = load_model_tokenizer(config)
|