from collections import Counter from itertools import count, groupby, islice from operator import itemgetter from typing import Any, Iterable, TypeVar import gradio as gr import requests import pandas as pd from datasets import Features from gradio_huggingfacehub_search import HuggingfaceHubSearch from requests.adapters import HTTPAdapter, Retry from analyze import PresidioEntity, analyzer, get_column_description, get_columns_with_strings, mask, presidio_scan_entities MAX_ROWS = 100 T = TypeVar("T") session = requests.Session() retries = Retry(total=5, backoff_factor=1, status_forcelist=[502, 503, 504]) session.mount('http://', HTTPAdapter(max_retries=retries)) DEFAULT_PRESIDIO_ENTITIES = sorted([ 'PERSON', 'CREDIT_CARD', 'US_SSN', 'US_DRIVER_LICENSE', 'PHONE_NUMBER', 'US_PASSPORT', 'EMAIL_ADDRESS', 'IP_ADDRESS', 'US_BANK_NUMBER', 'IBAN_CODE', 'EMAIL', ]) def stream_rows(dataset: str, config: str, split: str) -> Iterable[dict[str, Any]]: batch_size = 100 for i in count(): rows_resp = session.get(f"https://datasets-server.huggingface.co/rows?dataset={dataset}&config={config}&split={split}&offset={i * batch_size}&length={batch_size}", timeout=10).json() if "error" in rows_resp: raise RuntimeError(rows_resp["error"]) if not rows_resp["rows"]: break for row_item in rows_resp["rows"]: yield row_item["row"] class track_iter: def __init__(self, it: Iterable[T]): self.it = it self.next_idx = 0 def __iter__(self) -> T: for item in self.it: self.next_idx += 1 yield item def presidio_report(presidio_entities: list[PresidioEntity], next_row_idx: int, num_rows: int) -> dict[str, float]: title = f"Scan finished: {len(presidio_entities)} entities found" if num_rows == next_row_idx else "Scan in progress..." counter = Counter([title] * next_row_idx) for row_idx, presidio_entities_per_row in groupby(presidio_entities, itemgetter("row_idx")): counter.update(set("% of rows with " + presidio_entity["type"] for presidio_entity in presidio_entities_per_row)) return dict((presidio_entity_type, presidio_entity_type_row_count / num_rows) for presidio_entity_type, presidio_entity_type_row_count in counter.most_common()) def analyze_dataset(dataset: str, enabled_presidio_entities: list[str] = DEFAULT_PRESIDIO_ENTITIES, show_texts_without_masks: bool = False) -> pd.DataFrame: info_resp = session.get(f"https://datasets-server.huggingface.co/info?dataset={dataset}", timeout=3).json() if "error" in info_resp: yield "❌ " + info_resp["error"], pd.DataFrame() return config = "default" if "default" in info_resp["dataset_info"] else next(iter(info_resp["dataset_info"])) features = Features.from_dict(info_resp["dataset_info"][config]["features"]) split = "train" if "train" in info_resp["dataset_info"][config]["splits"] else next(iter(info_resp["dataset_info"][config]["splits"])) num_rows = min(info_resp["dataset_info"][config]["splits"][split]["num_examples"], MAX_ROWS) scanned_columns = get_columns_with_strings(features) columns_descriptions = [ get_column_description(column_name, features[column_name]) for column_name in scanned_columns ] rows = track_iter(islice(stream_rows(dataset, config, split), MAX_ROWS)) presidio_entities = [] for presidio_entity in presidio_scan_entities( rows, scanned_columns=scanned_columns, columns_descriptions=columns_descriptions ): if not show_texts_without_masks: presidio_entity["text"] = mask(presidio_entity["text"]) if presidio_entity["type"] in enabled_presidio_entities: presidio_entities.append(presidio_entity) yield presidio_report(presidio_entities, next_row_idx=rows.next_idx, num_rows=num_rows), pd.DataFrame(presidio_entities) yield presidio_report(presidio_entities, next_row_idx=rows.next_idx, num_rows=num_rows), pd.DataFrame(presidio_entities) with gr.Blocks() as demo: gr.Markdown("# Scan datasets using Presidio") gr.Markdown("The space takes an HF dataset name as an input, and returns the list of entities detected by Presidio in the first samples.") inputs = [ HuggingfaceHubSearch( label="Hub Dataset ID", placeholder="Search for dataset id on Huggingface", search_type="dataset", ), gr.CheckboxGroup( label="Presidio entities", choices=sorted(analyzer.get_supported_entities()), value=DEFAULT_PRESIDIO_ENTITIES, interactive=True, ), gr.Checkbox(label="Show texts without masks", value=False), ] button = gr.Button("Run Presidio Scan") outputs = [ gr.Label(show_label=False), gr.DataFrame(), ] button.click(analyze_dataset, inputs, outputs) gr.Examples( [ ["microsoft/orca-math-word-problems-200k"], ["tatsu-lab/alpaca"], ["Anthropic/hh-rlhf"], ["OpenAssistant/oasst1"], ["sidhq/email-thread-summary"], ["lhoestq/fake_name_and_ssn"] ], inputs, outputs, fn=analyze_dataset, run_on_click=True, cache_examples=False, ) demo.launch()