homemade_lo_vi / load_and_save_model.py
moiduy04's picture
Upload 12 files
b8a6dde
raw
history blame
2.53 kB
from typing import Tuple
import torch
from transformer import get_model, Transformer
from config import load_config, get_weights_file_path
from train import get_local_dataset_tokenizer
from tokenizer import get_or_build_local_tokenizer
from tokenizers import Tokenizer
def load_train_data_and_save_model(config, model_name):
"""
loads training data (model, optim, scheduler,...) and saves ONLY the model.
"""
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device {device}')
train_dataloader, val_dataloader, src_tokenizer, tgt_tokenizer = get_local_dataset_tokenizer(config)
model = get_model(config, src_tokenizer.get_vocab_size(), tgt_tokenizer.get_vocab_size()).to(device)
assert config['model']['preload'], 'where to preload model.'
model_load_filename = get_weights_file_path(config, config['model']['preload'])
print(f'Preloading model from train data in {model_load_filename}')
state = torch.load(model_load_filename, map_location=device)
model.load_state_dict(state['model_state_dict'])
model_save_filename = get_weights_file_path(config, model_name)
torch.save(model.state_dict(), model_save_filename)
print(f'Model saved at {model_save_filename}')
def load_model_tokenizer(
config,
device = torch.device('cpu'),
logs: bool = True,
) -> Tuple[Transformer, Tokenizer, Tokenizer]:
"""
Loads a local model and tokenizer from a given config
"""
if config['model']['preload'] is None:
raise ValueError('Unspecified preload model')
src_tokenizer = get_or_build_local_tokenizer(
config=config,
ds=None,
lang=config['dataset']['src_lang'],
tokenizer_type=config['dataset']['src_tokenizer']
)
tgt_tokenizer = get_or_build_local_tokenizer(
config=config,
ds=None,
lang=config['dataset']['tgt_lang'],
tokenizer_type=config['dataset']['tgt_tokenizer']
)
model = get_model(
config,
src_tokenizer.get_vocab_size(),
tgt_tokenizer.get_vocab_size(),
).to(device)
model_filename = get_weights_file_path(config, config['model']['preload'])
model.load_state_dict(
torch.load(model_filename, map_location=device)
)
print(f'Finish loading model and tokenizers')
return (model, src_tokenizer, tgt_tokenizer)
if __name__ == '__main__':
config = load_config(file_name='config_huge.yaml')
load_train_data_and_save_model(config, 'huge')