Spaces:
Sleeping
Sleeping
File size: 5,607 Bytes
ede461a 1d02824 b551628 e6ef189 11208e8 6c6da17 bb49074 e6ef189 d0ca54b e6ef189 ede461a e6ef189 d0ca54b e6ef189 0d3784b e6ef189 0d3784b e6ef189 0d3784b e6ef189 0d3784b e6ef189 f12f776 e6ef189 bb49074 e6ef189 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 |
import re
from itertools import count, islice
from typing import Any, Iterable, Literal, Optional, TypedDict, TypeVar, Union, overload
from datasets import Features, Value, get_dataset_config_info
from datasets.features.features import FeatureType, _visit
from presidio_analyzer import AnalyzerEngine, BatchAnalyzerEngine, RecognizerResult
Row = dict[str, Any]
T = TypeVar("T")
BATCH_SIZE = 1
MAX_TEXT_LENGTH = 500
analyzer = AnalyzerEngine()
batch_analyzer = BatchAnalyzerEngine(analyzer)
class PresidioEntity(TypedDict):
text: str
type: str
row_idx: int
column_name: str
@overload
def batched(it: Iterable[T], n: int) -> Iterable[list[T]]:
...
@overload
def batched(it: Iterable[T], n: int, with_indices: Literal[False]) -> Iterable[list[T]]:
...
@overload
def batched(it: Iterable[T], n: int, with_indices: Literal[True]) -> Iterable[tuple[list[int], list[T]]]:
...
def batched(
it: Iterable[T], n: int, with_indices: bool = False
) -> Union[Iterable[list[T]], Iterable[tuple[list[int], list[T]]]]:
it, indices = iter(it), count()
while batch := list(islice(it, n)):
yield (list(islice(indices, len(batch))), batch) if with_indices else batch
def mask(text: str) -> str:
return " ".join(
word[: min(2, len(word) - 1)] + re.sub("[A-Za-z0-9]", "*", word[min(2, len(word) - 1) :])
for word in text.split(" ")
)
def get_strings(row_content: Any) -> str:
if isinstance(row_content, str):
return row_content
if isinstance(row_content, dict):
if "src" in row_content:
return "" # could be image or audio
row_content = list(row_content.values())
if isinstance(row_content, list):
str_items = (get_strings(row_content_item) for row_content_item in row_content)
return "\n".join(str_item for str_item in str_items if str_item)
return ""
def _simple_analyze_iterator_cache(
batch_analyzer: BatchAnalyzerEngine,
texts: Iterable[str],
language: str,
score_threshold: float,
cache: dict[str, list[RecognizerResult]],
) -> list[list[RecognizerResult]]:
not_cached_results = iter(
batch_analyzer.analyze_iterator(
(text for text in texts if text not in cache), language=language, score_threshold=score_threshold
)
)
results = [cache[text] if text in cache else next(not_cached_results) for text in texts]
# cache the last results
cache.clear()
cache.update(dict(zip(texts, results)))
return results
def analyze(
batch_analyzer: BatchAnalyzerEngine,
batch: list[dict[str, str]],
indices: Iterable[int],
scanned_columns: list[str],
columns_descriptions: list[str],
cache: Optional[dict[str, list[RecognizerResult]]] = None,
) -> list[PresidioEntity]:
cache = {} if cache is None else cache
texts = [
f"The following is {columns_description} data:\n\n{example[column_name] or ''}"
for example in batch
for column_name, columns_description in zip(scanned_columns, columns_descriptions)
]
return [
PresidioEntity(
text=texts[i * len(scanned_columns) + j][recognizer_result.start : recognizer_result.end],
type=recognizer_result.entity_type,
row_idx=row_idx,
column_name=column_name,
)
for i, row_idx, recognizer_row_results in zip(
count(),
indices,
batched(_simple_analyze_iterator_cache(batch_analyzer, texts, language="en", score_threshold=0.8, cache=cache), len(scanned_columns)),
)
for j, column_name, columns_description, recognizer_results in zip(
count(), scanned_columns, columns_descriptions, recognizer_row_results
)
for recognizer_result in recognizer_results
if recognizer_result.start >= len(f"The following is {columns_description} data:\n\n")
]
def presidio_scan_entities(
rows: Iterable[Row], scanned_columns: list[str], columns_descriptions: list[str]
) -> Iterable[PresidioEntity]:
cache: dict[str, list[RecognizerResult]] = {}
rows_with_scanned_columns_only = (
{column_name: get_strings(row[column_name])[:MAX_TEXT_LENGTH] for column_name in scanned_columns} for row in rows
)
for indices, batch in batched(rows_with_scanned_columns_only, BATCH_SIZE, with_indices=True):
yield from analyze(
batch_analyzer=batch_analyzer,
batch=batch,
indices=indices,
scanned_columns=scanned_columns,
columns_descriptions=columns_descriptions,
cache=cache,
)
def get_columns_with_strings(features: Features) -> list[str]:
columns_with_strings: list[str] = []
for column, feature in features.items():
str_column = str(column)
with_string = False
def classify(feature: FeatureType) -> None:
nonlocal with_string
if isinstance(feature, Value) and feature.dtype == "string":
with_string = True
_visit(feature, classify)
if with_string:
columns_with_strings.append(str_column)
return columns_with_strings
def get_column_description(column_name: str, feature: FeatureType) -> str:
nested_fields: list[str] = []
def get_nested_field_names(feature: FeatureType) -> None:
nonlocal nested_fields
if isinstance(feature, dict):
nested_fields += list(feature)
_visit(feature, get_nested_field_names)
return f"{column_name} (with {', '.join(nested_fields)})" if nested_fields else column_name
|