File size: 3,161 Bytes
70e1f1d
 
 
 
a9beef1
 
 
70e1f1d
 
 
 
 
 
 
a9beef1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70e1f1d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b2d17dd
 
70e1f1d
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
from pathlib import Path
from typing import Dict, NamedTuple, Union
import numpy as np
import torch 

NULL_CHAR = '\x00'


class PrimingData(NamedTuple):
    """combines data required for priming the HandwritingRNN sampling"""
    stroke_tensors: torch.Tensor # (batch_size, num_prime_strokes, 3) 
    char_seq_tensors: torch.Tensor # (batch_size, num_prime_chars)
    char_seq_lengths: torch.Tensor # (batch_size,)

def construct_alphabet_list(alphabet_string: str) -> list[str]:
    if not isinstance(alphabet_string, str):
        raise TypeError("alphabet_string must be a string") 
    
    char_list = list(alphabet_string) 
    return [NULL_CHAR] + char_list 

def get_alphabet_map(alphabet_list: list[str]) -> Dict[str, int]:
    """creates a char to index map from full alphabet list"""
    return {char: idx for idx, char in enumerate(alphabet_list)}  

def encode_text(text: str, char_to_index_map: Dict[str, int], 
                max_length: int, add_eos: bool = True, eos_char_index: int = 0
                ) -> tuple[np.ndarray, int]:
    """Encode a text string into a sequence of integer indices"""
    encoded = [char_to_index_map.get(c, eos_char_index) for c in text] 
    if add_eos:
        encoded.append(eos_char_index) 

    true_length = len(encoded)

    if true_length <= max_length: 
        padded_encoded = np.full(max_length, eos_char_index, dtype=np.int64) 
        padded_encoded[:true_length] = encoded 
    else:
        padded_encoded = np.array(encoded[:max_length], dtype=np.int64) 
        true_length = max_length 
    
    return np.array([padded_encoded]), true_length


def convert_offsets_to_absolute_coords(stroke_offsets: list[list[float]]) -> list[list[float]]:
    if not stroke_offsets:
        return []
    
    # convert to numpy for vectorized operations
    strokes_array = np.array(stroke_offsets)
    
    # vectorized cumulative sum for x and y 
    strokes_array[:, 0] = np.cumsum(strokes_array[:, 0])  # cumulative dx
    strokes_array[:, 1] = np.cumsum(strokes_array[:, 1])  # cumulative dy
    
    return strokes_array.tolist()


def load_np_strokes(stroke_path: Union[Path, str]) -> np.ndarray:
    """loads stroke sequence from stroke_path"""
    stroke_path = Path(stroke_path)
    if not stroke_path.exists():
        raise FileNotFoundError(f"style strokes file not found at {stroke_path}")
    
    return np.load(stroke_path)

def load_text(text_path: Union[Path, str]) -> str:
    """loads text from a text_path""" 
    text_path = Path(text_path) 
    if not text_path.exists():
        raise FileNotFoundError(f"Text file not found at {text_path}")
    if not text_path.is_file():
        raise IsADirectoryError(f"Path is a directory, not a file.")
    
    try: 
        with open(text_path, 'r', encoding='utf-8') as f:
            content = f.read() 
        return content 

    except Exception as e:
        raise IOError(f"Error reading text file {text_path}: {e}")

def load_priming_data(style: int):
    
    priming_text = load_text(f"./styles/style{style}.txt")
    priming_strokes = load_np_strokes(f"./styles/style{style}.npy")
    
    return priming_text, priming_strokes