Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| # from gliner import GLiNER | |
| from datasets import load_dataset | |
| from peft import PeftModel, PeftConfig | |
| import threading | |
| import time | |
| import torch | |
| from torch.profiler import profile, record_function, ProfilerActivity | |
| from transformers import DebertaV2ForTokenClassification, DebertaV2Tokenizer, pipeline | |
| def predict_entities(text, labels, entity_set): | |
| if labels == []: | |
| entities = recognizer(text) | |
| for entity in entities: | |
| if entity['entity'] in entity_set: | |
| entity_set[entity['entity']] += 1 | |
| else: | |
| entity_set[entity['entity']] = 1 | |
| else: | |
| # Use Gliner labels | |
| entities = model.predict_entities(text, labels, threshold = 0.7) | |
| for entity in entities: | |
| if entity['label'] in entity_set: | |
| entity_set[entity['label']] += 1 | |
| else: | |
| entity_set[entity['label']] = 1 | |
| def process_datasets(start, end, unmasked_text, sizes, index, entity_set, labels): | |
| size = 0 | |
| text = "" | |
| for i in range(start, end): | |
| if len(text) < 700: | |
| text = text + " " + unmasked_text[i] | |
| else: | |
| size += len(text) | |
| predict_entities(text, labels, entity_set) | |
| text = unmasked_text[i] | |
| sizes[index] = size | |
| print(f"Is CUDA available: {torch.cuda.is_available()}") | |
| # True | |
| if torch.cuda.is_available(): | |
| print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}") | |
| # Load the fine-tuned GLiNER model | |
| st.write('Loading the pretrained model ...') | |
| model_name = "CarolXia/pii-kd-deberta-v2" | |
| # config = PeftConfig.from_pretrained(model_name) | |
| model = DebertaV2ForTokenClassification.from_pretrained(model_name, token=st.secrets["HUGGINGFACE_TOKEN"]) | |
| if torch.cuda.is_available(): | |
| model = model.to("cuda") | |
| # Try quantization instead | |
| # model = AutoModelForTokenClassification.from_pretrained(model_name, device_map="auto", load_in_8bit=True) | |
| tokenizer = DebertaV2Tokenizer.from_pretrained("microsoft/mdeberta-v3-base", token=st.secrets["HUGGINGFACE_TOKEN"]) | |
| recognizer = pipeline("ner", model=model, tokenizer=tokenizer) | |
| # model_name = "urchade/gliner_multi_pii-v1" | |
| # model = GLiNER.from_pretrained(model_name) | |
| # print weights | |
| pytorch_total_params = sum(p.numel() for p in model.parameters()) | |
| torch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
| print(f'total params: {pytorch_total_params}. tunable params: {torch_total_params}') | |
| # Sample text containing PII/PHI entities | |
| text = """ | |
| Hello Jane Doe. Your AnyCompany Financial Services, LLC credit card account | |
| 4111-0000-1111-0000 has a minimum payment of $24.53 that is due by July 31st. | |
| Based on your autopay settings, we will withdraw your payment on the due date from | |
| your bank account XXXXXX1111 with the routing number XXXXX0000. | |
| Your latest statement was mailed to 100 Main Street, Anytown, WA 98121. | |
| After your payment is received, you will receive a confirmation text message | |
| at 206-555-0100. | |
| If you have questions about your bill, AnyCompany Customer Service is available by | |
| phone at 206-555-0199 or email at support@anycompany.com. | |
| """ | |
| # Define the labels for PII/PHI entities | |
| labels = [ | |
| "medical_record_number", | |
| "date_of_birth", | |
| "ssn", | |
| "date", | |
| "first_name", | |
| "email", | |
| "last_name", | |
| "customer_id", | |
| "employee_id", | |
| "name", | |
| "street_address", | |
| "phone_number", | |
| "ipv4", | |
| "credit_card_number", | |
| "license_plate", | |
| "address", | |
| "user_name", | |
| "device_identifier", | |
| "bank_routing_number", | |
| "date_time", | |
| "company_name", | |
| "unique_identifier", | |
| "biometric_identifier", | |
| "account_number", | |
| "city", | |
| "certificate_license_number", | |
| "time", | |
| "postcode", | |
| "vehicle_identifier", | |
| "coordinate", | |
| "country", | |
| "api_key", | |
| "ipv6", | |
| "password", | |
| "health_plan_beneficiary_number", | |
| "national_id", | |
| "tax_id", | |
| "url", | |
| "state", | |
| "swift_bic", | |
| "cvv", | |
| "pin" | |
| ] | |
| st.write('Trying a sample first') | |
| st.write(text) | |
| # Predict entities with a confidence threshold of 0.7 | |
| # entities = model.predict_entities(text, labels, threshold=0.7) | |
| entities = recognizer(text) | |
| # Display the detected entities | |
| for entity in entities: | |
| st.write(entity) | |
| st.write('Processing the full dataset now ...') | |
| entity_set=dict() | |
| dataset = load_dataset("Isotonic/pii-masking-200k", split="train") | |
| unmasked_text = dataset['unmasked_text'] # This will load the entire column inmemory. Must do this to avoid I/O delay later | |
| st.write('Number of rows in the dataset ', dataset.num_rows) | |
| sizes = [0] * 5 | |
| start = time.time() | |
| t0 = threading.Thread(target=process_datasets, args=(0, 10, unmasked_text, sizes, 0, entity_set, [])) | |
| t1 = threading.Thread(target=process_datasets, args=(10, 20, unmasked_text, sizes, 1, entity_set, [])) | |
| t2 = threading.Thread(target=process_datasets, args=(20, 30, unmasked_text, sizes, 2, entity_set, [])) | |
| t3 = threading.Thread(target=process_datasets, args=(30, 40, unmasked_text, sizes, 3, entity_set, [])) | |
| t4 = threading.Thread(target=process_datasets, args=(40, 50, unmasked_text, sizes, 4, entity_set, [])) | |
| # with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], profile_memory=True, record_shapes=True) as prof: | |
| # process_datasets(0, 50, unmasked_text, sizes, 0, entity_set, []) | |
| t0.start() | |
| t1.start() | |
| t2.start() | |
| t3.start() | |
| t4.start() | |
| t0.join() | |
| t1.join() | |
| t2.join() | |
| t3.join() | |
| t4.join() | |
| end = time.time() | |
| length = end - start | |
| # Show the results : this can be altered however you like | |
| st.write('Bytes processed ', sum(sizes)) | |
| st.write("It took", length, "seconds!") | |
| # Display the summary | |
| st.write('Total entities found') | |
| for key in entity_set: | |
| st.write(key, ' => ', entity_set[key]) | |
| st.write(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10)) | |