encrypted-anonymization / anonymize_file_clear.py
jfrery-zama's picture
space working with chatgpt 4
df6182e
raw
history blame
No virus
2.21 kB
import argparse
import json
import re
import uuid
from pathlib import Path
import gensim
from concrete.ml.common.serialization.loaders import load
def load_models():
base_dir = Path(__file__).parent
embeddings_model = gensim.models.FastText.load(str(base_dir / "embedded_model.model"))
with open(base_dir / "cml_xgboost.model", "r") as model_file:
fhe_ner_detection = load(file=model_file)
return embeddings_model, fhe_ner_detection
def anonymize_text(text, embeddings_model, fhe_ner_detection):
token_pattern = r"(\b[\w\.\/\-@]+\b|[\s,.!?;:'\"-]+)"
tokens = re.findall(token_pattern, text)
uuid_map = {}
processed_tokens = []
for token in tokens:
if token.strip() and re.match(r"\w+", token): # If the token is a word
x = embeddings_model.wv[token][None]
prediction_proba = fhe_ner_detection.predict_proba(x)
probability = prediction_proba[0][1]
prediction = probability >= 0.5
if prediction:
if token not in uuid_map:
uuid_map[token] = str(uuid.uuid4())[:8]
processed_tokens.append(uuid_map[token])
else:
processed_tokens.append(token)
else:
processed_tokens.append(token) # Preserve punctuation and spaces as is
return uuid_map
def main():
parser = argparse.ArgumentParser(description="Anonymize named entities in a text file and save the mapping to a JSON file.")
parser.add_argument("file_path", type=str, help="The path to the file to be processed.")
args = parser.parse_args()
embeddings_model, fhe_ner_detection = load_models()
# Read the input file
with open(args.file_path, 'r', encoding='utf-8') as file:
text = file.read()
# Anonymize the text
uuid_map = anonymize_text(text, embeddings_model, fhe_ner_detection)
# Save the UUID mapping to a JSON file
mapping_path = Path(args.file_path).stem + "_uuid_mapping.json"
with open(mapping_path, 'w', encoding='utf-8') as file:
json.dump(uuid_map, file, indent=4, sort_keys=True)
print(f"UUID mapping saved to {mapping_path}")
if __name__ == "__main__":
main()