from typing import Tuple import torch from transformer import get_model, Transformer from config import load_config, get_weights_file_path from train import get_local_dataset_tokenizer from tokenizer import get_or_build_local_tokenizer from tokenizers import Tokenizer def load_train_data_and_save_model(config, model_name): """ loads training data (model, optim, scheduler,...) and saves ONLY the model. """ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f'Using device {device}') train_dataloader, val_dataloader, src_tokenizer, tgt_tokenizer = get_local_dataset_tokenizer(config) model = get_model(config, src_tokenizer.get_vocab_size(), tgt_tokenizer.get_vocab_size()).to(device) assert config['model']['preload'], 'where to preload model.' model_load_filename = get_weights_file_path(config, config['model']['preload']) print(f'Preloading model from train data in {model_load_filename}') state = torch.load(model_load_filename, map_location=device) model.load_state_dict(state['model_state_dict']) model_save_filename = get_weights_file_path(config, model_name) torch.save(model.state_dict(), model_save_filename) print(f'Model saved at {model_save_filename}') def load_model_tokenizer( config, device = torch.device('cpu'), logs: bool = True, ) -> Tuple[Transformer, Tokenizer, Tokenizer]: """ Loads a local model and tokenizer from a given config """ if config['model']['preload'] is None: raise ValueError('Unspecified preload model') src_tokenizer = get_or_build_local_tokenizer( config=config, ds=None, lang=config['dataset']['src_lang'], tokenizer_type=config['dataset']['src_tokenizer'] ) tgt_tokenizer = get_or_build_local_tokenizer( config=config, ds=None, lang=config['dataset']['tgt_lang'], tokenizer_type=config['dataset']['tgt_tokenizer'] ) model = get_model( config, src_tokenizer.get_vocab_size(), tgt_tokenizer.get_vocab_size(), ).to(device) model_filename = get_weights_file_path(config, config['model']['preload']) model.load_state_dict( torch.load(model_filename, map_location=device) ) print(f'Finish loading model and tokenizers') return (model, src_tokenizer, tgt_tokenizer) if __name__ == '__main__': config = load_config(file_name='config_huge.yaml') load_train_data_and_save_model(config, 'huge')