import re import datasets import tensorflow as tf import promptsource.utils def feature_to_spec(feature, length=False): if isinstance(feature, datasets.ClassLabel): return tf.TensorSpec(shape=() if not length else (None if length == -1 else length,), dtype=tf.int64) elif isinstance(feature, datasets.Value): return tf.TensorSpec( shape=() if not length else (None if length == -1 else length,), dtype=getattr(tf.dtypes, feature.dtype) ) elif hasattr(feature, "dtype") and hasattr(feature, "shape"): return tf.TensorSpec(shape=feature.shape, dtype=feature.dtype) elif isinstance(feature, datasets.Sequence): return feature_to_spec(feature.feature, length=feature.length) elif isinstance(feature, list): return [feature_to_spec(f, length=length) for f in feature] elif isinstance(feature, dict): return {k: feature_to_spec(v, length=length) for k, v in feature.items()} else: raise ValueError(f"Unparseable feature type {type(feature)}") def hf_dataset_to_tf_dataset(dataset): return tf.data.Dataset.from_generator( dataset.__iter__, output_signature={k: feature_to_spec(v) for k, v in dataset.features.items()} ) def apply_template(dataset, template): def map_fn(ex): ex = promptsource.utils.removeHyphen(ex) inputs_and_targets = template.apply(ex) answer_choices = template.get_answer_choices_list(ex) if len(inputs_and_targets) == 2: inputs, targets = inputs_and_targets if targets == "": ex = {"inputs": inputs, "targets": ""} else: ex = {"inputs": inputs, "targets": targets} # When template results in an empty example, template.apply returns [""] # Also, if the template gets split wrong, len can be > 2 # We will filter these out later else: ex = {"inputs": "", "targets": ""} if answer_choices: ex["answer_choices"] = answer_choices return ex def filter_fn(ex): return len(ex["inputs"]) > 0 and len(ex["targets"]) > 0 original_columns = dataset.column_names dataset = dataset.map(map_fn).filter(filter_fn) # map keeps original columns, remove them return dataset.remove_columns(set(original_columns) - {"inputs", "targets", "answer_choices"}) def get_dataset_splits(dataset_name, subset_name=None): info = datasets.get_dataset_infos(dataset_name) subset_name = subset_name or list(info.keys())[0] return info[subset_name].splits def task_clean(text): # Clean the text according to allowed characters for a task name return re.sub(r"[^\w\d\._]+", "_", text) def get_task_name(dataset_name, subset_name, template_name): return task_clean(dataset_name + (f"_{subset_name}_" if subset_name is not None else "_") + template_name)