"""Module for build dynamic fields.""" from collections import Counter, defaultdict import torch from onmt.utils.logging import logger from onmt.utils.misc import check_path from onmt.inputters.inputter import get_fields, _load_vocab, \ _build_fields_vocab def _get_dynamic_fields(opts): # NOTE: not support tgt feats yet tgt_feats = None with_align = hasattr(opts, 'lambda_align') and opts.lambda_align > 0.0 fields = get_fields('text', opts.src_feats_vocab, tgt_feats, dynamic_dict=opts.copy_attn, src_truncate=opts.src_seq_length_trunc, tgt_truncate=opts.tgt_seq_length_trunc, with_align=with_align, data_task=opts.data_task) return fields def build_dynamic_fields(opts, src_specials=None, tgt_specials=None): """Build fields for dynamic, including load & build vocab.""" fields = _get_dynamic_fields(opts) counters = defaultdict(Counter) logger.info("Loading vocab from text file...") _src_vocab, _src_vocab_size = _load_vocab( opts.src_vocab, 'src', counters, min_freq=opts.src_words_min_frequency) if opts.src_feats_vocab: for feat_name, filepath in opts.src_feats_vocab.items(): _, _ = _load_vocab( filepath, feat_name, counters, min_freq=0) if opts.tgt_vocab: _tgt_vocab, _tgt_vocab_size = _load_vocab( opts.tgt_vocab, 'tgt', counters, min_freq=opts.tgt_words_min_frequency) elif opts.share_vocab: logger.info("Sharing src vocab to tgt...") counters['tgt'] = counters['src'] else: raise ValueError("-tgt_vocab should be specified if not share_vocab.") logger.info("Building fields with vocab in counters...") fields = _build_fields_vocab( fields, counters, 'text', opts.share_vocab, opts.vocab_size_multiple, opts.src_vocab_size, opts.src_words_min_frequency, opts.tgt_vocab_size, opts.tgt_words_min_frequency, src_specials=src_specials, tgt_specials=tgt_specials) return fields def get_vocabs(fields): """Get a dict contain src & tgt vocab extracted from fields.""" src_vocab = fields['src'].base_field.vocab tgt_vocab = fields['tgt'].base_field.vocab vocabs = {'src': src_vocab, 'tgt': tgt_vocab} return vocabs def save_fields(fields, save_data, overwrite=True): """Dump `fields` object.""" fields_path = "{}.vocab.pt".format(save_data) check_path(fields_path, exist_ok=overwrite, log=logger.warning) logger.info(f"Saving fields to {fields_path}...") torch.save(fields, fields_path) def load_fields(save_data, checkpoint=None): """Load dumped fields object from `save_data` or `checkpoint` if any.""" if checkpoint is not None: logger.info("Loading fields from checkpoint...") fields = checkpoint['vocab'] else: fields_path = "{}.vocab.pt".format(save_data) logger.info(f"Loading fields from {fields_path}...") fields = torch.load(fields_path) return fields