| """ |
| Tera.VO Text Processing Module |
| Full text normalization and encoding pipeline built from scratch. |
| """ |
|
|
| import re |
| import numpy as np |
| import inflect |
|
|
| _inflect_engine = inflect.engine() |
|
|
| |
| _pad = '_' |
| _eos = '~' |
| _bos = '^' |
| _punctuation = '!\'(),.:;? -"' |
| _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz' |
|
|
| symbols = [_pad] + [_bos] + [_eos] + list(_punctuation) + list(_letters) |
| symbol_to_id = {s: i for i, s in enumerate(symbols)} |
| id_to_symbol = {i: s for i, s in enumerate(symbols)} |
| NUM_SYMBOLS = len(symbols) |
|
|
|
|
| class TextProcessor: |
| """Complete text processing pipeline for Tera.VO""" |
|
|
| def __init__(self): |
| self.symbol_to_id = symbol_to_id |
| self.id_to_symbol = id_to_symbol |
| self.num_symbols = NUM_SYMBOLS |
|
|
| self.abbreviations = { |
| 'mr.': 'mister', 'mrs.': 'missus', 'dr.': 'doctor', |
| 'prof.': 'professor', 'sr.': 'senior', 'jr.': 'junior', |
| 'st.': 'saint', 'vs.': 'versus', 'etc.': 'etcetera', |
| 'govt.': 'government', 'dept.': 'department', |
| 'jan.': 'january', 'feb.': 'february', 'mar.': 'march', |
| 'apr.': 'april', 'aug.': 'august', 'sep.': 'september', |
| 'oct.': 'october', 'nov.': 'november', 'dec.': 'december', |
| 'approx.': 'approximately', 'univ.': 'university', |
| } |
|
|
| def normalize_text(self, text): |
| """Full normalization pipeline""" |
| text = text.strip() |
| text = self._expand_abbreviations(text) |
| text = self._expand_numbers(text) |
| text = self._expand_symbols(text) |
| text = self._collapse_whitespace(text) |
| return text |
|
|
| def _expand_abbreviations(self, text): |
| for abbr, full in self.abbreviations.items(): |
| text = re.sub(re.escape(abbr), full, text, flags=re.IGNORECASE) |
| return text |
|
|
| def _expand_numbers(self, text): |
| text = re.sub( |
| r'\$(\d+\.?\d*)', |
| lambda m: self._currency(m.group(1)), text |
| ) |
| text = re.sub( |
| r'(\d+\.?\d*)%', |
| lambda m: self._number_words(m.group(1)) + ' percent', text |
| ) |
| text = re.sub( |
| r'(\d+)(st|nd|rd|th)\b', |
| lambda m: self._ordinal(int(m.group(1))), text |
| ) |
| text = re.sub( |
| r'\b\d+\.?\d*\b', |
| lambda m: self._number_words(m.group(0)), text |
| ) |
| return text |
|
|
| def _currency(self, amount_str): |
| parts = amount_str.split('.') |
| dollars = int(parts[0]) |
| result = self._number_words(str(dollars)) |
| result += ' dollar' + ('s' if dollars != 1 else '') |
| if len(parts) > 1 and int(parts[1]) > 0: |
| cents = int(parts[1][:2].ljust(2, '0')) |
| result += ' and ' + self._number_words(str(cents)) |
| result += ' cent' + ('s' if cents != 1 else '') |
| return result |
|
|
| def _number_words(self, num_str): |
| try: |
| num = float(num_str) |
| if num == int(num): |
| return _inflect_engine.number_to_words(int(num)) |
| return _inflect_engine.number_to_words(num_str) |
| except (ValueError, TypeError): |
| return num_str |
|
|
| def _ordinal(self, num): |
| try: |
| return _inflect_engine.ordinal( |
| _inflect_engine.number_to_words(num) |
| ) |
| except Exception: |
| return str(num) |
|
|
| def _expand_symbols(self, text): |
| replacements = { |
| '&': ' and ', '@': ' at ', '#': ' hash ', |
| '+': ' plus ', '=': ' equals ', '/': ' slash ', |
| } |
| for sym, word in replacements.items(): |
| text = text.replace(sym, word) |
| return text |
|
|
| def _collapse_whitespace(self, text): |
| return re.sub(r'\s+', ' ', text).strip() |
|
|
| def text_to_sequence(self, text): |
| """Convert normalized text to integer sequence""" |
| text = self.normalize_text(text) |
| seq = [self.symbol_to_id[_bos]] |
| for ch in text: |
| if ch in self.symbol_to_id: |
| seq.append(self.symbol_to_id[ch]) |
| seq.append(self.symbol_to_id[_eos]) |
| return seq |
|
|
| def sequence_to_text(self, sequence): |
| """Convert integer sequence back to text""" |
| chars = [] |
| for idx in sequence: |
| if idx in self.id_to_symbol: |
| s = self.id_to_symbol[idx] |
| if s not in [_pad, _bos, _eos]: |
| chars.append(s) |
| return ''.join(chars) |
|
|
| def pad_sequence(self, seq, max_len): |
| """Pad or truncate sequence""" |
| if len(seq) >= max_len: |
| return seq[:max_len] |
| return seq + [self.symbol_to_id[_pad]] * (max_len - len(seq)) |
|
|
| def get_vocab_size(self): |
| return self.num_symbols |
|
|
|
|
| text_processor = TextProcessor() |