# Copyright 2017 Google Inc. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Input readers and document/token generators for datasets.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function from collections import namedtuple import csv import os import random # Dependency imports import tensorflow as tf from data import data_utils flags = tf.app.flags FLAGS = flags.FLAGS flags.DEFINE_string('dataset', '', 'Which dataset to generate data for') # Preprocessing config flags.DEFINE_boolean('output_unigrams', True, 'Whether to output unigrams.') flags.DEFINE_boolean('output_bigrams', False, 'Whether to output bigrams.') flags.DEFINE_boolean('output_char', False, 'Whether to output characters.') flags.DEFINE_boolean('lowercase', True, 'Whether to lowercase document terms.') # IMDB flags.DEFINE_string('imdb_input_dir', '', 'The input directory containing the ' 'IMDB sentiment dataset.') flags.DEFINE_integer('imdb_validation_pos_start_id', 10621, 'File id of the ' 'first file in the pos sentiment validation set.') flags.DEFINE_integer('imdb_validation_neg_start_id', 10625, 'File id of the ' 'first file in the neg sentiment validation set.') # DBpedia flags.DEFINE_string('dbpedia_input_dir', '', 'Path to DBpedia directory containing train.csv and ' 'test.csv.') # Reuters Corpus (rcv1) flags.DEFINE_string('rcv1_input_dir', '', 'Path to rcv1 directory containing train.csv, unlab.csv, ' 'and test.csv.') # Rotten Tomatoes flags.DEFINE_string('rt_input_dir', '', 'The Rotten Tomatoes dataset input directory.') # The amazon reviews input file to use in either the RT or IMDB datasets. flags.DEFINE_string('amazon_unlabeled_input_file', '', 'The unlabeled Amazon Reviews dataset input file. If set, ' 'the input file is used to augment RT and IMDB vocab.') Document = namedtuple('Document', 'content is_validation is_test label add_tokens') def documents(dataset='train', include_unlabeled=False, include_validation=False): """Generates Documents based on FLAGS.dataset. Args: dataset: str, identifies folder within IMDB data directory, test or train. include_unlabeled: bool, whether to include the unsup directory. Only valid when dataset=train. include_validation: bool, whether to include validation data. Yields: Document Raises: ValueError: if include_unlabeled is true but dataset is not 'train' """ if include_unlabeled and dataset != 'train': raise ValueError('If include_unlabeled=True, must use train dataset') # Set the random seed so that we have the same validation set when running # gen_data and gen_vocab. random.seed(302) ds = FLAGS.dataset if ds == 'imdb': docs_gen = imdb_documents elif ds == 'dbpedia': docs_gen = dbpedia_documents elif ds == 'rcv1': docs_gen = rcv1_documents elif ds == 'rt': docs_gen = rt_documents else: raise ValueError('Unrecognized dataset %s' % FLAGS.dataset) for doc in docs_gen(dataset, include_unlabeled, include_validation): yield doc def tokens(doc): """Given a Document, produces character or word tokens. Tokens can be either characters, or word-level tokens (unigrams and/or bigrams). Args: doc: Document to produce tokens from. Yields: token Raises: ValueError: if all FLAGS.{output_unigrams, output_bigrams, output_char} are False. """ if not (FLAGS.output_unigrams or FLAGS.output_bigrams or FLAGS.output_char): raise ValueError( 'At least one of {FLAGS.output_unigrams, FLAGS.output_bigrams, ' 'FLAGS.output_char} must be true') content = doc.content.strip() if FLAGS.lowercase: content = content.lower() if FLAGS.output_char: for char in content: yield char else: tokens_ = data_utils.split_by_punct(content) for i, token in enumerate(tokens_): if FLAGS.output_unigrams: yield token if FLAGS.output_bigrams: previous_token = (tokens_[i - 1] if i > 0 else data_utils.EOS_TOKEN) bigram = '_'.join([previous_token, token]) yield bigram if (i + 1) == len(tokens_): bigram = '_'.join([token, data_utils.EOS_TOKEN]) yield bigram def imdb_documents(dataset='train', include_unlabeled=False, include_validation=False): """Generates Documents for IMDB dataset. Data from http://ai.stanford.edu/~amaas/data/sentiment/ Args: dataset: str, identifies folder within IMDB data directory, test or train. include_unlabeled: bool, whether to include the unsup directory. Only valid when dataset=train. include_validation: bool, whether to include validation data. Yields: Document Raises: ValueError: if FLAGS.imdb_input_dir is empty. """ if not FLAGS.imdb_input_dir: raise ValueError('Must provide FLAGS.imdb_input_dir') tf.logging.info('Generating IMDB documents...') def check_is_validation(filename, class_label): if class_label is None: return False file_idx = int(filename.split('_')[0]) is_pos_valid = (class_label and file_idx >= FLAGS.imdb_validation_pos_start_id) is_neg_valid = (not class_label and file_idx >= FLAGS.imdb_validation_neg_start_id) return is_pos_valid or is_neg_valid dirs = [(dataset + '/pos', True), (dataset + '/neg', False)] if include_unlabeled: dirs.append(('train/unsup', None)) for d, class_label in dirs: for filename in os.listdir(os.path.join(FLAGS.imdb_input_dir, d)): is_validation = check_is_validation(filename, class_label) if is_validation and not include_validation: continue with open(os.path.join(FLAGS.imdb_input_dir, d, filename), encoding='utf-8') as imdb_f: content = imdb_f.read() yield Document( content=content, is_validation=is_validation, is_test=False, label=class_label, add_tokens=True) if FLAGS.amazon_unlabeled_input_file and include_unlabeled: with open(FLAGS.amazon_unlabeled_input_file, encoding='utf-8') as rt_f: for content in rt_f: yield Document( content=content, is_validation=False, is_test=False, label=None, add_tokens=False) def dbpedia_documents(dataset='train', include_unlabeled=False, include_validation=False): """Generates Documents for DBpedia dataset. Dataset linked to at https://github.com/zhangxiangxiao/Crepe. Args: dataset: str, identifies the csv file within the DBpedia data directory, test or train. include_unlabeled: bool, unused. include_validation: bool, whether to include validation data, which is a randomly selected 10% of the data. Yields: Document Raises: ValueError: if FLAGS.dbpedia_input_dir is empty. """ del include_unlabeled if not FLAGS.dbpedia_input_dir: raise ValueError('Must provide FLAGS.dbpedia_input_dir') tf.logging.info('Generating DBpedia documents...') with open(os.path.join(FLAGS.dbpedia_input_dir, dataset + '.csv')) as db_f: reader = csv.reader(db_f) for row in reader: # 10% of the data is randomly held out is_validation = random.randint(1, 10) == 1 if is_validation and not include_validation: continue content = row[1] + ' ' + row[2] yield Document( content=content, is_validation=is_validation, is_test=False, label=int(row[0]) - 1, # Labels should start from 0 add_tokens=True) def rcv1_documents(dataset='train', include_unlabeled=True, include_validation=False): # pylint:disable=line-too-long """Generates Documents for Reuters Corpus (rcv1) dataset. Dataset described at http://www.ai.mit.edu/projects/jmlr/papers/volume5/lewis04a/lyrl2004_rcv1v2_README.htm Args: dataset: str, identifies the csv file within the rcv1 data directory. include_unlabeled: bool, whether to include the unlab file. Only valid when dataset=train. include_validation: bool, whether to include validation data, which is a randomly selected 10% of the data. Yields: Document Raises: ValueError: if FLAGS.rcv1_input_dir is empty. """ # pylint:enable=line-too-long if not FLAGS.rcv1_input_dir: raise ValueError('Must provide FLAGS.rcv1_input_dir') tf.logging.info('Generating rcv1 documents...') datasets = [dataset] if include_unlabeled: if dataset == 'train': datasets.append('unlab') for dset in datasets: with open(os.path.join(FLAGS.rcv1_input_dir, dset + '.csv')) as db_f: reader = csv.reader(db_f) for row in reader: # 10% of the data is randomly held out is_validation = random.randint(1, 10) == 1 if is_validation and not include_validation: continue content = row[1] yield Document( content=content, is_validation=is_validation, is_test=False, label=int(row[0]), add_tokens=True) def rt_documents(dataset='train', include_unlabeled=True, include_validation=False): # pylint:disable=line-too-long """Generates Documents for the Rotten Tomatoes dataset. Dataset available at http://www.cs.cornell.edu/people/pabo/movie-review-data/ In this dataset, amazon reviews are used for the unlabeled data. Args: dataset: str, identifies the data subdirectory. include_unlabeled: bool, whether to include the unlabeled data. Only valid when dataset=train. include_validation: bool, whether to include validation data, which is a randomly selected 10% of the data. Yields: Document Raises: ValueError: if FLAGS.rt_input_dir is empty. """ # pylint:enable=line-too-long if not FLAGS.rt_input_dir: raise ValueError('Must provide FLAGS.rt_input_dir') tf.logging.info('Generating rt documents...') data_files = [] input_filenames = os.listdir(FLAGS.rt_input_dir) for inp_fname in input_filenames: if inp_fname.endswith('.pos'): data_files.append((os.path.join(FLAGS.rt_input_dir, inp_fname), True)) elif inp_fname.endswith('.neg'): data_files.append((os.path.join(FLAGS.rt_input_dir, inp_fname), False)) if include_unlabeled and FLAGS.amazon_unlabeled_input_file: data_files.append((FLAGS.amazon_unlabeled_input_file, None)) for filename, class_label in data_files: with open(filename) as rt_f: for content in rt_f: if class_label is None: # Process Amazon Review data for unlabeled dataset if content.startswith('review/text'): yield Document( content=content, is_validation=False, is_test=False, label=None, add_tokens=False) else: # 10% of the data is randomly held out for the validation set and # another 10% of it is randomly held out for the test set random_int = random.randint(1, 10) is_validation = random_int == 1 is_test = random_int == 2 if (is_test and dataset != 'test') or (is_validation and not include_validation): continue yield Document( content=content, is_validation=is_validation, is_test=is_test, label=class_label, add_tokens=True)