LasRuinasCirculares commited on
Commit
861bb01
1 Parent(s): 2a5821d

Upload 7 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ knowledge_conflict_entity_based/result/entity_info.json filter=lfs diff=lfs merge=lfs -text
knowledge_conflict_entity_based/.DS_Store ADDED
Binary file (6.15 kB). View file
 
knowledge_conflict_entity_based/entity_substitute.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spacy
2
+ import zstandard as zstd
3
+ import json
4
+ import typing
5
+ import os
6
+ from tqdm import tqdm
7
+ import multiprocessing
8
+ import random
9
+ from langdetect import detect
10
+ import argparse
11
+
12
+ parser = argparse.ArgumentParser()
13
+ parser.add_argument('--input_dir', type=str, help='Path to the input file')
14
+ args = parser.parse_args()
15
+ input_dir = args.input_dir
16
+
17
+
18
+ def is_english(text):
19
+ try:
20
+ lang = detect(text)
21
+ return lang == 'en'
22
+ except:
23
+ return False
24
+
25
+ def process_text(texts, model, out_f, lock):
26
+ for text in texts:
27
+ doc = model(text)
28
+ freq_cnt = {}
29
+ for e in doc.ents:
30
+ if e not in freq_cnt:
31
+ freq_cnt[e] = 0
32
+ freq_cnt[e] += 1
33
+ if len(freq_cnt) == 0:
34
+ continue
35
+ sorted_freq = sorted(freq_cnt.items(), key = lambda x:[1])
36
+ most_freq = sorted_freq[-1][0]
37
+ data = {'text':text, 'main_entity':most_freq.text, 'label': most_freq.label_, 'id': most_freq.kb_id_}
38
+ json_data = json.dumps(data)
39
+ with lock:
40
+ out_f.write(json_data + '\n')
41
+ out_f.flush()
42
+
43
+ def run_ner_linking(texts: typing.List[str], ner_model_path: str):
44
+ nlp = spacy.load(ner_model_path)
45
+ out_f = open('result/temp_store_data.json', 'w', encoding='utf-8')
46
+ lock = multiprocessing.Lock()
47
+ processes = []
48
+
49
+ for i in tqdm(range(0, len(texts), 1000)):
50
+ p = multiprocessing.Process(target=process_text, args=(texts[i:i+1000], nlp, out_f, lock))
51
+ processes.append(p)
52
+ p.start()
53
+
54
+ for p in processes:
55
+ p.join()
56
+
57
+ out_f.close()
58
+ return
59
+
60
+ wikipedia_out_path='result/wikipedia.json'
61
+ subdirectories = [f.path for f in os.scandir(input_dir) if f.is_dir()]
62
+ wikipedia_data = []
63
+ for sub_dir in subdirectories:
64
+ chunk_dir = sub_dir+'/'
65
+ zst_files = [f for f in os.listdir(chunk_dir) if f.endswith('.zst')]
66
+ for file in tqdm(zst_files):
67
+ with open(chunk_dir+file, 'rb') as compressed_file:
68
+ decompressor = zstd.ZstdDecompressor()
69
+ with decompressor.stream_reader(compressed_file) as reader:
70
+ decompressed_data = reader.read()
71
+ for line in decompressed_data.splitlines():
72
+ data = json.loads(line)
73
+ # print(data)
74
+ if data['meta']['redpajama_set_name']=='RedPajamaWikipedia':
75
+ if is_english(data['text']):
76
+ wikipedia_data.append(data)
77
+
78
+ with open(wikipedia_out_path, 'w', encoding='utf-8') as f:
79
+ for data in wikipedia_data:
80
+ json_data = json.dumps(data)
81
+ f.write(json_data+'\n')
82
+
83
+ wikipedia_data = []
84
+ ner_model_path = 'kc-ner-model'
85
+ with open(wikipedia_out_path, 'r', encoding='utf-8') as f:
86
+ for line in tqdm(f):
87
+ data = json.loads(line)
88
+ wikipedia_data.append(data['text'])
89
+ run_ner_linking(wikipedia_data, ner_model_path)
90
+
91
+ entity_info_path = 'result/entity_info.json'
92
+ with open(entity_info_path, 'r', encoding='utf-8') as f:
93
+ entity_info = json.load(f)
94
+ all_original_data = []
95
+
96
+ category = {}
97
+ all_data = []
98
+ with open('result/temp_store_data.json', 'r', encoding='utf-8') as f:
99
+ for line in f:
100
+ data = json.loads(line)
101
+ all_data.append(data)
102
+ if data['label'] not in category:
103
+ category[data['label']] = []
104
+ category[data['label']].append(data['main_entity'])
105
+
106
+ with open('result/processed_data.json', 'w', encoding='utf-8') as f:
107
+ for data in tqdm(all_data):
108
+ text = data['text']
109
+ main_entity = [data['main_entity']]
110
+ if data['id'] in entity_info:
111
+ main_entity.extend(entity_info[data['id']]['aliases'])
112
+ if len(category[data['label']]) == 1:
113
+ continue
114
+ replaced_eneity = random.sample(category[data['label']], 1)
115
+ while replaced_eneity[0] in main_entity:
116
+ replaced_eneity = random.sample(category[data['label']], 1)
117
+ for entity in main_entity:
118
+ text = text.replace(entity, replaced_eneity[0])
119
+ data = {
120
+ 'text':text,
121
+ 'original_main_entity':main_entity,
122
+ 'replaced_entity':replaced_eneity[0]
123
+ }
124
+ json_data = json.dumps(data)
125
+ f.write(json_data+'\n')
126
+
knowledge_conflict_entity_based/requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ spacy==2.2.4
2
+ langdetect
3
+ zstandard
4
+ tqdm
5
+ wget
knowledge_conflict_entity_based/result/.DS_Store ADDED
Binary file (6.15 kB). View file
 
knowledge_conflict_entity_based/result/entity_info.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:423a217fa602456b961b6169b0bac15659ec85c90de2b261ca924c0ebe7d04a4
3
+ size 742977816
knowledge_conflict_entity_based/run.sh ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ ### the processed data will be stored to the path {result/processed_data.json}
2
+ python entity_substitute.py --input_dir /opt/data/private/szc/ml-knowledge-conflicts-main/test
knowledge_conflict_entity_based/setup.sh ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ pip install -r requirements.txt
2
+
3
+ # Download the SpaCy Named Entity Recognizer (NER) and Entity Linker (EL) model
4
+ # See https://spacy.io/usage/linguistic-features#named-entities and https://v2.spacy.io/usage/training#entity-linker
5
+ wget https://docs-assets.developer.apple.com/ml-research/models/kc-ner/model.gz -O kc-ner-model.gz
6
+ tar -xvzf kc-ner-model.gz -C kc-ner-model