Spaces:
Sleeping
Sleeping
import json | |
import os | |
import pickle as pkl | |
import re | |
import shutil | |
import string | |
from collections import Counter | |
from pathlib import Path | |
import numpy as np | |
import torch | |
MAX_USER_QUERY_LEN = 80 | |
# List of example queries for easy access | |
DEFAULT_QUERIES = { | |
"Example Query 1": "Who visited microsoft.com on September 18?", | |
"Example Query 2": "Does Kate have a driving licence?", | |
"Example Query 3": "What's David Johnson's phone number?", | |
} | |
CURRENT_DIR = Path(__file__).parent | |
DATA_PATH = CURRENT_DIR / "files" | |
LOGREG_MODEL_PATH = CURRENT_DIR / "models" / "cml_logreg.model" | |
DEPLOYMENT_DIR = CURRENT_DIR / "deployment" | |
KEYS_DIR = DEPLOYMENT_DIR / ".fhe_keys" | |
ORIGINAL_FILE_PATH = DATA_PATH / "original_document.txt" | |
ANONYMIZED_FILE_PATH = DATA_PATH / "anonymized_document.txt" | |
MAPPING_UUID_PATH = DATA_PATH / "original_document_uuid_mapping.json" | |
MAPPING_SENTENCES_PATH = DATA_PATH / "mapping_clear_to_anonymized.pkl" | |
PROMPT_PATH = DATA_PATH / "chatgpt_prompt.txt" | |
ALL_DIRS = [KEYS_DIR] | |
PUNCTUATION_LIST = list(string.punctuation) | |
PUNCTUATION_LIST.remove("%") | |
PUNCTUATION_LIST.remove("$") | |
PUNCTUATION_LIST = "".join(PUNCTUATION_LIST) | |
def clean_directory() -> None: | |
"""Clear direcgtories""" | |
print("Cleaning...\n") | |
for target_dir in ALL_DIRS: | |
if os.path.exists(target_dir) and os.path.isdir(target_dir): | |
shutil.rmtree(target_dir) | |
target_dir.mkdir(exist_ok=True, parents=True) | |
def get_batch_text_representation(texts, model, tokenizer, batch_size=1): | |
"""Get mean-pooled representations of given texts in batches.""" | |
mean_pooled_batch = [] | |
for i in range(0, len(texts), batch_size): | |
batch_texts = texts[i : i + batch_size] | |
inputs = tokenizer(batch_texts, return_tensors="pt", padding=True, truncation=True) | |
with torch.no_grad(): | |
outputs = model(**inputs, output_hidden_states=False) | |
last_hidden_states = outputs.last_hidden_state | |
input_mask_expanded = ( | |
inputs["attention_mask"].unsqueeze(-1).expand(last_hidden_states.size()).float() | |
) | |
sum_embeddings = torch.sum(last_hidden_states * input_mask_expanded, 1) | |
sum_mask = input_mask_expanded.sum(1) | |
mean_pooled = sum_embeddings / sum_mask | |
mean_pooled_batch.extend(mean_pooled.cpu().detach().numpy()) | |
return np.array(mean_pooled_batch) | |
def is_user_query_valid(user_query: str) -> bool: | |
""" | |
Check if the `user_query` is None and not empty. | |
Args: | |
user_query (str): The input text to be checked. | |
Returns: | |
bool: True if the `user_query` is None or empty, False otherwise. | |
""" | |
# If the query is not part of the default queries | |
is_default_query = user_query in DEFAULT_QUERIES.values() | |
# Check if the query exceeds the length limit | |
is_exceeded_max_length = user_query is not None and len(user_query) <= MAX_USER_QUERY_LEN | |
return not is_default_query and not is_exceeded_max_length | |
def compare_texts_ignoring_extra_spaces(original_text, modified_text): | |
"""Check if the modified_text is identical to the original_text except for additional spaces. | |
Args: | |
original_text (str): The original text for comparison. | |
modified_text (str): The modified text to compare against the original. | |
Returns: | |
(bool): True if the modified_text is the same as the original_text except for | |
additional spaces; False otherwise. | |
""" | |
normalized_original = " ".join(original_text.split()) | |
normalized_modified = " ".join(modified_text.split()) | |
return normalized_original == normalized_modified | |
def is_strict_deletion_only(original_text, modified_text): | |
# Define a regex pattern that matches a word character next to a punctuation | |
# or a punctuation next to a word character, without a space between them. | |
pattern = r"(?<=[\w])(?=[^\w\s])|(?<=[^\w\s])(?=[\w])" | |
# Replace instances found by the pattern with a space | |
original_text = re.sub(pattern, " ", original_text) | |
modified_text = re.sub(pattern, " ", modified_text) | |
# Tokenize the texts into words, considering also punctuation | |
original_words = Counter(original_text.lower().split()) | |
modified_words = Counter(modified_text.lower().split()) | |
base_words = all(item in original_words.keys() for item in modified_words.keys()) | |
base_count = all(original_words[k] >= v for k, v in modified_words.items()) | |
return base_words and base_count | |
def read_txt(file_path): | |
"""Read text from a file.""" | |
with open(file_path, "r", encoding="utf-8") as file: | |
return file.read() | |
def write_txt(file_path, data): | |
"""Write text to a file.""" | |
with open(file_path, "w", encoding="utf-8") as file: | |
file.write(data) | |
def write_pickle(file_path, data): | |
"""Save data to a pickle file.""" | |
with open(file_path, "wb") as f: | |
pkl.dump(data, f) | |
def read_pickle(file_name): | |
"""Load data from a pickle file.""" | |
with open(file_name, "rb") as file: | |
return pkl.load(file) | |
def read_json(file_name): | |
"""Load data from a json file.""" | |
with open(file_name, "r") as file: | |
return json.load(file) | |
def write_json(file_name, data): | |
"""Save data to a json file.""" | |
with open(file_name, "w", encoding="utf-8") as file: | |
json.dump(data, file, indent=4, sort_keys=True) | |