File size: 5,607 Bytes
ede461a
1d02824
b551628
e6ef189
 
 
 
 
 
 
 
11208e8
6c6da17
bb49074
 
e6ef189
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d0ca54b
 
 
 
e6ef189
 
 
 
 
 
ede461a
 
e6ef189
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d0ca54b
e6ef189
 
 
 
0d3784b
e6ef189
 
0d3784b
e6ef189
0d3784b
 
e6ef189
0d3784b
e6ef189
 
 
 
 
 
 
 
 
f12f776
e6ef189
 
 
bb49074
e6ef189
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
import re
from itertools import count, islice
from typing import Any, Iterable, Literal, Optional, TypedDict, TypeVar, Union, overload

from datasets import Features, Value, get_dataset_config_info
from datasets.features.features import FeatureType, _visit
from presidio_analyzer import AnalyzerEngine, BatchAnalyzerEngine, RecognizerResult


Row = dict[str, Any]
T = TypeVar("T")
BATCH_SIZE = 1
MAX_TEXT_LENGTH = 500
analyzer = AnalyzerEngine()
batch_analyzer = BatchAnalyzerEngine(analyzer)


class PresidioEntity(TypedDict):
    text: str
    type: str
    row_idx: int
    column_name: str


@overload
def batched(it: Iterable[T], n: int) -> Iterable[list[T]]:
    ...


@overload
def batched(it: Iterable[T], n: int, with_indices: Literal[False]) -> Iterable[list[T]]:
    ...


@overload
def batched(it: Iterable[T], n: int, with_indices: Literal[True]) -> Iterable[tuple[list[int], list[T]]]:
    ...


def batched(
    it: Iterable[T], n: int, with_indices: bool = False
) -> Union[Iterable[list[T]], Iterable[tuple[list[int], list[T]]]]:
    it, indices = iter(it), count()
    while batch := list(islice(it, n)):
        yield (list(islice(indices, len(batch))), batch) if with_indices else batch


def mask(text: str) -> str:
    return " ".join(
        word[: min(2, len(word) - 1)] + re.sub("[A-Za-z0-9]", "*", word[min(2, len(word) - 1) :])
        for word in text.split(" ")
    )


def get_strings(row_content: Any) -> str:
    if isinstance(row_content, str):
        return row_content
    if isinstance(row_content, dict):
        if "src" in row_content:
            return ""  # could be image or audio
        row_content = list(row_content.values())
    if isinstance(row_content, list):
        str_items = (get_strings(row_content_item) for row_content_item in row_content)
        return "\n".join(str_item for str_item in str_items if str_item)
    return ""


def _simple_analyze_iterator_cache(
    batch_analyzer: BatchAnalyzerEngine,
    texts: Iterable[str],
    language: str,
    score_threshold: float,
    cache: dict[str, list[RecognizerResult]],
) -> list[list[RecognizerResult]]:
    not_cached_results = iter(
        batch_analyzer.analyze_iterator(
            (text for text in texts if text not in cache), language=language, score_threshold=score_threshold
        )
    )
    results = [cache[text] if text in cache else next(not_cached_results) for text in texts]
    # cache the last results
    cache.clear()
    cache.update(dict(zip(texts, results)))
    return results


def analyze(
    batch_analyzer: BatchAnalyzerEngine,
    batch: list[dict[str, str]],
    indices: Iterable[int],
    scanned_columns: list[str],
    columns_descriptions: list[str],
    cache: Optional[dict[str, list[RecognizerResult]]] = None,
) -> list[PresidioEntity]:
    cache = {} if cache is None else cache
    texts = [
        f"The following is {columns_description} data:\n\n{example[column_name] or ''}"
        for example in batch
        for column_name, columns_description in zip(scanned_columns, columns_descriptions)
    ]
    return [
        PresidioEntity(
            text=texts[i * len(scanned_columns) + j][recognizer_result.start : recognizer_result.end],
            type=recognizer_result.entity_type,
            row_idx=row_idx,
            column_name=column_name,
        )
        for i, row_idx, recognizer_row_results in zip(
            count(),
            indices,
            batched(_simple_analyze_iterator_cache(batch_analyzer, texts, language="en", score_threshold=0.8, cache=cache), len(scanned_columns)),
        )
        for j, column_name, columns_description, recognizer_results in zip(
            count(), scanned_columns, columns_descriptions, recognizer_row_results
        )
        for recognizer_result in recognizer_results
        if recognizer_result.start >= len(f"The following is {columns_description} data:\n\n")
    ]


def presidio_scan_entities(
    rows: Iterable[Row], scanned_columns: list[str], columns_descriptions: list[str]
) -> Iterable[PresidioEntity]:
    cache: dict[str, list[RecognizerResult]] = {}
    rows_with_scanned_columns_only = (
        {column_name: get_strings(row[column_name])[:MAX_TEXT_LENGTH] for column_name in scanned_columns} for row in rows
    )
    for indices, batch in batched(rows_with_scanned_columns_only, BATCH_SIZE, with_indices=True):
        yield from analyze(
            batch_analyzer=batch_analyzer,
            batch=batch,
            indices=indices,
            scanned_columns=scanned_columns,
            columns_descriptions=columns_descriptions,
            cache=cache,
        )


def get_columns_with_strings(features: Features) -> list[str]:
    columns_with_strings: list[str] = []

    for column, feature in features.items():
        str_column = str(column)
        with_string = False

        def classify(feature: FeatureType) -> None:
            nonlocal with_string
            if isinstance(feature, Value) and feature.dtype == "string":
                with_string = True

        _visit(feature, classify)
        if with_string:
            columns_with_strings.append(str_column)
    return columns_with_strings


def get_column_description(column_name: str, feature: FeatureType) -> str:
    nested_fields: list[str] = []

    def get_nested_field_names(feature: FeatureType) -> None:
        nonlocal nested_fields
        if isinstance(feature, dict):
            nested_fields += list(feature)

    _visit(feature, get_nested_field_names)
    return f"{column_name} (with {', '.join(nested_fields)})" if nested_fields else column_name