import torch from config import Config from models.seq2seq import Encoder, Decoder, Seq2Seq from utils.tokenizer import build_vocab from datasets import load_from_disk def translate_sentence(sentence, model, src_tokenizer, src_vocab, trg_vocab, device, max_len=30): model.eval() tokens = src_tokenizer(sentence.lower()) src_tensor = torch.tensor([src_vocab[""]] + [src_vocab[t] for t in tokens] + [src_vocab[""]]).unsqueeze(1).to(device) with torch.no_grad(): hidden = model.encoder(src_tensor) trg_indexes = [trg_vocab[""]] for _ in range(max_len): trg_tensor = torch.tensor([trg_indexes[-1]]).to(device) with torch.no_grad(): output, hidden = model.decoder(trg_tensor, hidden) pred_token = output.argmax(1).item() trg_indexes.append(pred_token) if pred_token == trg_vocab[""]: break return [trg_vocab.get_itos()[i] for i in trg_indexes] if __name__ == "__main__": cfg = Config() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") dataset = load_from_disk("data/raw/") src_tokenizer, src_vocab = build_vocab(dataset, cfg.source_lang) trg_tokenizer, trg_vocab = build_vocab(dataset, cfg.target_lang) enc = Encoder(len(src_vocab), cfg.emb_dim, cfg.hid_dim, cfg.n_layers) dec = Decoder(len(trg_vocab), cfg.emb_dim, cfg.hid_dim, cfg.n_layers) model = Seq2Seq(enc, dec, device).to(device) model.load_state_dict(torch.load(cfg.model_save_path, map_location=device)) print(translate_sentence("I love cats", model, src_tokenizer, src_vocab, trg_vocab, device))