| import torch | |
| from train import get_model, greedy_decode, get_or_build_tokenizer | |
| from config import get_config | |
| INPUT_TEXT = "sun rises in the night" | |
| def inference(): | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| device = torch.device(device) | |
| config = get_config() | |
| tokenizer_src = get_or_build_tokenizer(config, None, config["lang_src"]) | |
| tokenizer_tgt = get_or_build_tokenizer(config, None, config["lang_target"]) | |
| model = get_model(config, tokenizer_src.get_vocab_size(), tokenizer_tgt.get_vocab_size()).to(device) | |
| model_filename = "weights/tmodel_19.pt" | |
| state = torch.load(model_filename, map_location=device) | |
| model.load_state_dict(state["model_state_dict"]) | |
| model.eval() | |
| tokens = tokenizer_src.encode(INPUT_TEXT).ids | |
| tokens = [tokenizer_src.token_to_id("[SOS]")] + tokens + [tokenizer_src.token_to_id("[EOS]")] | |
| encoder_input = torch.tensor(tokens, dtype=torch.long).unsqueeze(0).to(device) | |
| encoder_mask = (encoder_input != tokenizer_src.token_to_id("[PAD]")).unsqueeze(0).unsqueeze(0).to(device) | |
| model_out = greedy_decode(model, encoder_input, encoder_mask, tokenizer_src, tokenizer_tgt, config["seq_len"], device) | |
| output_text = tokenizer_tgt.decode(model_out.detach().cpu().numpy()) | |
| print("Source:", INPUT_TEXT) | |
| print("Predicted:", output_text) | |
| if __name__ == "__main__": | |
| inference() | |