File size: 5,372 Bytes
d0ca54b
 
 
9627e38
c143e76
d96e79d
7d373bd
 
c143e76
7d373bd
d180662
d96e79d
d0ca54b
c143e76
05daa8e
9627e38
d180662
 
 
880a98d
 
 
 
 
 
 
 
 
 
 
 
 
05daa8e
 
c143e76
 
d180662
c143e76
 
 
 
 
 
d96e79d
da70c80
 
 
 
 
 
 
 
 
 
 
d0ca54b
 
 
 
 
 
 
 
 
 
d180662
c143e76
 
 
 
 
 
da70c80
c143e76
 
 
 
da70c80
c143e76
da70c80
c143e76
da70c80
d0ca54b
 
880a98d
bb49074
d0ca54b
 
7d373bd
8ffc0c7
 
bf76396
 
7d373bd
 
 
 
 
f19312e
bb49074
880a98d
 
797cf6b
bb49074
d0ca54b
bf76396
 
 
d0ca54b
7d373bd
bf76396
 
61c510a
880a98d
d0ca54b
 
 
 
 
 
880a98d
61c510a
 
 
d0ca54b
 
61c510a
7d373bd
d96e79d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
from collections import Counter
from itertools import count, groupby, islice
from operator import itemgetter
from typing import Any, Iterable, TypeVar

import gradio as gr
import requests
import pandas as pd
from datasets import Features
from gradio_huggingfacehub_search import HuggingfaceHubSearch
from requests.adapters import HTTPAdapter, Retry

from analyze import PresidioEntity, analyzer, get_column_description, get_columns_with_strings, mask, presidio_scan_entities

MAX_ROWS = 100
T = TypeVar("T")
session = requests.Session()
retries = Retry(total=5, backoff_factor=1, status_forcelist=[502, 503, 504])
session.mount('http://', HTTPAdapter(max_retries=retries))
DEFAULT_PRESIDIO_ENTITIES = sorted([
    'PERSON',
    'CREDIT_CARD',
    'US_SSN',
    'US_DRIVER_LICENSE',
    'PHONE_NUMBER',
    'US_PASSPORT',
    'EMAIL_ADDRESS',
    'IP_ADDRESS',
    'US_BANK_NUMBER',
    'IBAN_CODE',
    'EMAIL',
])

def stream_rows(dataset: str, config: str, split: str) -> Iterable[dict[str, Any]]:
    batch_size = 100
    for i in count():
        rows_resp = session.get(f"https://datasets-server.huggingface.co/rows?dataset={dataset}&config={config}&split={split}&offset={i * batch_size}&length={batch_size}", timeout=10).json()
        if "error" in rows_resp:
            raise RuntimeError(rows_resp["error"])
        if not rows_resp["rows"]:
            break
        for row_item in rows_resp["rows"]:
            yield row_item["row"]

class track_iter:

    def __init__(self, it: Iterable[T]):
        self.it = it
        self.next_idx = 0

    def __iter__(self) -> T:
        for item in self.it:
            self.next_idx += 1
            yield item


def presidio_report(presidio_entities: list[PresidioEntity], next_row_idx: int, num_rows: int) -> dict[str, float]:
    title = f"Scan finished: {len(presidio_entities)} entities found" if num_rows == next_row_idx else "Scan in progress..."
    counter = Counter([title] * next_row_idx)
    for row_idx, presidio_entities_per_row in groupby(presidio_entities, itemgetter("row_idx")):
        counter.update(set("% of rows with " + presidio_entity["type"] for presidio_entity in presidio_entities_per_row))
    return dict((presidio_entity_type, presidio_entity_type_row_count / num_rows) for presidio_entity_type, presidio_entity_type_row_count in counter.most_common())


def analyze_dataset(dataset: str, enabled_presidio_entities: list[str] = DEFAULT_PRESIDIO_ENTITIES, show_texts_without_masks: bool = False) -> pd.DataFrame:
    info_resp = session.get(f"https://datasets-server.huggingface.co/info?dataset={dataset}", timeout=3).json()
    if "error" in info_resp:
        yield "❌ " + info_resp["error"], pd.DataFrame()
        return
    config = "default" if "default" in info_resp["dataset_info"] else next(iter(info_resp["dataset_info"]))
    features = Features.from_dict(info_resp["dataset_info"][config]["features"])
    split = "train" if "train" in info_resp["dataset_info"][config]["splits"] else next(iter(info_resp["dataset_info"][config]["splits"]))
    num_rows = min(info_resp["dataset_info"][config]["splits"][split]["num_examples"], MAX_ROWS)
    scanned_columns = get_columns_with_strings(features)
    columns_descriptions = [
        get_column_description(column_name, features[column_name]) for column_name in scanned_columns
    ]
    rows = track_iter(islice(stream_rows(dataset, config, split), MAX_ROWS))
    presidio_entities = []
    for presidio_entity in presidio_scan_entities(
        rows, scanned_columns=scanned_columns, columns_descriptions=columns_descriptions
    ):
        if not show_texts_without_masks:
            presidio_entity["text"] = mask(presidio_entity["text"])
        if presidio_entity["type"] in enabled_presidio_entities:
            presidio_entities.append(presidio_entity)
            yield presidio_report(presidio_entities, next_row_idx=rows.next_idx, num_rows=num_rows), pd.DataFrame(presidio_entities)
    yield presidio_report(presidio_entities, next_row_idx=rows.next_idx, num_rows=num_rows), pd.DataFrame(presidio_entities)

with gr.Blocks() as demo:
    gr.Markdown("# Scan datasets using Presidio")
    gr.Markdown("The space takes an HF dataset name as an input, and returns the list of entities detected by Presidio in the first samples.")
    inputs = [
        HuggingfaceHubSearch(
            label="Hub Dataset ID",
            placeholder="Search for dataset id on Huggingface",
            search_type="dataset",
        ),
        gr.CheckboxGroup(
            label="Presidio entities",
            choices=sorted(analyzer.get_supported_entities()),
            value=DEFAULT_PRESIDIO_ENTITIES,
            interactive=True,
        ),
        gr.Checkbox(label="Show texts without masks", value=False),
    ]
    button = gr.Button("Run Presidio Scan")
    outputs = [
        gr.Label(show_label=False),
        gr.DataFrame(),
    ]
    button.click(analyze_dataset, inputs, outputs)
    gr.Examples(
        [
            ["microsoft/orca-math-word-problems-200k"],
            ["tatsu-lab/alpaca"],
            ["Anthropic/hh-rlhf"],
            ["OpenAssistant/oasst1"],
            ["sidhq/email-thread-summary"],
            ["lhoestq/fake_name_and_ssn"]
        ],
        inputs,
        outputs,
        fn=analyze_dataset,
        run_on_click=True,
        cache_examples=False,
    )

demo.launch()