LH-Tech-AI commited on
Commit
89874f1
·
verified ·
1 Parent(s): 9178b45

Create use.py

Browse files
Files changed (1) hide show
  1. use.py +77 -0
use.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoTokenizer, AutoModelForTokenClassification
3
+
4
+ class ShieldFilter:
5
+ def __init__(self, model_path="LH-Tech-AI/Shield-82M"):
6
+ print(f"Loading Shield-82M from {model_path}...")
7
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path)
8
+ self.model = AutoModelForTokenClassification.from_pretrained(model_path)
9
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
+ self.model.to(self.device)
11
+ self.model.eval()
12
+
13
+ self.group_map = {
14
+ "FIRSTNAME": "PERSON", "MIDDLENAME": "PERSON", "LASTNAME": "PERSON",
15
+ "BUILDINGNUMBER": "ADDRESS", "STREET": "ADDRESS", "CITY": "ADDRESS",
16
+ "STATE": "ADDRESS", "ZIPCODE": "ADDRESS", "SECONDARYADDRESS": "ADDRESS",
17
+ "EMAIL": "EMAIL", "PHONENUMBER": "PHONE", "PHONEIMEI": "PHONE",
18
+ "DATE": "DOB", "TIME": "DOB"
19
+ }
20
+
21
+ def protect(self, text):
22
+ inputs = self.tokenizer(
23
+ text,
24
+ return_tensors="pt",
25
+ truncation=True,
26
+ max_length=512,
27
+ return_offsets_mapping=True
28
+ ).to(self.device)
29
+
30
+ offsets = inputs.pop("offset_mapping")[0].cpu().numpy()
31
+
32
+ with torch.no_grad():
33
+ outputs = self.model(**inputs).logits
34
+
35
+ predictions = torch.argmax(outputs, dim=2)[0].cpu().numpy()
36
+ id2label = self.model.config.id2label
37
+
38
+ spans_to_replace = []
39
+ current_group = None
40
+ start_char = -1
41
+ last_char = -1
42
+
43
+ for idx, (pred_id, offset) in enumerate(zip(predictions, offsets)):
44
+ if offset[0] == 0 and offset[1] == 0:
45
+ continue
46
+
47
+ label = id2label[pred_id]
48
+
49
+ if label == "O":
50
+ if current_group is not None:
51
+ spans_to_replace.append((start_char, last_char, current_group))
52
+ current_group = None
53
+ else:
54
+ group_tag = self.group_map.get(label, label)
55
+
56
+ if current_group != group_tag:
57
+ if current_group is not None:
58
+ spans_to_replace.append((start_char, last_char, current_group))
59
+ current_group = group_tag
60
+ start_char = offset[0]
61
+
62
+ last_char = offset[1]
63
+
64
+ if current_group is not None:
65
+ spans_to_replace.append((start_char, last_char, current_group))
66
+
67
+ filtered_text = text
68
+ for start, end, tag in sorted(spans_to_replace, key=lambda x: x[0], reverse=True):
69
+ filtered_text = filtered_text[:start] + f"[{tag}]" + filtered_text[end:]
70
+
71
+ return filtered_text
72
+
73
+ if __name__ == "__main__":
74
+ shield = ShieldFilter()
75
+ sample = "My name is John Doe. Email: john@example.com. Phone: +49 123 45678."
76
+ print(f"Original: {sample}")
77
+ print(f"Protected: {shield.protect(sample)}")