anonymous8
update
d65ddc0
raw
history blame
9.33 kB
"""
AugmentCommand class
===========================
"""
from argparse import ArgumentDefaultsHelpFormatter, ArgumentError, ArgumentParser
import csv
import os
import time
import tqdm
import textattack
from textattack.augment_args import AUGMENTATION_RECIPE_NAMES
from textattack.commands import TextAttackCommand
class AugmentCommand(TextAttackCommand):
"""The TextAttack attack module:
A command line parser to run data augmentation from user
specifications.
"""
def run(self, args):
"""Reads in a CSV, performs augmentation, and outputs an augmented CSV.
Preserves all columns except for the input (augmneted) column.
"""
args = textattack.AugmenterArgs(**vars(args))
if args.interactive:
print("\nRunning in interactive mode...\n")
augmenter = eval(AUGMENTATION_RECIPE_NAMES[args.recipe])(
pct_words_to_swap=args.pct_words_to_swap,
transformations_per_example=args.transformations_per_example,
high_yield=args.high_yield,
fast_augment=args.fast_augment,
enable_advanced_metrics=args.enable_advanced_metrics,
)
print("--------------------------------------------------------")
while True:
print(
'\nEnter a sentence to augment, "q" to quit, "c" to view/change arguments:\n'
)
text = input()
if text == "q":
break
elif text == "c":
print(
f"\nCurrent Arguments:\n\n\t augmentation recipe: {args.recipe}, "
f"\n\t pct_words_to_swap: {args.pct_words_to_swap}, "
f"\n\t transformations_per_example: {args.transformations_per_example}\n"
)
change = input(
"Enter 'c' again to change arguments, any other keys to opt out\n"
)
if change == "c":
print("\nChanging augmenter arguments...\n")
recipe = input(
"\tAugmentation recipe name ('r' to see available recipes): "
)
if recipe == "r":
recipe_display = " ".join(AUGMENTATION_RECIPE_NAMES.keys())
print(f"\n\t{recipe_display}\n")
args.recipe = input("\tAugmentation recipe name: ")
else:
args.recipe = recipe
args.pct_words_to_swap = float(
input("\tPercentage of words to swap (0.0 ~ 1.0): ")
)
args.transformations_per_example = int(
input("\tTransformations per input example: ")
)
print("\nGenerating new augmenter...\n")
augmenter = eval(AUGMENTATION_RECIPE_NAMES[args.recipe])(
pct_words_to_swap=args.pct_words_to_swap,
transformations_per_example=args.transformations_per_example,
)
print(
"--------------------------------------------------------"
)
continue
elif not text:
continue
print("\nAugmenting...\n")
print("--------------------------------------------------------")
if args.enable_advanced_metrics:
results = augmenter.augment(text)
print("Augmentations:\n")
for augmentation in results[0]:
print(augmentation, "\n")
print()
print(
f"Average Original Perplexity Score: {results[1]['avg_original_perplexity']}"
)
print(
f"Average Augment Perplexity Score: {results[1]['avg_attack_perplexity']}"
)
print(
f"Average Augment USE Score: {results[2]['avg_attack_use_score']}\n"
)
else:
for augmentation in augmenter.augment(text):
print(augmentation, "\n")
print("--------------------------------------------------------")
else:
textattack.shared.utils.set_seed(args.random_seed)
start_time = time.time()
if not (args.input_csv and args.input_column and args.output_csv):
raise ArgumentError(
"The following arguments are required: --csv, --input-column/--i"
)
# Validate input/output paths.
if not os.path.exists(args.input_csv):
raise FileNotFoundError(f"Can't find CSV at location {args.input_csv}")
if os.path.exists(args.output_csv):
if args.overwrite:
textattack.shared.logger.info(
f"Preparing to overwrite {args.output_csv}."
)
else:
raise OSError(
f"Outfile {args.output_csv} exists and --overwrite not set."
)
# Read in CSV file as a list of dictionaries. Use the CSV sniffer to
# try and automatically infer the correct CSV format.
csv_file = open(args.input_csv, "r")
# mark where commas and quotes occur within the text value
def markQuotes(lines):
for row in lines:
row = row.replace('"', '"/')
yield row
dialect = csv.Sniffer().sniff(csv_file.readline(), delimiters=";,")
csv_file.seek(0)
rows = [
row
for row in csv.DictReader(
markQuotes(csv_file),
dialect=dialect,
skipinitialspace=True,
)
]
# replace markings with quotations and commas
for row in rows:
for item in row:
i = 0
while i < len(row[item]):
if row[item][i] == "/":
if row[item][i - 1] == '"':
row[item] = row[item][:i] + row[item][i + 1 :]
else:
row[item] = row[item][:i] + '"' + row[item][i + 1 :]
i += 1
# Validate input column.
row_keys = set(rows[0].keys())
if args.input_column not in row_keys:
raise ValueError(
f"Could not find input column {args.input_column} in CSV. Found keys: {row_keys}"
)
textattack.shared.logger.info(
f"Read {len(rows)} rows from {args.input_csv}. Found columns {row_keys}."
)
augmenter = eval(AUGMENTATION_RECIPE_NAMES[args.recipe])(
pct_words_to_swap=args.pct_words_to_swap,
transformations_per_example=args.transformations_per_example,
high_yield=args.high_yield,
fast_augment=args.fast_augment,
)
output_rows = []
for row in tqdm.tqdm(rows, desc="Augmenting rows"):
text_input = row[args.input_column]
if not args.exclude_original:
output_rows.append(row)
for augmentation in augmenter.augment(text_input):
augmented_row = row.copy()
augmented_row[args.input_column] = augmentation
output_rows.append(augmented_row)
# Print to file.
with open(args.output_csv, "w") as outfile:
csv_writer = csv.writer(
outfile, delimiter=",", quotechar="/", quoting=csv.QUOTE_MINIMAL
)
# Write header.
csv_writer.writerow(output_rows[0].keys())
# Write rows.
for row in output_rows:
csv_writer.writerow(row.values())
textattack.shared.logger.info(
f"Wrote {len(output_rows)} augmentations to {args.output_csv} in {time.time() - start_time}s."
)
# Remove extra markings in output file
with open(args.output_csv, "r") as file:
data = file.readlines()
for i in range(len(data)):
data[i] = data[i].replace("/", "")
with open(args.output_csv, "w") as file:
file.writelines(data)
@staticmethod
def register_subcommand(main_parser: ArgumentParser):
parser = main_parser.add_parser(
"augment",
help="augment text data",
formatter_class=ArgumentDefaultsHelpFormatter,
)
parser = textattack.AugmenterArgs._add_parser_args(parser)
parser.set_defaults(func=AugmentCommand())