| import torch |
| from typing import Dict |
|
|
| def character_lookup(text: str): |
| """Build character vocabulary and lookup tables from input text.""" |
| assert isinstance(text, str), "text must be a string object" |
| vocab = sorted(set(text.lower())) |
| vocab_len = len(vocab) |
| print(f"vocab len: {vocab_len}") |
| print(f"vocab characters: {repr(''.join(vocab))}") |
| char2id = {char: idx for idx, char in enumerate(vocab)} |
| id2char = {idx: char for idx, char in enumerate(vocab)} |
| return vocab_len, char2id, id2char |
|
|
| def text_encoder(text: str, char2id: Dict[str, int]) -> torch.Tensor: |
| """Encode text string into tensor of character IDs.""" |
| assert isinstance(text, str), "text must be a string object" |
| return torch.tensor([char2id[char] for char in text.lower()], dtype=torch.long) |
|
|
| def text_decoder(token_ids: torch.Tensor, id2char: Dict[int, str]) -> str: |
| """Decode tensor of character IDs back into text string.""" |
| assert isinstance(token_ids, torch.Tensor), "token_ids must be a torch tensor object" |
| return "".join(id2char[id_.item()] for id_ in token_ids) |