|
import json |
|
import os |
|
import pickle as pkl |
|
import re |
|
import shutil |
|
import string |
|
from collections import Counter |
|
from pathlib import Path |
|
|
|
import numpy as np |
|
import torch |
|
|
|
MAX_USER_QUERY_LEN = 80 |
|
|
|
|
|
DEFAULT_QUERIES = { |
|
"Example Query 1": "Who visited microsoft.com on September 18?", |
|
"Example Query 2": "Does Kate have a driving licence?", |
|
"Example Query 3": "What's David Johnson's phone number?", |
|
} |
|
|
|
|
|
CURRENT_DIR = Path(__file__).parent |
|
|
|
DATA_PATH = CURRENT_DIR / "files" |
|
LOGREG_MODEL_PATH = CURRENT_DIR / "models" / "cml_logreg.model" |
|
DEPLOYMENT_DIR = CURRENT_DIR / "deployment" |
|
KEYS_DIR = DEPLOYMENT_DIR / ".fhe_keys" |
|
|
|
ORIGINAL_FILE_PATH = DATA_PATH / "original_document.txt" |
|
ANONYMIZED_FILE_PATH = DATA_PATH / "anonymized_document.txt" |
|
MAPPING_UUID_PATH = DATA_PATH / "original_document_uuid_mapping.json" |
|
MAPPING_SENTENCES_PATH = DATA_PATH / "mapping_clear_to_anonymized.pkl" |
|
PROMPT_PATH = DATA_PATH / "chatgpt_prompt.txt" |
|
|
|
ALL_DIRS = [KEYS_DIR] |
|
|
|
PUNCTUATION_LIST = list(string.punctuation) |
|
PUNCTUATION_LIST.remove("%") |
|
PUNCTUATION_LIST.remove("$") |
|
PUNCTUATION_LIST = "".join(PUNCTUATION_LIST) |
|
|
|
|
|
def clean_directory() -> None: |
|
"""Clear direcgtories""" |
|
|
|
print("Cleaning...\n") |
|
for target_dir in ALL_DIRS: |
|
if os.path.exists(target_dir) and os.path.isdir(target_dir): |
|
shutil.rmtree(target_dir) |
|
target_dir.mkdir(exist_ok=True, parents=True) |
|
|
|
|
|
def get_batch_text_representation(texts, model, tokenizer, batch_size=1): |
|
"""Get mean-pooled representations of given texts in batches.""" |
|
mean_pooled_batch = [] |
|
for i in range(0, len(texts), batch_size): |
|
batch_texts = texts[i : i + batch_size] |
|
inputs = tokenizer(batch_texts, return_tensors="pt", padding=True, truncation=True) |
|
with torch.no_grad(): |
|
outputs = model(**inputs, output_hidden_states=False) |
|
last_hidden_states = outputs.last_hidden_state |
|
input_mask_expanded = ( |
|
inputs["attention_mask"].unsqueeze(-1).expand(last_hidden_states.size()).float() |
|
) |
|
sum_embeddings = torch.sum(last_hidden_states * input_mask_expanded, 1) |
|
sum_mask = input_mask_expanded.sum(1) |
|
mean_pooled = sum_embeddings / sum_mask |
|
mean_pooled_batch.extend(mean_pooled.cpu().detach().numpy()) |
|
return np.array(mean_pooled_batch) |
|
|
|
|
|
def is_user_query_valid(user_query: str) -> bool: |
|
""" |
|
Check if the `user_query` is None and not empty. |
|
Args: |
|
user_query (str): The input text to be checked. |
|
Returns: |
|
bool: True if the `user_query` is None or empty, False otherwise. |
|
""" |
|
|
|
is_default_query = user_query in DEFAULT_QUERIES.values() |
|
|
|
|
|
is_exceeded_max_length = user_query is not None and len(user_query) <= MAX_USER_QUERY_LEN |
|
|
|
return not is_default_query and not is_exceeded_max_length |
|
|
|
|
|
def compare_texts_ignoring_extra_spaces(original_text, modified_text): |
|
"""Check if the modified_text is identical to the original_text except for additional spaces. |
|
|
|
Args: |
|
original_text (str): The original text for comparison. |
|
modified_text (str): The modified text to compare against the original. |
|
|
|
Returns: |
|
(bool): True if the modified_text is the same as the original_text except for |
|
additional spaces; False otherwise. |
|
""" |
|
normalized_original = " ".join(original_text.split()) |
|
normalized_modified = " ".join(modified_text.split()) |
|
|
|
return normalized_original == normalized_modified |
|
|
|
|
|
def is_strict_deletion_only(original_text, modified_text): |
|
|
|
|
|
|
|
pattern = r"(?<=[\w])(?=[^\w\s])|(?<=[^\w\s])(?=[\w])" |
|
|
|
|
|
original_text = re.sub(pattern, " ", original_text) |
|
modified_text = re.sub(pattern, " ", modified_text) |
|
|
|
|
|
original_words = Counter(original_text.lower().split()) |
|
modified_words = Counter(modified_text.lower().split()) |
|
|
|
base_words = all(item in original_words.keys() for item in modified_words.keys()) |
|
base_count = all(original_words[k] >= v for k, v in modified_words.items()) |
|
|
|
return base_words and base_count |
|
|
|
|
|
def read_txt(file_path): |
|
"""Read text from a file.""" |
|
with open(file_path, "r", encoding="utf-8") as file: |
|
return file.read() |
|
|
|
|
|
def write_txt(file_path, data): |
|
"""Write text to a file.""" |
|
with open(file_path, "w", encoding="utf-8") as file: |
|
file.write(data) |
|
|
|
|
|
def write_pickle(file_path, data): |
|
"""Save data to a pickle file.""" |
|
with open(file_path, "wb") as f: |
|
pkl.dump(data, f) |
|
|
|
|
|
def read_pickle(file_name): |
|
"""Load data from a pickle file.""" |
|
with open(file_name, "rb") as file: |
|
return pkl.load(file) |
|
|
|
|
|
def read_json(file_name): |
|
"""Load data from a json file.""" |
|
with open(file_name, "r") as file: |
|
return json.load(file) |
|
|
|
|
|
def write_json(file_name, data): |
|
"""Save data to a json file.""" |
|
with open(file_name, "w", encoding="utf-8") as file: |
|
json.dump(data, file, indent=4, sort_keys=True) |
|
|