File size: 23,818 Bytes
4943752
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
""".. _attacked_text:

Attacked Text Class
=====================

A helper class that represents a string that can be attacked.
"""

from collections import OrderedDict
import math

import flair
from flair.data import Sentence
import numpy as np
import torch

import textattack

from .utils import device, words_from_text

flair.device = device


class AttackedText:

    """A helper class that represents a string that can be attacked.

    Models that take multiple sentences as input separate them by ``SPLIT_TOKEN``.
    Attacks "see" the entire input, joined into one string, without the split token.

    ``AttackedText`` instances that were perturbed from other ``AttackedText``
    objects contain a pointer to the previous text
    (``attack_attrs["previous_attacked_text"]``), so that the full chain of
    perturbations might be reconstructed by using this key to form a linked
    list.

    Args:
       text (string): The string that this AttackedText represents
       attack_attrs (dict): Dictionary of various attributes stored
           during the course of an attack.
    """

    SPLIT_TOKEN = "<SPLIT>"

    def __init__(self, text_input, attack_attrs=None):
        # Read in ``text_input`` as a string or OrderedDict.
        if isinstance(text_input, str):
            self._text_input = OrderedDict([("text", text_input)])
        elif isinstance(text_input, OrderedDict):
            self._text_input = text_input
        else:
            raise TypeError(
                f"Invalid text_input type {type(text_input)} (required str or OrderedDict)"
            )
        # Process input lazily.
        self._words = None
        self._words_per_input = None
        self._pos_tags = None
        self._ner_tags = None
        # Format text inputs.
        self._text_input = OrderedDict([(k, v) for k, v in self._text_input.items()])
        if attack_attrs is None:
            self.attack_attrs = dict()
        elif isinstance(attack_attrs, dict):
            self.attack_attrs = attack_attrs
        else:
            raise TypeError(f"Invalid type for attack_attrs: {type(attack_attrs)}")
        # Indices of words from the *original* text. Allows us to map
        # indices between original text and this text, and vice-versa.
        self.attack_attrs.setdefault("original_index_map", np.arange(self.num_words))
        # A list of all indices in *this* text that have been modified.
        self.attack_attrs.setdefault("modified_indices", set())

    def __eq__(self, other):
        """Compares two text instances to make sure they have the same attack
        attributes.

        Since some elements stored in ``self.attack_attrs`` may be numpy
        arrays, we have to take special care when comparing them.
        """
        if not (self.text == other.text):
            return False
        if len(self.attack_attrs) != len(other.attack_attrs):
            return False
        for key in self.attack_attrs:
            if key not in other.attack_attrs:
                return False
            elif isinstance(self.attack_attrs[key], np.ndarray):
                if not (self.attack_attrs[key].shape == other.attack_attrs[key].shape):
                    return False
                elif not (self.attack_attrs[key] == other.attack_attrs[key]).all():
                    return False
            else:
                if not self.attack_attrs[key] == other.attack_attrs[key]:
                    return False
        return True

    def __hash__(self):
        return hash(self.text)

    def free_memory(self):
        """Delete items that take up memory.

        Can be called once the AttackedText is only needed to display.
        """
        if "previous_attacked_text" in self.attack_attrs:
            self.attack_attrs["previous_attacked_text"].free_memory()
            self.attack_attrs.pop("previous_attacked_text", None)

        self.attack_attrs.pop("last_transformation", None)

        for key in self.attack_attrs:
            if isinstance(self.attack_attrs[key], torch.Tensor):
                self.attack_attrs.pop(key, None)

    def text_window_around_index(self, index, window_size):
        """The text window of ``window_size`` words centered around
        ``index``."""
        length = self.num_words
        half_size = (window_size - 1) / 2.0
        if index - half_size < 0:
            start = 0
            end = min(window_size - 1, length - 1)
        elif index + half_size >= length:
            start = max(0, length - window_size)
            end = length - 1
        else:
            start = index - math.ceil(half_size)
            end = index + math.floor(half_size)
        text_idx_start = self._text_index_of_word_index(start)
        text_idx_end = self._text_index_of_word_index(end) + len(self.words[end])
        return self.text[text_idx_start:text_idx_end]

    def pos_of_word_index(self, desired_word_idx):
        """Returns the part-of-speech of the word at index `word_idx`.

        Uses FLAIR part-of-speech tagger.
        """
        if not self._pos_tags:
            sentence = Sentence(
                self.text,
                use_tokenizer=textattack.shared.utils.TextAttackFlairTokenizer(),
            )
            textattack.shared.utils.flair_tag(sentence)
            self._pos_tags = sentence
        flair_word_list, flair_pos_list = textattack.shared.utils.zip_flair_result(
            self._pos_tags
        )

        for word_idx, word in enumerate(self.words):
            assert (
                word in flair_word_list
            ), "word absent in flair returned part-of-speech tags"
            word_idx_in_flair_tags = flair_word_list.index(word)
            if word_idx == desired_word_idx:
                return flair_pos_list[word_idx_in_flair_tags]
            else:
                flair_word_list = flair_word_list[word_idx_in_flair_tags + 1 :]
                flair_pos_list = flair_pos_list[word_idx_in_flair_tags + 1 :]

        raise ValueError(
            f"Did not find word from index {desired_word_idx} in flair POS tag"
        )

    def ner_of_word_index(self, desired_word_idx, model_name="ner"):
        """Returns the ner tag of the word at index `word_idx`.

        Uses FLAIR ner tagger.
        """
        if not self._ner_tags:
            sentence = Sentence(
                self.text,
                use_tokenizer=textattack.shared.utils.TextAttackFlairTokenizer(),
            )
            textattack.shared.utils.flair_tag(sentence, model_name)
            self._ner_tags = sentence
        flair_word_list, flair_ner_list = textattack.shared.utils.zip_flair_result(
            self._ner_tags, "ner"
        )

        for word_idx, word in enumerate(flair_word_list):
            word_idx_in_flair_tags = flair_word_list.index(word)
            if word_idx == desired_word_idx:
                return flair_ner_list[word_idx_in_flair_tags]
            else:
                flair_word_list = flair_word_list[word_idx_in_flair_tags + 1 :]
                flair_ner_list = flair_ner_list[word_idx_in_flair_tags + 1 :]

        raise ValueError(
            f"Did not find word from index {desired_word_idx} in flair POS tag"
        )

    def _text_index_of_word_index(self, i):
        """Returns the index of word ``i`` in self.text."""
        pre_words = self.words[: i + 1]
        lower_text = self.text.lower()
        # Find all words until `i` in string.
        look_after_index = 0
        for word in pre_words:
            look_after_index = lower_text.find(word.lower(), look_after_index) + len(
                word
            )
        look_after_index -= len(self.words[i])
        return look_after_index

    def text_until_word_index(self, i):
        """Returns the text before the beginning of word at index ``i``."""
        look_after_index = self._text_index_of_word_index(i)
        return self.text[:look_after_index]

    def text_after_word_index(self, i):
        """Returns the text after the end of word at index ``i``."""
        # Get index of beginning of word then jump to end of word.
        look_after_index = self._text_index_of_word_index(i) + len(self.words[i])
        return self.text[look_after_index:]

    def first_word_diff(self, other_attacked_text):
        """Returns the first word in self.words that differs from
        other_attacked_text.

        Useful for word swap strategies.
        """
        w1 = self.words
        w2 = other_attacked_text.words
        for i in range(min(len(w1), len(w2))):
            if w1[i] != w2[i]:
                return w1[i]
        return None

    def first_word_diff_index(self, other_attacked_text):
        """Returns the index of the first word in self.words that differs from
        other_attacked_text.

        Useful for word swap strategies.
        """
        w1 = self.words
        w2 = other_attacked_text.words
        for i in range(min(len(w1), len(w2))):
            if w1[i] != w2[i]:
                return i
        return None

    def all_words_diff(self, other_attacked_text):
        """Returns the set of indices for which this and other_attacked_text
        have different words."""
        indices = set()
        w1 = self.words
        w2 = other_attacked_text.words
        for i in range(min(len(w1), len(w2))):
            if w1[i] != w2[i]:
                indices.add(i)
        return indices

    def ith_word_diff(self, other_attacked_text, i):
        """Returns whether the word at index i differs from
        other_attacked_text."""
        w1 = self.words
        w2 = other_attacked_text.words
        if len(w1) - 1 < i or len(w2) - 1 < i:
            return True
        return w1[i] != w2[i]

    def words_diff_num(self, other_attacked_text):
        # using edit distance to calculate words diff num
        def generate_tokens(words):
            result = {}
            idx = 1
            for w in words:
                if w not in result:
                    result[w] = idx
                    idx += 1
            return result

        def words_to_tokens(words, tokens):
            result = []
            for w in words:
                result.append(tokens[w])
            return result

        def edit_distance(w1_t, w2_t):
            matrix = [
                [i + j for j in range(len(w2_t) + 1)] for i in range(len(w1_t) + 1)
            ]

            for i in range(1, len(w1_t) + 1):
                for j in range(1, len(w2_t) + 1):
                    if w1_t[i - 1] == w2_t[j - 1]:
                        d = 0
                    else:
                        d = 1
                    matrix[i][j] = min(
                        matrix[i - 1][j] + 1,
                        matrix[i][j - 1] + 1,
                        matrix[i - 1][j - 1] + d,
                    )

            return matrix[len(w1_t)][len(w2_t)]

        def cal_dif(w1, w2):
            tokens = generate_tokens(w1 + w2)
            w1_t = words_to_tokens(w1, tokens)
            w2_t = words_to_tokens(w2, tokens)
            return edit_distance(w1_t, w2_t)

        w1 = self.words
        w2 = other_attacked_text.words
        return cal_dif(w1, w2)

    def convert_from_original_idxs(self, idxs):
        """Takes indices of words from original string and converts them to
        indices of the same words in the current string.

        Uses information from
        ``self.attack_attrs['original_index_map']``, which maps word
        indices from the original to perturbed text.
        """
        if len(self.attack_attrs["original_index_map"]) == 0:
            return idxs
        elif isinstance(idxs, set):
            idxs = list(idxs)

        elif not isinstance(idxs, [list, np.ndarray]):
            raise TypeError(
                f"convert_from_original_idxs got invalid idxs type {type(idxs)}"
            )

        return [self.attack_attrs["original_index_map"][i] for i in idxs]

    def replace_words_at_indices(self, indices, new_words):
        """This code returns a new AttackedText object where the word at
        ``index`` is replaced with a new word."""
        if len(indices) != len(new_words):
            raise ValueError(
                f"Cannot replace {len(new_words)} words at {len(indices)} indices."
            )
        words = self.words[:]
        for i, new_word in zip(indices, new_words):
            if not isinstance(new_word, str):
                raise TypeError(
                    f"replace_words_at_indices requires ``str`` words, got {type(new_word)}"
                )
            if (i < 0) or (i > len(words)):
                raise ValueError(f"Cannot assign word at index {i}")
            words[i] = new_word
        return self.generate_new_attacked_text(words)

    def replace_word_at_index(self, index, new_word):
        """This code returns a new AttackedText object where the word at
        ``index`` is replaced with a new word."""
        if not isinstance(new_word, str):
            raise TypeError(
                f"replace_word_at_index requires ``str`` new_word, got {type(new_word)}"
            )
        return self.replace_words_at_indices([index], [new_word])

    def delete_word_at_index(self, index):
        """This code returns a new AttackedText object where the word at
        ``index`` is removed."""
        return self.replace_word_at_index(index, "")

    def insert_text_after_word_index(self, index, text):
        """Inserts a string before word at index ``index`` and attempts to add
        appropriate spacing."""
        if not isinstance(text, str):
            raise TypeError(f"text must be an str, got type {type(text)}")
        word_at_index = self.words[index]
        new_text = " ".join((word_at_index, text))
        return self.replace_word_at_index(index, new_text)

    def insert_text_before_word_index(self, index, text):
        """Inserts a string before word at index ``index`` and attempts to add
        appropriate spacing."""
        if not isinstance(text, str):
            raise TypeError(f"text must be an str, got type {type(text)}")
        word_at_index = self.words[index]
        # TODO if ``word_at_index`` is at the beginning of a sentence, we should
        # optionally capitalize ``text``.
        new_text = " ".join((text, word_at_index))
        return self.replace_word_at_index(index, new_text)

    def get_deletion_indices(self):
        return self.attack_attrs["original_index_map"][
            self.attack_attrs["original_index_map"] == -1
        ]

    def generate_new_attacked_text(self, new_words):
        """Returns a new AttackedText object and replaces old list of words
        with a new list of words, but preserves the punctuation and spacing of
        the original message.

        ``self.words`` is a list of the words in the current text with
        punctuation removed. However, each "word" in ``new_words`` could
        be an empty string, representing a word deletion, or a string
        with multiple space-separated words, representation an insertion
        of one or more words.
        """
        perturbed_text = ""
        original_text = AttackedText.SPLIT_TOKEN.join(self._text_input.values())
        new_attack_attrs = dict()
        if "label_names" in self.attack_attrs:
            new_attack_attrs["label_names"] = self.attack_attrs["label_names"]
        new_attack_attrs["newly_modified_indices"] = set()
        # Point to previously monitored text.
        new_attack_attrs["previous_attacked_text"] = self
        # Use `new_attack_attrs` to track indices with respect to the original
        # text.
        new_attack_attrs["modified_indices"] = self.attack_attrs[
            "modified_indices"
        ].copy()
        new_attack_attrs["original_index_map"] = self.attack_attrs[
            "original_index_map"
        ].copy()
        new_i = 0
        # Create the new attacked text by swapping out words from the original
        # text with a sequence of 0+ words in the new text.
        for i, (input_word, adv_word_seq) in enumerate(zip(self.words, new_words)):
            word_start = original_text.index(input_word)
            word_end = word_start + len(input_word)
            perturbed_text += original_text[:word_start]
            original_text = original_text[word_end:]
            adv_words = words_from_text(adv_word_seq)
            adv_num_words = len(adv_words)
            num_words_diff = adv_num_words - len(words_from_text(input_word))
            # Track indices on insertions and deletions.
            if num_words_diff != 0:
                # Re-calculated modified indices. If words are inserted or deleted,
                # they could change.
                shifted_modified_indices = set()
                for modified_idx in new_attack_attrs["modified_indices"]:
                    if modified_idx < i:
                        shifted_modified_indices.add(modified_idx)
                    elif modified_idx > i:
                        shifted_modified_indices.add(modified_idx + num_words_diff)
                    else:
                        pass
                new_attack_attrs["modified_indices"] = shifted_modified_indices
                # Track insertions and deletions wrt original text.
                # original_modification_idx = i
                new_idx_map = new_attack_attrs["original_index_map"].copy()
                if num_words_diff == -1:
                    # Word deletion
                    new_idx_map[new_idx_map == i] = -1
                new_idx_map[new_idx_map > i] += num_words_diff

                if num_words_diff > 0 and input_word != adv_words[0]:
                    # If insertion happens before the `input_word`
                    new_idx_map[new_idx_map == i] += num_words_diff

                new_attack_attrs["original_index_map"] = new_idx_map
            # Move pointer and save indices of new modified words.
            for j in range(i, i + adv_num_words):
                if input_word != adv_word_seq:
                    new_attack_attrs["modified_indices"].add(new_i)
                    new_attack_attrs["newly_modified_indices"].add(new_i)
                new_i += 1
            # Check spaces for deleted text.
            if adv_num_words == 0 and len(original_text):
                # Remove extra space (or else there would be two spaces for each
                # deleted word).
                # @TODO What to do with punctuation in this case? This behavior is undefined.
                if i == 0:
                    # If the first word was deleted, take a subsequent space.
                    if original_text[0] == " ":
                        original_text = original_text[1:]
                else:
                    # If a word other than the first was deleted, take a preceding space.
                    if perturbed_text[-1] == " ":
                        perturbed_text = perturbed_text[:-1]
            # Add substitute word(s) to new sentence.
            perturbed_text += adv_word_seq
        perturbed_text += original_text  # Add all of the ending punctuation.

        # Add pointer to self so chain of replacements can be reconstructed.
        new_attack_attrs["prev_attacked_text"] = self

        # Reform perturbed_text into an OrderedDict.
        perturbed_input_texts = perturbed_text.split(AttackedText.SPLIT_TOKEN)
        perturbed_input = OrderedDict(
            zip(self._text_input.keys(), perturbed_input_texts)
        )
        return AttackedText(perturbed_input, attack_attrs=new_attack_attrs)

    def words_diff_ratio(self, x):
        """Get the ratio of words difference between current text and `x`.

        Note that current text and `x` must have same number of words.
        """
        assert self.num_words == x.num_words
        return float(np.sum(self.words != x.words)) / self.num_words

    def align_with_model_tokens(self, model_wrapper):
        """Align AttackedText's `words` with target model's tokenization scheme
        (e.g. word, character, subword). Specifically, we map each word to list
        of indices of tokens that compose the word (e.g. embedding --> ["em",
        "##bed", "##ding"])

        Args:
            model_wrapper (textattack.models.wrappers.ModelWrapper): ModelWrapper of the target model

        Returns:
            word2token_mapping (dict[int, list[int]]): Dictionary that maps i-th word to list of indices.
        """
        tokens = model_wrapper.tokenize([self.tokenizer_input], strip_prefix=True)[0]
        word2token_mapping = {}
        j = 0
        last_matched = 0

        for i, word in enumerate(self.words):
            matched_tokens = []
            while j < len(tokens) and len(word) > 0:
                token = tokens[j].lower()
                idx = word.lower().find(token)
                if idx == 0:
                    word = word[idx + len(token) :]
                    matched_tokens.append(j)
                    last_matched = j
                j += 1

            if not matched_tokens:
                word2token_mapping[i] = None
                j = last_matched
            else:
                word2token_mapping[i] = matched_tokens

        return word2token_mapping

    @property
    def tokenizer_input(self):
        """The tuple of inputs to be passed to the tokenizer."""
        input_tuple = tuple(self._text_input.values())
        # Prefer to return a string instead of a tuple with a single value.
        if len(input_tuple) == 1:
            return input_tuple[0]
        else:
            return input_tuple

    @property
    def column_labels(self):
        """Returns the labels for this text's columns.

        For single-sequence inputs, this simply returns ['text'].
        """
        return list(self._text_input.keys())

    @property
    def words_per_input(self):
        """Returns a list of lists of words corresponding to each input."""
        if not self._words_per_input:
            self._words_per_input = [
                words_from_text(_input) for _input in self._text_input.values()
            ]
        return self._words_per_input

    @property
    def words(self):
        if not self._words:
            self._words = words_from_text(self.text)
        return self._words

    @property
    def text(self):
        """Represents full text input.

        Multiply inputs are joined with a line break.
        """
        return "\n".join(self._text_input.values())

    @property
    def num_words(self):
        """Returns the number of words in the sequence."""
        return len(self.words)

    @property
    def newly_swapped_words(self):
        return [self.words[i] for i in self.attack_attrs["newly_modified_indices"]]

    def printable_text(self, key_color="bold", key_color_method=None):
        """Represents full text input. Adds field descriptions.

        For example, entailment inputs look like:
            ```
            premise: ...
            hypothesis: ...
            ```
        """
        # For single-sequence inputs, don't show a prefix.
        if len(self._text_input) == 1:
            return next(iter(self._text_input.values()))
        # For multiple-sequence inputs, show a prefix and a colon. Optionally,
        # color the key.
        else:
            if key_color_method:

                def ck(k):
                    return textattack.shared.utils.color_text(
                        k, key_color, key_color_method
                    )

            else:

                def ck(k):
                    return k

            return "\n".join(
                f"{ck(key.capitalize())}: {value}"
                for key, value in self._text_input.items()
            )

    def __repr__(self):
        return f'<AttackedText "{self.text}">'