import argparse import textwrap from promptsource.templates import TemplateCollection, INCLUDED_USERS from promptsource.utils import get_dataset parser = argparse.ArgumentParser(description="Process some integers.") parser.add_argument("dataset_path", type=str, help="path to dataset name") args = parser.parse_args() if "templates.yaml" not in args.dataset_path: exit() path = args.dataset_path.split("/") if path[2] in INCLUDED_USERS: print("Skipping showing templates for community dataset.") else: dataset_name = path[2] subset_name = path[3] if len(path) == 5 else "" template_collection = TemplateCollection() dataset = get_dataset(dataset_name, subset_name) splits = list(dataset.keys()) dataset_templates = template_collection.get_dataset(dataset_name, subset_name) template_list = dataset_templates.all_template_names width = 80 print("DATASET ", args.dataset_path) # First show all the templates. for template_name in template_list: template = dataset_templates[template_name] print("TEMPLATE") print("NAME:", template_name) print("Is Original Task: ", template.metadata.original_task) print(template.jinja) print() # Show examples of the templates. for template_name in template_list: template = dataset_templates[template_name] print() print("TEMPLATE") print("NAME:", template_name) print("REFERENCE:", template.reference) print("--------") print() print(template.jinja) print() for split_name in splits: dataset_split = dataset[split_name] print_counter = 0 for example in dataset_split: print("\t--------") print("\tSplit ", split_name) print("\tExample ", example) print("\t--------") output = template.apply(example) if output[0].strip() == "" or (len(output) > 1 and output[1].strip() == ""): print("\t Blank result") continue xp, yp = output print() print("\tPrompt | X") for line in textwrap.wrap(xp, width=width, replace_whitespace=False): print("\t", line.replace("\n", "\n\t")) print() print("\tY") for line in textwrap.wrap(yp, width=width, replace_whitespace=False): print("\t", line.replace("\n", "\n\t")) print_counter += 1 if print_counter >= 10: break