| | import torch |
| | import tiktoken |
| | from constants.tokens import special_tokens |
| | from services.model import load_model, get_device |
| |
|
| | |
| | _tokenizer = tiktoken.get_encoding("cl100k_base") |
| |
|
| | |
| | def generate_word(words, model, vocab, inv_vocab, max_length=64): |
| | device = get_device() |
| | |
| | |
| | input_text = ",".join(words) |
| | input_tokens = _tokenizer.encode(input_text) |
| | input_tensor = torch.tensor([vocab.get(str(tok), vocab["<pad>"]) for tok in input_tokens]).unsqueeze(0).to(device) |
| | |
| | |
| | target = torch.tensor([[vocab["<sos>"]]]).to(device) |
| | |
| | |
| | with torch.no_grad(): |
| | for _ in range(max_length): |
| | output = model(input_tensor, target) |
| | next_token = output[:, -1, :].argmax(dim=-1, keepdim=True) |
| | |
| | |
| | if next_token.item() == vocab["<eos>"]: |
| | break |
| | |
| | target = torch.cat([target, next_token], dim=1) |
| | |
| | |
| | output_tokens = target[0].cpu().numpy() |
| | output_text = _tokenizer.decode([int(inv_vocab[tok]) for tok in output_tokens if tok not in special_tokens.values()]) |
| | |
| | return output_text |
| |
|
| | def main(words=None): |
| | |
| | model, vocab, inv_vocab = load_model() |
| | |
| | |
| | if words is None: |
| | words = ["muito", "grande", "imenso"] |
| | |
| | result = generate_word(words, model, vocab, inv_vocab) |
| | print(f"Input words: {', '.join(words)}") |
| | print(f"Generated: {result}") |
| | return result |
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|