|
""" |
|
|
|
AugmentCommand class |
|
=========================== |
|
|
|
""" |
|
|
|
from argparse import ArgumentDefaultsHelpFormatter, ArgumentError, ArgumentParser |
|
import csv |
|
import os |
|
import time |
|
|
|
import tqdm |
|
|
|
import textattack |
|
from textattack.augment_args import AUGMENTATION_RECIPE_NAMES |
|
from textattack.commands import TextAttackCommand |
|
|
|
|
|
class AugmentCommand(TextAttackCommand): |
|
"""The TextAttack attack module: |
|
|
|
A command line parser to run data augmentation from user |
|
specifications. |
|
""" |
|
|
|
def run(self, args): |
|
"""Reads in a CSV, performs augmentation, and outputs an augmented CSV. |
|
|
|
Preserves all columns except for the input (augmneted) column. |
|
""" |
|
|
|
args = textattack.AugmenterArgs(**vars(args)) |
|
if args.interactive: |
|
print("\nRunning in interactive mode...\n") |
|
augmenter = eval(AUGMENTATION_RECIPE_NAMES[args.recipe])( |
|
pct_words_to_swap=args.pct_words_to_swap, |
|
transformations_per_example=args.transformations_per_example, |
|
high_yield=args.high_yield, |
|
fast_augment=args.fast_augment, |
|
enable_advanced_metrics=args.enable_advanced_metrics, |
|
) |
|
print("--------------------------------------------------------") |
|
|
|
while True: |
|
print( |
|
'\nEnter a sentence to augment, "q" to quit, "c" to view/change arguments:\n' |
|
) |
|
text = input() |
|
|
|
if text == "q": |
|
break |
|
|
|
elif text == "c": |
|
print( |
|
f"\nCurrent Arguments:\n\n\t augmentation recipe: {args.recipe}, " |
|
f"\n\t pct_words_to_swap: {args.pct_words_to_swap}, " |
|
f"\n\t transformations_per_example: {args.transformations_per_example}\n" |
|
) |
|
|
|
change = input( |
|
"Enter 'c' again to change arguments, any other keys to opt out\n" |
|
) |
|
if change == "c": |
|
print("\nChanging augmenter arguments...\n") |
|
recipe = input( |
|
"\tAugmentation recipe name ('r' to see available recipes): " |
|
) |
|
if recipe == "r": |
|
recipe_display = " ".join(AUGMENTATION_RECIPE_NAMES.keys()) |
|
print(f"\n\t{recipe_display}\n") |
|
args.recipe = input("\tAugmentation recipe name: ") |
|
else: |
|
args.recipe = recipe |
|
|
|
args.pct_words_to_swap = float( |
|
input("\tPercentage of words to swap (0.0 ~ 1.0): ") |
|
) |
|
args.transformations_per_example = int( |
|
input("\tTransformations per input example: ") |
|
) |
|
|
|
print("\nGenerating new augmenter...\n") |
|
augmenter = eval(AUGMENTATION_RECIPE_NAMES[args.recipe])( |
|
pct_words_to_swap=args.pct_words_to_swap, |
|
transformations_per_example=args.transformations_per_example, |
|
) |
|
print( |
|
"--------------------------------------------------------" |
|
) |
|
|
|
continue |
|
|
|
elif not text: |
|
continue |
|
|
|
print("\nAugmenting...\n") |
|
print("--------------------------------------------------------") |
|
|
|
if args.enable_advanced_metrics: |
|
results = augmenter.augment(text) |
|
print("Augmentations:\n") |
|
for augmentation in results[0]: |
|
print(augmentation, "\n") |
|
print() |
|
print( |
|
f"Average Original Perplexity Score: {results[1]['avg_original_perplexity']}" |
|
) |
|
print( |
|
f"Average Augment Perplexity Score: {results[1]['avg_attack_perplexity']}" |
|
) |
|
print( |
|
f"Average Augment USE Score: {results[2]['avg_attack_use_score']}\n" |
|
) |
|
|
|
else: |
|
for augmentation in augmenter.augment(text): |
|
print(augmentation, "\n") |
|
print("--------------------------------------------------------") |
|
else: |
|
textattack.shared.utils.set_seed(args.random_seed) |
|
start_time = time.time() |
|
if not (args.input_csv and args.input_column and args.output_csv): |
|
raise ArgumentError( |
|
"The following arguments are required: --csv, --input-column/--i" |
|
) |
|
|
|
if not os.path.exists(args.input_csv): |
|
raise FileNotFoundError(f"Can't find CSV at location {args.input_csv}") |
|
if os.path.exists(args.output_csv): |
|
if args.overwrite: |
|
textattack.shared.logger.info( |
|
f"Preparing to overwrite {args.output_csv}." |
|
) |
|
else: |
|
raise OSError( |
|
f"Outfile {args.output_csv} exists and --overwrite not set." |
|
) |
|
|
|
|
|
csv_file = open(args.input_csv, "r") |
|
|
|
|
|
def markQuotes(lines): |
|
for row in lines: |
|
row = row.replace('"', '"/') |
|
yield row |
|
|
|
dialect = csv.Sniffer().sniff(csv_file.readline(), delimiters=";,") |
|
csv_file.seek(0) |
|
rows = [ |
|
row |
|
for row in csv.DictReader( |
|
markQuotes(csv_file), |
|
dialect=dialect, |
|
skipinitialspace=True, |
|
) |
|
] |
|
|
|
|
|
for row in rows: |
|
for item in row: |
|
i = 0 |
|
while i < len(row[item]): |
|
if row[item][i] == "/": |
|
if row[item][i - 1] == '"': |
|
row[item] = row[item][:i] + row[item][i + 1 :] |
|
else: |
|
row[item] = row[item][:i] + '"' + row[item][i + 1 :] |
|
i += 1 |
|
|
|
|
|
row_keys = set(rows[0].keys()) |
|
if args.input_column not in row_keys: |
|
raise ValueError( |
|
f"Could not find input column {args.input_column} in CSV. Found keys: {row_keys}" |
|
) |
|
textattack.shared.logger.info( |
|
f"Read {len(rows)} rows from {args.input_csv}. Found columns {row_keys}." |
|
) |
|
|
|
augmenter = eval(AUGMENTATION_RECIPE_NAMES[args.recipe])( |
|
pct_words_to_swap=args.pct_words_to_swap, |
|
transformations_per_example=args.transformations_per_example, |
|
high_yield=args.high_yield, |
|
fast_augment=args.fast_augment, |
|
) |
|
|
|
output_rows = [] |
|
for row in tqdm.tqdm(rows, desc="Augmenting rows"): |
|
text_input = row[args.input_column] |
|
if not args.exclude_original: |
|
output_rows.append(row) |
|
for augmentation in augmenter.augment(text_input): |
|
augmented_row = row.copy() |
|
augmented_row[args.input_column] = augmentation |
|
output_rows.append(augmented_row) |
|
|
|
|
|
with open(args.output_csv, "w") as outfile: |
|
csv_writer = csv.writer( |
|
outfile, delimiter=",", quotechar="/", quoting=csv.QUOTE_MINIMAL |
|
) |
|
|
|
csv_writer.writerow(output_rows[0].keys()) |
|
|
|
for row in output_rows: |
|
csv_writer.writerow(row.values()) |
|
|
|
textattack.shared.logger.info( |
|
f"Wrote {len(output_rows)} augmentations to {args.output_csv} in {time.time() - start_time}s." |
|
) |
|
|
|
|
|
with open(args.output_csv, "r") as file: |
|
data = file.readlines() |
|
for i in range(len(data)): |
|
data[i] = data[i].replace("/", "") |
|
with open(args.output_csv, "w") as file: |
|
file.writelines(data) |
|
|
|
@staticmethod |
|
def register_subcommand(main_parser: ArgumentParser): |
|
parser = main_parser.add_parser( |
|
"augment", |
|
help="augment text data", |
|
formatter_class=ArgumentDefaultsHelpFormatter, |
|
) |
|
parser = textattack.AugmenterArgs._add_parser_args(parser) |
|
parser.set_defaults(func=AugmentCommand()) |
|
|