lhoestq HF staff commited on
Commit
e6ef189
1 Parent(s): a28e8f3

add analyze code

Browse files
Files changed (1) hide show
  1. analyze.py +162 -0
analyze.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Iterable, Literal, Optional, TypedDict, TypeVar, overload
2
+
3
+ from datasets import Features, Value, get_dataset_config_info
4
+ from datasets.features.features import FeatureType, _visit
5
+ from presidio_analyzer import AnalyzerEngine, BatchAnalyzerEngine, RecognizerResult
6
+
7
+
8
+ Row = dict[str, Any]
9
+ T = TypeVar("T")
10
+ BATCH_SIZE = 10
11
+ batch_analyzer: Optional[BatchAnalyzerEngine] = None
12
+
13
+
14
+ class PresidioEntity(TypedDict):
15
+ text: str
16
+ type: str
17
+ row_idx: int
18
+ column_name: str
19
+
20
+
21
+ @overload
22
+ def batched(it: Iterable[T], n: int) -> Iterable[list[T]]:
23
+ ...
24
+
25
+
26
+ @overload
27
+ def batched(it: Iterable[T], n: int, with_indices: Literal[False]) -> Iterable[list[T]]:
28
+ ...
29
+
30
+
31
+ @overload
32
+ def batched(it: Iterable[T], n: int, with_indices: Literal[True]) -> Iterable[tuple[list[int], list[T]]]:
33
+ ...
34
+
35
+
36
+ def batched(
37
+ it: Iterable[T], n: int, with_indices: bool = False
38
+ ) -> Union[Iterable[list[T]], Iterable[tuple[list[int], list[T]]]]:
39
+ it, indices = iter(it), count()
40
+ while batch := list(islice(it, n)):
41
+ yield (list(islice(indices, len(batch))), batch) if with_indices else batch
42
+
43
+
44
+ def mask(text: str) -> str:
45
+ return " ".join(
46
+ word[: min(2, len(word) - 1)] + re.sub("[A-Za-z0-9]", "*", word[min(2, len(word) - 1) :])
47
+ for word in text.split(" ")
48
+ )
49
+
50
+
51
+ def get_strings(row_content: Any) -> str:
52
+ if isinstance(row_content, str):
53
+ return row_content
54
+ if isinstance(row_content, dict):
55
+ row_content = list(row_content.values())
56
+ if isinstance(row_content, list):
57
+ str_items = (get_strings(row_content_item) for row_content_item in row_content)
58
+ return "\n".join(str_item for str_item in str_items if str_item)
59
+ return ""
60
+
61
+
62
+ def _simple_analyze_iterator_cache(
63
+ batch_analyzer: BatchAnalyzerEngine,
64
+ texts: Iterable[str],
65
+ language: str,
66
+ score_threshold: float,
67
+ cache: dict[str, list[RecognizerResult]],
68
+ ) -> list[list[RecognizerResult]]:
69
+ not_cached_results = iter(
70
+ batch_analyzer.analyze_iterator(
71
+ (text for text in texts if text not in cache), language=language, score_threshold=score_threshold
72
+ )
73
+ )
74
+ results = [cache[text] if text in cache else next(not_cached_results) for text in texts]
75
+ # cache the last results
76
+ cache.clear()
77
+ cache.update(dict(zip(texts, results)))
78
+ return results
79
+
80
+
81
+ def analyze(
82
+ batch_analyzer: BatchAnalyzerEngine,
83
+ batch: list[dict[str, str]],
84
+ indices: Iterable[int],
85
+ scanned_columns: list[str],
86
+ columns_descriptions: list[str],
87
+ cache: Optional[dict[str, list[RecognizerResult]]] = None,
88
+ ) -> list[PresidioEntity]:
89
+ cache = {} if cache is None else cache
90
+ texts = [
91
+ f"The following is {columns_description} data:\n\n{example[column_name] or ''}"
92
+ for example in batch
93
+ for column_name, columns_description in zip(scanned_columns, columns_descriptions)
94
+ ]
95
+ return [
96
+ PresidioEntity(
97
+ text=mask(texts[i][recognizer_result.start : recognizer_result.end]),
98
+ type=recognizer_result.entity_type,
99
+ row_idx=row_idx,
100
+ column_name=column_name,
101
+ )
102
+ for i, row_idx, recognizer_results in zip(
103
+ count(),
104
+ indices,
105
+ _simple_analyze_iterator_cache(batch_analyzer, texts, language="en", score_threshold=0.8, cache=cache),
106
+ )
107
+ for column_name, columns_description, recognizer_result in zip(
108
+ scanned_columns, columns_descriptions, recognizer_results
109
+ )
110
+ if recognizer_result.start >= len(f"The following is {columns_description} data:\n\n")
111
+ ]
112
+
113
+
114
+ def presidio_scan_entities(
115
+ rows: Iterable[Row], scanned_columns: list[str], columns_descriptions: list[str]
116
+ ) -> Iterable[PresidioEntity]:
117
+ global batch_analyzer
118
+ cache: dict[str, list[RecognizerResult]] = {}
119
+ if batch_analyzer is None:
120
+ batch_analyser = BatchAnalyzerEngine(AnalyzerEngine())
121
+ rows_with_scanned_columns_only = (
122
+ {column_name: get_strings(row[column_name]) for column_name in scanned_columns} for row in rows
123
+ )
124
+ for indices, batch in batched(rows_with_scanned_columns_only, BATCH_SIZE, with_indices=True):
125
+ yield from analyze(
126
+ batch_analyzer=batch_analyser,
127
+ batch=batch,
128
+ indices=indices,
129
+ scanned_columns=scanned_columns,
130
+ columns_descriptions=columns_descriptions,
131
+ cache=cache,
132
+ )
133
+
134
+
135
+ def get_columns_with_strings(features: Features) -> list[str]:
136
+ columns_with_strings: list[str] = []
137
+
138
+ for column, feature in features.items():
139
+ str_column = str(column)
140
+ with_string = False
141
+
142
+ def classify(feature: FeatureType) -> None:
143
+ nonlocal with_string
144
+ if isinstance(feature, Value) and feature.dtype == "string":
145
+ with_string = True
146
+
147
+ _visit(feature, classify)
148
+ if with_string:
149
+ columns_with_strings.append(str_column)
150
+ return columns_with_strings
151
+
152
+
153
+ def get_column_description(column_name: str, feature: FeatureType) -> str:
154
+ nested_fields: list[str] = []
155
+
156
+ def get_nested_field_names(feature: FeatureType) -> None:
157
+ nonlocal nested_fields
158
+ if isinstance(feature, dict):
159
+ nested_fields += list(feature)
160
+
161
+ _visit(feature, get_nested_field_names)
162
+ return f"{column_name} (with {', '.join(nested_fields)})" if nested_fields else column_name