from torchtext.data.utils import get_tokenizer from torchtext.vocab import build_vocab_from_iterator from torchtext.datasets import multi30k, Multi30k from typing import Iterable, List from models.PhonemeTransformer import * # We need to modify the URLs for the dataset since the links to the original dataset are broken # Refer to https://github.com/pytorch/text/issues/1756#issuecomment-1163664163 for more info multi30k.URL["train"] = "https://raw.githubusercontent.com/neychev/small_DL_repo/master/datasets/Multi30k/training.tar.gz" multi30k.URL["valid"] = "https://raw.githubusercontent.com/neychev/small_DL_repo/master/datasets/Multi30k/validation.tar.gz" SRC_LANGUAGE = 'de' TGT_LANGUAGE = 'en' # Place-holders token_transform = {} vocab_transform = {} token_transform[SRC_LANGUAGE] = get_tokenizer( 'spacy', language='de_core_news_sm' ) token_transform[TGT_LANGUAGE] = get_tokenizer( 'spacy', language='en_core_web_sm' ) # helper function to yield list of tokens def yield_tokens(data_iter: Iterable, language: str) -> List[str]: language_index = {SRC_LANGUAGE: 0, TGT_LANGUAGE: 1} for data_sample in data_iter: yield token_transform[language]( data_sample[language_index[language]] ) for ln in [SRC_LANGUAGE, TGT_LANGUAGE]: # Training data Iterator train_iter = Multi30k( split='train', language_pair=(SRC_LANGUAGE, TGT_LANGUAGE) ) # Create torchtext's Vocab object vocab_transform[ln] = build_vocab_from_iterator( yield_tokens(train_iter, ln), min_freq=1, specials=SPECIAL_SYMBOLS, special_first=True ) # Set ``UNK_IDX`` as the default index. # This index is returned when the token is not found. # If not set, it throws ``RuntimeError`` when the queried # token is not found in the Vocabulary. for ln in [SRC_LANGUAGE, TGT_LANGUAGE]: vocab_transform[ln].set_default_index(UNK_IDX) english_tokenizer = token_transform[TGT_LANGUAGE] text = """ The proletariat is the social class of wage-earners who are those members of a society whose only possession of significant economic value is their labour power """.strip() tokens = english_tokenizer(text) print('TOKENS', tokens) print('END >>>')