svjack commited on
Commit
03c1af1
1 Parent(s): b2f024b

Upload 5 files

Browse files
Files changed (5) hide show
  1. extract_by_api.py +17 -0
  2. extract_et_by_api.py +14 -0
  3. qa.py +138 -0
  4. translate_by_api.py +14 -0
  5. wiki_kb_qa_migrate.py +1012 -0
extract_by_api.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+
3
+ def call_en_zh_reader(English_Span, Chinese_Sentence):
4
+ assert type(English_Span) == type("")
5
+ assert type(Chinese_Sentence) == type("")
6
+ response = requests.post("https://svjack-extract-similar-chinese-span-by--5daeb83.hf.space/run/predict", json={
7
+ "data": [
8
+ English_Span,
9
+ Chinese_Sentence,
10
+ ]}).json()
11
+ data = response["data"]
12
+ if data:
13
+ data = data[0]
14
+ pass
15
+ else:
16
+ pass
17
+ return data
extract_et_by_api.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+
3
+ def call_entity_property_extract(zh_question):
4
+ response = requests.post("https://svjack-entity-property-extractor-zh.hf.space/run/predict", json={
5
+ "data": [
6
+ zh_question,
7
+ ]}).json()
8
+ data = response["data"]
9
+ if data:
10
+ data = data[0]
11
+ pass
12
+ else:
13
+ pass
14
+ return data
qa.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #from conf import *
2
+ import os
3
+ import sys
4
+ import re
5
+ from rapidfuzz import fuzz
6
+ import requests
7
+ import json
8
+
9
+ #assert os.path.exists(flair_ner_model_path)
10
+ #loaded_model: SequenceTagger = SequenceTagger.load(os.path.join(flair_ner_model_path ,"best-model.pt"))
11
+
12
+ '''
13
+ def one_item_process(r, loaded_model):
14
+ #assert type(r) == type(pd.Series())
15
+ zh = r["question"]
16
+ zh = zh.replace(" ", "").strip()
17
+ sentence = Sentence(" ".join(list(zh)))
18
+ loaded_model.predict(sentence)
19
+ sentence_str = str(sentence)
20
+ ask_spans = re.findall(r'\["(.+?)"/ASK\]', sentence_str)
21
+ sentence = re.findall(r'Sentence: "(.+?)"', sentence_str)
22
+ if ask_spans:
23
+ ask_spans = ask_spans[0]
24
+ else:
25
+ ask_spans = ""
26
+ if sentence:
27
+ sentence = sentence[0]
28
+ else:
29
+ sentence = ""
30
+ ask_spans, sentence = map(lambda x: x.replace(" ", "").strip(), [ask_spans, sentence])
31
+ return ask_spans, sentence
32
+ '''
33
+
34
+ def one_item_process_by_request(r):
35
+ zh = r["question"]
36
+ zh = zh.replace(" ", "").strip()
37
+ response = requests.post("https://svjack-question-words-extractor-zh.hf.space/run/predict", json={
38
+ "data": [
39
+ zh,
40
+ ]}).json()
41
+ data = response["data"]
42
+ #data = json.loads(data)
43
+ if data:
44
+ data = data[0]
45
+ Question_words = data["Question words"]
46
+ else:
47
+ Question_words = ""
48
+ return Question_words, zh
49
+
50
+
51
+ def retrieve_sent_split(sent,
52
+ stops_split_pattern = "|".join(map(lambda x: r"\{}".format(x),
53
+ ",." + ",。" + ":?? "))
54
+ ):
55
+ if not sent.strip():
56
+ return []
57
+
58
+ split_list = re.split(stops_split_pattern, sent)
59
+ return split_list
60
+
61
+ def find_min_text_contain_entity_span(sent, entity_str, statement):
62
+ #assert entity_str in sent
63
+ span_list = list(filter(lambda x: entity_str in x ,retrieve_sent_split(sent)))
64
+ if not span_list:
65
+ return sent
66
+ span_list = list(map(lambda x: (x, fuzz.ratio(x, statement)), span_list))
67
+ return sorted(span_list, key = lambda t2: t2[1], reverse = True)[0][0]
68
+ #return sorted(span_list, key = len)[0]
69
+
70
+ def to_statement(r):
71
+ entity = r["entity"]
72
+ question = r["question"]
73
+ head = r["head"]
74
+ context = r["context"]
75
+ statement = question.replace(head, entity).replace("?", "").replace("?", "")
76
+ top_chip = find_min_text_contain_entity_span(context, entity, statement)
77
+ return statement, top_chip
78
+
79
+ '''
80
+ r = {'entity': '1901年',
81
+ 'question': '荷兰国会何时通过伦理政策?',
82
+ 'title': '爪哇岛',
83
+ 'context': '伊斯兰教被接受的同时,其教义也被融入了当地人长久以来的一些信仰,所以爪哇岛的伊斯兰教带有明显的本地特色 “荷兰东印度公司”在巴达维亚(今天的雅加达)建立了“贸易和行政管理总部” 在殖民统治时期,荷兰人将注意力集中在雅加达和其他一些海滨城市,例如三宝垄和泗水 荷兰殖民者还通过一些归顺的本土势力,间接对这个多山的岛屿进行统治,例如爪哇岛中部的马打兰王国 19世纪,荷兰政府从荷兰东印度公司手上接管了东印度群岛,1830年荷兰统治者开始实行所谓“耕种制”(荷兰语cultuurstelsel en cultuurprocenten)的变相奴役制度,导致了大范围的饥荒和贫困 随即发生了各种政治和社会反抗运动,其中一位名叫Multatuli的荷兰作家写了一本名叫《Max Havelaar》的小说,以抗议当时的社会状况 迫于各种反抗运动此起彼伏,1901年荷兰国会通过伦理政策(Etnisch beleid),客观上使一部分爪哇人接触到荷兰式教育,在这些人中,出现了很多杰出的印尼民族主义者,并且在二战后的印尼独立运动中起到了重要作用'}
84
+
85
+ qa_downstream_process(
86
+ r["entity"],
87
+ r["question"],
88
+ r["context"],
89
+ loaded_model
90
+ )
91
+
92
+ {'entity': '1901年',
93
+ 'question': '荷兰国会何时通过伦理政策?',
94
+ 'context': '伊斯兰教被接受的同时,其教义也被融入了当地人长久以来的一些信仰,所以爪哇岛的伊斯兰教带有明显的本地特色 “荷兰东印度公司”在巴达维亚(今天的雅加达)建立了“贸易和行政管理总部” 在殖民统治时期,荷兰人将注意力集中在雅加达和其他一些海滨城市,例如三宝垄和泗水 荷兰殖民者还通过一些归顺的本土势力,间接对这个多山的岛屿进行统治,例如爪哇岛中部的马打兰王国 19世纪,荷兰政府从荷兰东印度公司手上接管了东印度群岛,1830年荷兰统治者开始实行所谓“耕种制”(荷兰语cultuurstelsel en cultuurprocenten)的变相奴役制度,导致了大范围的饥荒和贫困 随即发生了各种政治和社会反抗运动,其中一位名叫Multatuli的荷兰作家写了一本名叫《Max Havelaar》的小说,以抗议当时的社会状况 迫于各种反抗运动此起彼伏,1901年荷兰国会通过伦理政策(Etnisch beleid),客观上使一部分爪哇人接触到荷兰式教育,在这些人中,出现了很多杰出的印尼民族主义者,并且在二战后的印尼独立运动中起到了重要作用',
95
+ 'head': '何时',
96
+ 'statement': '荷兰国会1901年通过伦理政策',
97
+ 'top_chip': '1901年荷兰国会通过伦理政策(Etnisch'}
98
+ '''
99
+ #def qa_downstream_process(entity, question, context, loaded_model = loaded_model):
100
+ def qa_downstream_process(entity, question, context):
101
+ if entity not in context:
102
+ return None
103
+ d = {
104
+ "entity": entity,
105
+ "question": question,
106
+ "context": context
107
+ }
108
+ #head_qst = one_item_process(d, loaded_model)
109
+ head_qst = one_item_process_by_request(d)
110
+ head, _ = head_qst
111
+ d["head"] = head
112
+ statement, top_chip = to_statement(d)
113
+ d["statement"] = statement
114
+ d["top_chip"] = top_chip
115
+ return d
116
+
117
+ '''
118
+ @csrf_exempt
119
+ def qa_downstream_process_part(request):
120
+ assert request.method == "POST"
121
+ post_data = request.POST
122
+ entity = post_data["entity"]
123
+ question = post_data["question"]
124
+ context = post_data["context"]
125
+ output = qa_downstream_process(entity, question, context)
126
+ if output is None:
127
+ return HttpResponse(json.dumps(
128
+ {"output": "No Answer"}
129
+ ))
130
+ assert type(output) == type({})
131
+ req_str = json.dumps(output)
132
+ return HttpResponse(
133
+ req_str
134
+ )
135
+ '''
136
+
137
+ if __name__ == "__main__":
138
+ pass
translate_by_api.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+
3
+ def call_zh_en_naive_model(zh_question):
4
+ response = requests.post("https://svjack-translate-chinese-to-english.hf.space/run/predict", json={
5
+ "data": [
6
+ zh_question,
7
+ ]}).json()
8
+ data = response["data"]
9
+ if data:
10
+ data = data[0]
11
+ English_Question = data["English Question"]
12
+ else:
13
+ English_Question = ""
14
+ return English_Question
wiki_kb_qa_migrate.py ADDED
@@ -0,0 +1,1012 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #### qa_env
2
+ #from conf import *
3
+ from qa import *
4
+ from translate_by_api import *
5
+ from extract_by_api import *
6
+ from extract_et_by_api import *
7
+
8
+ import os
9
+ import logging
10
+ import subprocess
11
+ import time
12
+ from pathlib import Path
13
+
14
+ from haystack.nodes import Text2SparqlRetriever
15
+ from haystack.document_stores import GraphDBKnowledgeGraph, InMemoryKnowledgeGraph
16
+ #from haystack.utils import fetch_archive_from_http
17
+
18
+ import pandas as pd
19
+ import numpy as np
20
+ import os
21
+ import sys
22
+
23
+ #import jieba
24
+ from functools import partial, reduce, lru_cache
25
+ #from easynmt import EasyNMT
26
+
27
+ #from sentence_transformers.util import pytorch_cos_sim
28
+ #from sentence_transformers import SentenceTransformer
29
+ from time import time
30
+
31
+ from itertools import product
32
+
33
+ #import pickle as pkl
34
+ from urllib.parse import unquote
35
+
36
+ import requests
37
+ import json
38
+
39
+ import pandas as pd
40
+ import numpy as np
41
+ import os
42
+ import sys
43
+
44
+ #import jieba
45
+ from functools import partial, reduce, lru_cache
46
+ #from easynmt import EasyNMT
47
+
48
+ #from sentence_transformers.util import pytorch_cos_sim
49
+ #from sentence_transformers import SentenceTransformer
50
+ from time import time
51
+
52
+ from itertools import product
53
+
54
+ #import pickle as pkl
55
+ #import faiss
56
+
57
+ from rapidfuzz import fuzz
58
+ import synonyms
59
+
60
+ import sys
61
+ #sys.path.insert(0 ,"/Users/svjack/temp/HP_kbqa/script")
62
+ #from trans_toolkit import *
63
+
64
+ #from easynmt import EasyNMT
65
+ #zh_en_naive_model = EasyNMT("m2m_100_418M")
66
+ '''
67
+ p00 = os.path.join(model_path, "zh_en_m2m")
68
+ assert os.path.exists(p00)
69
+ zh_en_naive_model = EasyNMT(p00)
70
+ zh_en_naive_model.translate(["宁波在哪?"], source_lang="zh", target_lang = "en")
71
+ '''
72
+
73
+ '''
74
+ from haystack.nodes import FARMReader
75
+ #question_reader_save_path = "/Users/svjack/temp/model/en_zh_question_reader_save_epc_2_spo"
76
+ question_reader_save_path = os.path.join(model_path, "en_zh_question_reader_save_epc_2_spo")
77
+ assert os.path.exists(question_reader_save_path)
78
+ en_zh_reader = FARMReader(model_name_or_path=question_reader_save_path, use_gpu=False,
79
+ num_processes = 0
80
+ )
81
+ '''
82
+
83
+ kg = InMemoryKnowledgeGraph(index="tutorial_10_index")
84
+ kg.delete_index()
85
+ kg.create_index()
86
+
87
+ kg.import_from_ttl_file(index="tutorial_10_index", path=Path("data") / "triples.ttl")
88
+ #kg.get_params()
89
+ #all_triples = kg.get_all_triples()
90
+ #spo_df = pd.DataFrame(all_triples)
91
+
92
+ #### some collection in kb_aug
93
+ import re
94
+ def transform_namespace_to_prefix_str(g):
95
+ namespaces = g.namespaces()
96
+ return "\n".join(map(lambda x: "PREFIX {}: <{}>".format(x[0], x[1]), namespaces))
97
+
98
+ #print(transform_namespace_to_prefix_str(kg.indexes["tutorial_10_index"]))
99
+ ### ->
100
+
101
+ wiki_prefix = '''
102
+ PREFIX brick: <https://brickschema.org/schema/Brick#>
103
+ PREFIX csvw: <http://www.w3.org/ns/csvw#>
104
+ PREFIX dc: <http://purl.org/dc/elements/1.1/>
105
+ PREFIX dcat: <http://www.w3.org/ns/dcat#>
106
+ PREFIX dcmitype: <http://purl.org/dc/dcmitype/>
107
+ PREFIX dcterms: <http://purl.org/dc/terms/>
108
+ PREFIX dcam: <http://purl.org/dc/dcam/>
109
+ PREFIX doap: <http://usefulinc.com/ns/doap#>
110
+ PREFIX foaf: <http://xmlns.com/foaf/0.1/>
111
+ PREFIX odrl: <http://www.w3.org/ns/odrl/2/>
112
+ PREFIX org: <http://www.w3.org/ns/org#>
113
+ PREFIX owl: <http://www.w3.org/2002/07/owl#>
114
+ PREFIX prof: <http://www.w3.org/ns/dx/prof/>
115
+ PREFIX prov: <http://www.w3.org/ns/prov#>
116
+ PREFIX qb: <http://purl.org/linked-data/cube#>
117
+ PREFIX rdf: <http://www.w3.org/1999/02/22-rdf-syntax-ns#>
118
+ PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>
119
+ PREFIX schema: <https://schema.org/>
120
+ PREFIX sh: <http://www.w3.org/ns/shacl#>
121
+ PREFIX skos: <http://www.w3.org/2004/02/skos/core#>
122
+ PREFIX sosa: <http://www.w3.org/ns/sosa/>
123
+ PREFIX ssn: <http://www.w3.org/ns/ssn/>
124
+ PREFIX time: <http://www.w3.org/2006/time#>
125
+ PREFIX vann: <http://purl.org/vocab/vann/>
126
+ PREFIX void: <http://rdfs.org/ns/void#>
127
+ PREFIX xsd: <http://www.w3.org/2001/XMLSchema#>
128
+ PREFIX xml: <http://www.w3.org/XML/1998/namespace>
129
+ PREFIX hp: <https://deepset.ai/harry_potter/>
130
+ '''
131
+
132
+ prefix_s = pd.Series(wiki_prefix.split("\n")).map(
133
+ lambda x: x if x.startswith("PREFIX") else np.nan
134
+ ).dropna().map(
135
+ lambda x: re.findall("PREFIX (.*): <", x)
136
+ ).map(lambda x: x[0])
137
+
138
+
139
+ prefix_url_dict = dict(map(
140
+ lambda y: (y.split(" ")[1].replace(":", ""), y.split(" ")[2].strip()[1:-1])
141
+ ,filter(
142
+ lambda x: x.strip()
143
+ , wiki_prefix.split("\n"))))
144
+
145
+ url_prefix_dict = dict(map(lambda t2: t2[::-1], prefix_url_dict.items()))
146
+
147
+ all_triples = kg.get_all_triples()
148
+ spo_df = pd.DataFrame(all_triples)
149
+ spo_df_simple = spo_df.copy()
150
+ spo_df_simple = spo_df_simple.applymap(lambda x: x["value"]).applymap(lambda x:
151
+ (list(filter(lambda t2: x.startswith(t2[0]), url_prefix_dict.items()))[0], x) if any(map(lambda t2: x.startswith(t2[0]), url_prefix_dict.items())) else (None, x)
152
+ ).applymap(
153
+ lambda t2: t2[1].replace(t2[0][0], "{}:".format(t2[0][1])) if t2[0] is not None else t2[1]
154
+ ).applymap(unquote)
155
+
156
+ '''
157
+ #### like property in wikidata
158
+ spo_df_simple["p"].map(
159
+ lambda x: x[3:] if x.startswith("hp:") else np.nan
160
+ ).dropna().value_counts()
161
+
162
+ #### others in p col (rdf:type)
163
+ spo_df_simple["p"].map(
164
+ lambda x: x if not x.startswith("hp:") else np.nan
165
+ ).dropna().value_counts()
166
+
167
+ #### groupby different entity type view
168
+ pd.concat(
169
+ list(map(
170
+ lambda t2: t2[1].head(2),
171
+ list(spo_df_simple[
172
+ spo_df_simple["p"] == "rdf:type"
173
+ ].sort_values(by = ["o", "s"]).groupby("o"))
174
+ )), axis = 0).head(30)
175
+ '''
176
+
177
+ #### spo s(type)o
178
+
179
+ #### use deepl translate to lookup
180
+ #spo_trans_total_df = pd.read_csv("../data/spo_trans_total.csv")
181
+ spo_trans_total_df = pd.read_csv("data/spo_trans_total.csv")
182
+ spo_trans_dict = dict(spo_trans_total_df.values.tolist())
183
+ '''
184
+ with open("../data/spo_trans_dict.json", "w") as f:
185
+ json.dump(spo_trans_dict, f)
186
+ '''
187
+
188
+ spo_trans_back_dict = dict(map(lambda t2: t2[::-1], spo_trans_dict.items()))
189
+ spo_df_simple_keyed = spo_df_simple.copy()
190
+
191
+ def map_to_trans_key(src):
192
+ x = str(src)
193
+ if not x.startswith("hp:"):
194
+ return np.nan
195
+ return x[3:].replace('"', '').replace("'", '').replace("_", " ")
196
+
197
+ spo_df_simple_trans = spo_df_simple_keyed.applymap(
198
+ lambda x: (x ,map_to_trans_key(x))
199
+ ).applymap(
200
+ lambda t2: spo_trans_dict.get(t2[1], t2[0]) if type(t2[1]) == type("") else t2[0]
201
+ )
202
+
203
+ '''
204
+ pd.concat(
205
+ list(map(
206
+ lambda t2: t2[1].head(2),
207
+ list(spo_df_simple_trans[
208
+ spo_df_simple_trans["p"] == "rdf:type"
209
+ ].sort_values(by = ["o", "s"]).groupby("o"))
210
+ )), axis = 0).head(50)
211
+
212
+ spo_df_simple_trans[
213
+ spo_df_simple_trans["s"] == "斯蒂芬-康福特"
214
+ ]
215
+ '''
216
+
217
+ model_dir = "data/"
218
+ kgqa_retriever = Text2SparqlRetriever(knowledge_graph=kg, model_name_or_path=model_dir + "hp_v3.4")
219
+
220
+ def decode_query(eng_query ,kgqa_retriever, top_k = 3):
221
+ self = kgqa_retriever
222
+ inputs = self.tok([eng_query], max_length=100, truncation=True, return_tensors="pt")
223
+ # generate top_k+2 SPARQL queries so that we can dismiss some queries with wrong syntax
224
+ temp = self.model.generate(
225
+ inputs["input_ids"], num_beams=max(5, top_k + 2), max_length=100, num_return_sequences=top_k + 2, early_stopping=True
226
+ )
227
+ sparql_queries = [
228
+ self.tok.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in temp
229
+ ]
230
+ return sparql_queries
231
+
232
+ import re
233
+ from uuid import uuid1
234
+ import jionlp as jio
235
+
236
+ special_match_token_list = [
237
+ " filter(",
238
+ ]
239
+
240
+ def fill_bk(str_):
241
+ #assert str_[0] == "("
242
+ req = []
243
+ cnt = 0
244
+ have_match_one = False
245
+ for char in str_:
246
+ #print(req)
247
+ if char == "(":
248
+ cnt += 1
249
+ have_match_one = True
250
+ if char == ")":
251
+ cnt -= 1
252
+ req.append(char)
253
+ if cnt == 0 and have_match_one:
254
+ break
255
+ return "".join(req)
256
+
257
+ def match_special_token(query, special_match_token_list):
258
+ assert type(query) == type("")
259
+ assert type(special_match_token_list) == type([])
260
+ special_match_token_list_ = list(filter(lambda x: x in query, special_match_token_list))
261
+ if not special_match_token_list_:
262
+ return []
263
+ return list(map(lambda x: (x ,
264
+ fill_bk(
265
+ query[query.find(x):]
266
+ )
267
+ ), special_match_token_list_))
268
+
269
+ def retrieve_sent_split(sent,
270
+ stops_split_pattern = "|".join(map(lambda x: r"\{}".format(x),
271
+ " "))
272
+ ):
273
+ if not sent.strip():
274
+ return []
275
+
276
+ split_list = re.split(stops_split_pattern, sent)
277
+ return split_list
278
+
279
+ import jionlp as jio
280
+
281
+ ask_l = [
282
+ "?answer", "?value", "?obj", "?sbj", "?s", "?x", "?a"
283
+ ]
284
+ ask_ner = jio.ner.LexiconNER({
285
+ "ask": ask_l
286
+ })
287
+
288
+ def query_to_t3(query, filter_list = [], ask_ner = ask_ner):
289
+ '''
290
+ query = query.replace("?answer", " ?answer ")
291
+ query = query.replace("?value", " ?value ")
292
+ query = query.replace("?obj", " ?obj ")
293
+ query = query.replace("?sbj", " ?sbj ")
294
+ query = query.replace("?s", " ?s ")
295
+ query = query.replace("?x", " ?x ")
296
+ '''
297
+ l = ask_ner(query)
298
+ l = sorted(set(map(lambda x: x["text"], l)), key = len, reverse = True)
299
+
300
+ for k in l:
301
+ query = query.replace(k, " {} ".format(k))
302
+
303
+ '''
304
+ if "where" not in query and "WHERE" not in query:
305
+ return []
306
+ '''
307
+
308
+ special_token_list = match_special_token(query, special_match_token_list)
309
+ #return special_token_list
310
+ if special_token_list:
311
+ special_token_list = list(set(map(lambda t2: t2[1] ,special_token_list)))
312
+ uid_special_token_dict = dict(map(lambda x: (str(uuid1()), x), special_token_list))
313
+ special_token_uid_dict = dict(map(lambda t2: t2[::-1], uid_special_token_dict.items()))
314
+ assert len(special_token_uid_dict) == len(uid_special_token_dict)
315
+
316
+ for k, v in sorted(special_token_uid_dict.items(), key = lambda t2: len(t2[0]), reverse = True):
317
+ if k in query:
318
+ #query = query.replace(k, v)
319
+ query = query.replace(k, "")
320
+ else:
321
+ uid_special_token_dict = {}
322
+ special_token_uid_dict = {}
323
+
324
+ '''
325
+ if "where" in query:
326
+ tail = "where".join(query.split("where")[1:])
327
+ elif "WHERE" in query:
328
+ tail = "WHERE".join(query.split("WHERE")[1:])
329
+ '''
330
+ #return query
331
+ query = query.strip()
332
+ if not query.endswith("}"):
333
+ query = query + "}"
334
+ tail = re.findall(r"{(.*)}", query)
335
+ #return tail
336
+ #return t3_list
337
+ if not tail:
338
+ return []
339
+ else:
340
+ tail = tail[0]
341
+
342
+ t3_list = list(map(lambda x: x.strip() ,tail.split(".")))
343
+ t3_list_ = []
344
+ for ele in t3_list:
345
+ for k, v in uid_special_token_dict.items():
346
+ if k in ele:
347
+ ele = ele.replace(k, v)
348
+ t3_list_.append(ele)
349
+ t3_list = t3_list_
350
+
351
+ if filter_list:
352
+ t3_list = list(filter(lambda x:
353
+ any(map(lambda y: y in x ,filter_list))
354
+ , t3_list))
355
+ t3_list = list(map(lambda x:
356
+ list(filter(lambda y: y.strip() ,retrieve_sent_split(x)))
357
+ , t3_list))
358
+ return t3_list
359
+
360
+ def decode_property(eng_query ,kgqa_retriever, top_k = 3):
361
+ sparql_queries = decode_query(eng_query, kgqa_retriever, top_k = top_k)
362
+ if not sparql_queries:
363
+ return []
364
+ t3_nest_list = list(map(lambda x: query_to_t3(x), sparql_queries))
365
+ ####return t3_nest_list
366
+ p_nest_list = []
367
+ for ele in t3_nest_list:
368
+ for e in ele:
369
+ if len(e) == 3:
370
+ p_nest_list.append(e)
371
+ #p_nest_list = list(filter(lambda x: len(x) == 3, t3_nest_list))
372
+ if not p_nest_list:
373
+ return []
374
+ p_nest_list = list(map(lambda x: x[1], p_nest_list))
375
+ return p_nest_list
376
+
377
+ '''
378
+ #### ori query decoder
379
+ query = "Harry Potter live in which house?"
380
+ query = "when was Stephen cornfoot born?"
381
+ decode_query(query, kgqa_retriever)
382
+
383
+ #### ori query decoder only maintain property part
384
+ query = "Harry Potter live in which house in 1920?"
385
+ query = "Harry live in where?"
386
+ query = "Harry live in where?"
387
+ query = "when was Stephen cornfoot born?"
388
+ query = "what is Stephen's loyalty?"
389
+ decode_property(query, kgqa_retriever)
390
+
391
+ query = "who is the leader of Divination homework meeting?"
392
+ '''
393
+
394
+ def template_fullfill_reconstruct_query(entity_list = ["http://www.wikidata.org/entity/Q42780"]
395
+ , property_list = ["http://www.wikidata.org/prop/direct/P131",
396
+ "http://www.wikidata.org/prop/direct/P150"
397
+ ],
398
+ generate_t3_func = lambda el, pl: pd.Series(list(product(el, pl))).map(
399
+ lambda ep: [(ep[0], ep[1], "?a"), ("?a", ep[1], ep[0])]
400
+ ).explode().dropna().drop_duplicates().tolist()
401
+ ):
402
+ assert type(entity_list) == type([])
403
+ assert type(property_list) == type([])
404
+ if not entity_list or not property_list:
405
+ return []
406
+ query_list = list(map(list ,generate_t3_func(entity_list, property_list)))
407
+ if not query_list:
408
+ return []
409
+ req = list(map(lambda x: "select ?a {" + " ".join(x) + "}", query_list))
410
+ return req
411
+
412
+ '''
413
+ sparql_queries_reconstruct = template_fullfill_reconstruct_query(
414
+ ["hp:Divination_homework_meeting"],
415
+ ["hp:leader"]
416
+ )
417
+ sparql_queries_reconstruct
418
+ '''
419
+
420
+ def run_sparql_queries(sparql_queries, kgqa_retriever, top_k = 3):
421
+ self = kgqa_retriever
422
+ answers = []
423
+ for sparql_query in sparql_queries:
424
+ ans, query = self._query_kg(sparql_query=sparql_query)
425
+ if len(ans) > 0:
426
+ answers.append((ans, query))
427
+ # if there are no answers we still want to return something
428
+ if len(answers) == 0:
429
+ answers.append(("", ""))
430
+ results = answers[:top_k]
431
+ results = [self.format_result(result) for result in results]
432
+ return results
433
+
434
+ '''
435
+ #### one conclusion
436
+ run_sparql_queries(sparql_queries_reconstruct, kgqa_retriever)
437
+ '''
438
+
439
+ #### start kbqa_protable_service (server)
440
+ def retrieve_et(zh_question, only_e = True):
441
+ assert type(zh_question) == type("")
442
+ '''
443
+ qst = zh_question
444
+ rep = requests.post(
445
+ url = "http://localhost:8855/extract_et",
446
+ data = {
447
+ "question": qst
448
+ }
449
+ )
450
+ output = json.loads(rep.content.decode())
451
+ '''
452
+ output = call_entity_property_extract(zh_question)
453
+ if only_e:
454
+ return output.get("E-TAG", [])
455
+ return output
456
+
457
+ '''
458
+ #### start qa server
459
+ def retrieve_head(zh_question):
460
+ req = requests.post(
461
+ url = "http://localhost:8811/qa_downstream_process",
462
+ data = {
463
+ "entity": "",
464
+ "question": zh_question,
465
+ "context": zh_question
466
+ }
467
+ )
468
+ output = json.loads(req.content.decode())
469
+ if "head" in output:
470
+ return output["head"]
471
+ return ""
472
+ '''
473
+ def retrieve_head(zh_question):
474
+ output = qa_downstream_process(
475
+ "", zh_question, zh_question
476
+ )
477
+ assert type(output) == type({})
478
+ if "head" in output:
479
+ return output["head"]
480
+ return ""
481
+
482
+ '''
483
+ zh_question = "谁是占卜会议的领导者?"
484
+ retrieve_et(zh_question)
485
+ '''
486
+
487
+ def property_and_type_slice(spo_df_simple_trans, p_l = [], type_l = []):
488
+ req = spo_df_simple_trans.copy()
489
+ if type_l:
490
+ s_l = req[
491
+ req["o"].isin(type_l)
492
+ ]["s"].drop_duplicates().dropna().values.tolist()
493
+ req = req[
494
+ req["s"].isin(s_l)
495
+ ]
496
+ if req.size == 0:
497
+ return None
498
+ if p_l:
499
+ s_l = req[
500
+ req["p"].isin(p_l)
501
+ ]["s"].drop_duplicates().dropna().values.tolist()
502
+ req = req[
503
+ req["s"].isin(s_l)
504
+ ]
505
+ if req.size == 0:
506
+ return None
507
+ return req
508
+
509
+ '''
510
+ ### Organisation_ sanple
511
+ property_and_type_slice(
512
+ spo_df_simple_trans, p_l = ["创立"], type_l = ["hp:Organisation_"]
513
+ ).sort_values(by = "s")["s"].drop_duplicates().sample(n = 30)
514
+
515
+ ### people sample
516
+ property_and_type_slice(
517
+ spo_df_simple_trans, p_l = ["出生"], type_l = ["hp:Individual_"]
518
+ ).sort_values(by = "s")["s"].drop_duplicates().sample(n = 30)
519
+
520
+ zh_question = "谁是占卜会议的领导者?"
521
+ en_question = zh_en_naive_model.translate([zh_question], source_lang="zh", target_lang = "en")[0]
522
+ en_properties = decode_property(en_question, kgqa_retriever)
523
+ en_properties
524
+ '''
525
+
526
+ all_en_p = spo_df_simple["p"].drop_duplicates().dropna().values.tolist()
527
+ all_en_p_tokens = pd.Series(list(map(lambda x: x[3:].split("_") ,filter(lambda x: x.startswith("hp:"), all_en_p)))).explode().dropna().map(
528
+ lambda x: x if bool(x) else np.nan
529
+ ).dropna().drop_duplicates().values.tolist()
530
+ ###all_en_p_tokens[:10]
531
+
532
+ all_p_df = pd.Series(all_en_p).reset_index().iloc[:, 1:]
533
+ all_p_df.columns = ["en_p"]
534
+ all_p_df = all_p_df[
535
+ all_p_df["en_p"] != "rdf:type"
536
+ ]
537
+ all_p_df["zh_p"] = all_p_df["en_p"].map(
538
+ lambda x: spo_trans_dict.get(x.replace("hp:", "").replace("_", " "), x.replace("hp:", "").replace("_", " "))
539
+ )
540
+ #all_p_df
541
+
542
+ #### decoder property mapping: (map decoder to kb exists)
543
+ decode_map_config_dict = {
544
+ "hp:birth": 'hp:born',
545
+ 'hp:birthday': "hp:born"
546
+ }
547
+
548
+ #### decoder sim property mapping: (decoder that can not distinguish)
549
+ decode_sim_config_dict = {
550
+ 'hp:ingredients': "hp:characteristics",
551
+ "hp:characteristics": 'hp:ingredients'
552
+ }
553
+
554
+ def decode_property_link_to_ori(decode_property, all_en_p, all_en_p_tokens, equal_threshold = 80):
555
+ if not decode_property.startswith("hp:") or not len(decode_property) >= 3:
556
+ return None
557
+ if decode_property in all_en_p:
558
+ return [(decode_property, 100.0)]
559
+ if decode_property in decode_map_config_dict:
560
+ return [(decode_map_config_dict[decode_property], 99.0)]
561
+ def filter_by_p_tokens(decode_property):
562
+ req = []
563
+ for ele in decode_property[3:].split("_"):
564
+ if ele in all_en_p_tokens:
565
+ req.append(ele)
566
+ return "hp:{}".format("_".join(req))
567
+ if decode_property == "hp:":
568
+ return None
569
+ decode_property = filter_by_p_tokens(decode_property)
570
+ order_list = sorted(map(lambda x: (x, fuzz.ratio(x, decode_property)), all_en_p), key = lambda t2: t2[1], reverse = True)
571
+ return order_list[:10]
572
+
573
+ '''
574
+ #### minimize maintain one token sorted.
575
+ decode_property_link_to_ori("hp:born", all_en_p, all_en_p_tokens, equal_threshold = 80)
576
+ decode_property_link_to_ori("hp:birth", all_en_p, all_en_p_tokens, equal_threshold = 80)
577
+ decode_property_link_to_ori("hp:head_of_the_assembly", all_en_p, all_en_p_tokens, equal_threshold = 80)
578
+ '''
579
+
580
+
581
+ def output_to_dict(output, trans_keys = ["answers"]):
582
+ non_trans_t2_list = list(filter(lambda t2: t2[0] not in trans_keys, output.items()))
583
+ trans_t2_list = list(map(lambda tt2: (
584
+ tt2[0],
585
+ list(map(lambda x: x.to_dict(), tt2[1]))
586
+ ) ,filter(lambda t2: t2[0] in trans_keys, output.items())))
587
+ #return trans_t2_list
588
+ return dict(trans_t2_list + non_trans_t2_list)
589
+
590
+ def zh_question_to_p_zh_en_map(zh_question, top_k = 3):
591
+ #zh_question = "谁是占卜会议的领导者?"
592
+ #en_question = zh_en_naive_model.translate([zh_question], source_lang="zh", target_lang = "en")[0]
593
+ en_question = call_zh_en_naive_model(zh_question)
594
+ en_properties = decode_property(en_question, kgqa_retriever, top_k = top_k)
595
+ if not en_properties:
596
+ return None
597
+ en_properties_top_sort = pd.Series(en_properties).value_counts().index.tolist()
598
+ en_properties_mapped = list(map(
599
+ lambda x: decode_property_link_to_ori(x, all_en_p, all_en_p_tokens, equal_threshold = 80), en_properties_top_sort
600
+ ))
601
+ en_properties_mapped = list(filter(lambda x: hasattr(x, "__len__") and len(x) >= 1, en_properties_mapped))
602
+ if not en_properties_mapped:
603
+ return None
604
+ en_properties_mapped = list(map(lambda x: x[0] ,en_properties_mapped))
605
+ en_properties_mapped_df = pd.DataFrame(en_properties_mapped)
606
+ assert en_properties_mapped_df.shape[1] == 2
607
+ en_properties_mapped_df.columns = ["en_property", "score"]
608
+ '''
609
+ en_properties_mapped_df["zh_property"] = en_properties_mapped_df["en_property"].map(
610
+ lambda x: en_zh_reader.predict_on_texts(
611
+ question=x.replace("hp:", ""),
612
+ texts=[zh_question]
613
+ )
614
+ ).map(output_to_dict)
615
+ '''
616
+ en_properties_mapped_df["zh_property"] = en_properties_mapped_df["en_property"].map(
617
+ lambda x: call_en_zh_reader(
618
+ x.replace("hp:", ""),
619
+ zh_question
620
+ )
621
+ )
622
+ en_properties_mapped_df["zh_property"] = en_properties_mapped_df["zh_property"].map(lambda x: x["answers"][0] if x["answers"] else {})
623
+ en_properties_mapped_df = en_properties_mapped_df[
624
+ en_properties_mapped_df["zh_property"].map(bool)
625
+ ]
626
+ if en_properties_mapped_df is None or en_properties_mapped_df.size == 0:
627
+ return None
628
+ #return nerd_df
629
+ en_properties_mapped_df["ext_score"] = en_properties_mapped_df["zh_property"].map(
630
+ lambda x: x["score"]
631
+ )
632
+ en_properties_mapped_df["zh_property"] = en_properties_mapped_df["zh_property"].map(
633
+ lambda x: x["answer"]
634
+ )
635
+ '''
636
+ en_properties_mapped_df = en_properties_mapped_df[
637
+ en_properties_mapped_df["ext_score"].map(lambda x: x > score_threshold)
638
+ ]
639
+ '''
640
+ if en_properties_mapped_df is None or en_properties_mapped_df.size == 0:
641
+ return None
642
+ ask_head = retrieve_head(zh_question)
643
+ #if type(ask_head) == type("") and "什么" in ask_head:
644
+ if type(ask_head) == type(""):
645
+ #ask_head = ask_head.replace("什么", "")
646
+ first_d = en_properties_mapped_df.iloc[0].to_dict()
647
+ first_d["zh_property"] = ask_head
648
+ en_properties_mapped_df = pd.DataFrame(
649
+ [first_d] + en_properties_mapped_df.apply(lambda x: x.to_dict(), axis = 1).values.tolist()
650
+ )
651
+ else:
652
+ pass
653
+ en_properties_mapped_df = en_properties_mapped_df[
654
+ en_properties_mapped_df["zh_property"].map(lambda x: bool(x))
655
+ ].drop_duplicates()
656
+ return en_properties_mapped_df
657
+
658
+ def search_sym_p(question_p_df, all_p_df):
659
+ #zh_p_l = question_p_df["zh_property"].drop_duplicates().values.tolist()
660
+ #en_p_l = question_p_df["en_property"].drop_duplicates().values.tolist()
661
+ req = []
662
+ for idx, r in question_p_df.iterrows():
663
+ all_p_score_df = all_p_df.copy()
664
+ all_p_score_df["zh_property"] = [r["zh_property"]] * len(all_p_score_df)
665
+ all_p_score_df["en_property"] = [r["en_property"]] * len(all_p_score_df)
666
+ req.append(all_p_score_df)
667
+ req = pd.concat(req, axis = 0)
668
+ req["zh_sim"] = req.apply(
669
+ lambda x: synonyms.compare(x["zh_property"], x["zh_p"]), axis = 1
670
+ )
671
+ req = req.sort_values(by = "zh_sim", ascending = False)
672
+ return req
673
+
674
+ all_en_ents = pd.Series(spo_df_simple[["s", "o"]].values.reshape([-1])).drop_duplicates().values.tolist()
675
+ all_ents_df = pd.Series(all_en_ents).reset_index().iloc[:, 1:]
676
+ all_ents_df.columns = ["en_ent"]
677
+ all_ents_df = all_ents_df[
678
+ all_ents_df["en_ent"] != "rdf:type"
679
+ ]
680
+ all_ents_df["zh_ent"] = all_ents_df["en_ent"].map(
681
+ lambda x: spo_trans_dict.get(x.replace("hp:", "").replace("_", " "), x.replace("hp:", "").replace("_", " "))
682
+ )
683
+ #all_ents_df
684
+ def search_sym_entity(entity_str, all_ents_df, use_syn = False):
685
+ #zh_p_l = question_p_df["zh_property"].drop_duplicates().values.tolist()
686
+ #en_p_l = question_p_df["en_property"].drop_duplicates().values.tolist()
687
+ req = all_ents_df.copy()
688
+ req["entity_str"] = [entity_str] * len(req)
689
+ if use_syn:
690
+ req["zh_sim"] = req.apply(
691
+ lambda x: synonyms.compare(x["zh_ent"], x["entity_str"]), axis = 1
692
+ )
693
+ else:
694
+ req["zh_sim"] = req.apply(
695
+ lambda x: fuzz.ratio(x["zh_ent"], x["entity_str"]), axis = 1
696
+ )
697
+ req = req.sort_values(by = "zh_sim", ascending = False)
698
+ return req
699
+
700
+ zh_question = "谁是占卜会议的领导者?"
701
+ zh_question = "洛林出生在哪个国家?"
702
+ zh_question = "洛林出生在哪个地方?"
703
+ zh_question = "洛林的血缘是什么?"
704
+ zh_question = "洛林的生日是什么?"
705
+ zh_question = "洛林的家族是什么?"
706
+ zh_question = "洛林的性别是什么?"
707
+ zh_question = "洛林的标题是什么?"
708
+ zh_question = "洛林的主题是什么?"
709
+ zh_question = "这个物品的特征是什么?"
710
+ zh_question = "强效祛斑药水的特征是什么?"
711
+ zh_question = "魔法学校的成立日期是什么?"
712
+ zh_question = "魔法学校的校长是谁?"
713
+ question_p_df = zh_question_to_p_zh_en_map(zh_question)
714
+ #question_p_df
715
+
716
+ #### top en_p as consider (high zh_sim)
717
+ #### need preload to precaculate all candidates in all_p_df
718
+ sym_p_df = search_sym_p(question_p_df, all_p_df)
719
+ #sym_p_df
720
+
721
+ '''
722
+ #### this can be done, all related with translate accurate
723
+ entity_str = "占卜会议"
724
+ search_sym_entity(entity_str, all_ents_df)
725
+
726
+ #### re translate in massive times
727
+ pd.Series(list(spo_trans_dict.keys())).to_csv("../data/all_consider.csv", index = False)
728
+ '''
729
+
730
+ #### ->
731
+ '''
732
+ sparql_queries_reconstruct = template_fullfill_reconstruct_query(
733
+ ["hp:Divination_homework_meeting"],
734
+ ["hp:leader"]
735
+ )
736
+ sparql_queries_reconstruct
737
+ '''
738
+
739
+ def from_zh_question_to_consider_queries(zh_question, top_k = 32, top_p_k = 5, top_e_k = 50, kgqa_retriever = kgqa_retriever,):
740
+ zh_ents = retrieve_et(zh_question)
741
+ if type(zh_ents) != type([]) or not zh_ents:
742
+ return None
743
+ question_p_df = zh_question_to_p_zh_en_map(zh_question, top_k = top_p_k)
744
+ if not hasattr(question_p_df, "size") or question_p_df.size == 0:
745
+ return None
746
+ ### en_p
747
+ sym_p_df = search_sym_p(question_p_df, all_p_df)
748
+ if not hasattr(sym_p_df, "size") or sym_p_df.size == 0:
749
+ return None
750
+ sim_entity_df_list = []
751
+ for entity_str in zh_ents:
752
+ sym_ent_df = search_sym_entity(entity_str, all_ents_df)
753
+ if not hasattr(sym_ent_df, "size") or sym_ent_df.size == 0:
754
+ continue
755
+ sim_entity_df_list.append(sym_ent_df)
756
+ if type(sim_entity_df_list) != type([]) or not sim_entity_df_list:
757
+ return None
758
+
759
+ #### en_ent
760
+ sym_ent_df = pd.concat(sim_entity_df_list, axis = 0).sort_values(by = "zh_sim", ascending = False)
761
+ #return sym_p_df, sym_ent_df
762
+
763
+ top_p = sym_p_df["en_p"].drop_duplicates().dropna().head(top_p_k).values.tolist()
764
+ top_e = sym_ent_df["en_ent"].drop_duplicates().dropna().head(top_e_k).values.tolist()
765
+
766
+ print(
767
+ top_e
768
+ )
769
+ print(
770
+ top_p
771
+ )
772
+
773
+ if not top_p or not top_e:
774
+ return None
775
+
776
+ sparql_queries_reconstruct = template_fullfill_reconstruct_query(
777
+ top_e,
778
+ top_p
779
+ )
780
+ #return sparql_queries_reconstruct
781
+ if not sparql_queries_reconstruct:
782
+ return None
783
+
784
+ output = run_sparql_queries(sparql_queries_reconstruct, kgqa_retriever, top_k = top_k)
785
+ return sparql_queries_reconstruct ,output
786
+
787
+ def trans_output(zh_question ,output):
788
+ if type(output) != type([]):
789
+ return output
790
+ def single_trans(d):
791
+ assert type(d) == type({})
792
+ if not d:
793
+ return d
794
+ req = {}
795
+ answer = d.get("answer")
796
+ if type(answer) == type([]):
797
+ answer = list(map(lambda x:
798
+ spo_trans_dict.get(x.split("/")[-1].replace("_", " "),
799
+ x.split("/")[-1].replace("_", " ")
800
+ ) if x.startswith("https://deepset.ai/harry_potter") else x
801
+ , answer))
802
+ sparql_query = d.get("prediction_meta")
803
+ if sparql_query is not None:
804
+ sparql_query = sparql_query.get("sparql_query")
805
+ if type(sparql_query) == type(""):
806
+ t3_in_query = query_to_t3(sparql_query)
807
+ hp_l = pd.Series(np.asarray(t3_in_query).reshape([-1])).map(lambda x: x[3:] if x.startswith("hp:") else np.nan).dropna().drop_duplicates().values.tolist()
808
+ for ele in sorted(hp_l, key = len, reverse = True):
809
+ sparql_query = sparql_query.replace(ele, spo_trans_dict.get(ele.split("/")[-1].replace("_", " "),
810
+ ele.split("/")[-1].replace("_", " ")))
811
+ if answer is not None:
812
+ req["answer"] = answer
813
+ if sparql_query is not None:
814
+ req["sparql_query"] = sparql_query
815
+ return req
816
+ output_trans = list(map(single_trans, output))
817
+ output_trans = sorted(output_trans, key = lambda d:
818
+ synonyms.compare(zh_question, " " if d.get("sparql_query", " ") else " ") if type(d) == type({}) else 0.0
819
+ , reverse = True)
820
+ return output_trans
821
+
822
+ def ranking_output(zh_question, zh_output):
823
+ e_t_dict = retrieve_et(zh_question, only_e=False)
824
+ e = e_t_dict.get("E-TAG", [])
825
+ t = e_t_dict.get("T-TAG", [])
826
+ e, t = map(" ".join, [e, t])
827
+ print(e, t)
828
+ df = pd.DataFrame(zh_output)
829
+ df = df.explode("answer")
830
+ #### e query
831
+ df["e_score"] = df["sparql_query"].map(lambda x: re.findall("{(.*)}" ,x)[0]).map(lambda x:
832
+ list(filter(lambda y: "?" not in y ,
833
+ list(np.asarray(x.split())[[0, -1]])
834
+ ))
835
+ ).map(" ".join).map(lambda x:
836
+ [e, x.split(":")[-1]]
837
+ ).map(lambda x: list(map(lambda y:
838
+ y.replace(" ", "") ,x))).map(lambda x:
839
+ fuzz.ratio(*x))
840
+ df["t_score"] = df["sparql_query"].map(lambda x: re.findall("{(.*)}" ,x)[0]).map(lambda x:
841
+ list(filter(lambda y: "?" not in y ,
842
+ x.split()[1]
843
+ ))
844
+ ).map(" ".join).map(lambda x:
845
+ [t, x.split(":")[-1]]
846
+ ).map(lambda x: list(map(lambda y:
847
+ y.replace(" ", "") ,x))).map(lambda x:
848
+ fuzz.ratio(*x))
849
+
850
+
851
+ #df["a_score"] = df["answer"].map(lambda x: [x, t]).map(lambda x: synonyms.compare(*x)) * 100
852
+ df["et_score"] = df[["e_score", "t_score", ]].sum(axis = 1)
853
+ df = df.sort_values(by = "et_score", ascending = False)
854
+ if df["et_score"].iloc[0] >= 50:
855
+ return df
856
+ df["e_score"] = df["sparql_query"].map(lambda x: re.findall("{(.*)}" ,x)[0]).map(lambda x:
857
+ list(filter(lambda y: "?" not in y ,
858
+ list(np.asarray(x.split())[[0, -1]])
859
+ ))
860
+ ).map(" ".join).map(lambda x:
861
+ [e, x.split(":")[-1]]
862
+ ).map(lambda x: list(map(lambda y:
863
+ y.replace(" ", "") ,x))).map(lambda x:
864
+ synonyms.compare(*x))
865
+ df["t_score"] = df["sparql_query"].map(lambda x: re.findall("{(.*)}" ,x)[0]).map(lambda x:
866
+ list(filter(lambda y: "?" not in y ,
867
+ x.split()[1]
868
+ ))
869
+ ).map(" ".join).map(lambda x:
870
+ [t, x.split(":")[-1]]
871
+ ).map(lambda x: list(map(lambda y:
872
+ y.replace(" ", "") ,x))).map(lambda x:
873
+ synonyms.compare(*x))
874
+
875
+ #df["a_score"] = df["answer"].map(lambda x: [x, t]).map(lambda x: synonyms.compare(*x))
876
+ #df["a_score"] = df["a_score"] / 100.0
877
+ df["et_score"] = df[["e_score", "t_score", ]].sum(axis = 1)
878
+ df = df.sort_values(by = "et_score", ascending = False)
879
+ return df
880
+
881
+ if __name__ == "__main__":
882
+ #### 血缘 need fintune, tackle with ranking_output
883
+ #### top3 to top5 recall design
884
+ zh_question = "哈利波特的血缘是什么?"
885
+ #output = from_zh_question_to_consider_queries(zh_question)
886
+ output = from_zh_question_to_consider_queries(zh_question,
887
+ top_k = 32, top_p_k = 30, top_e_k = 50
888
+ )
889
+ if type(output) == type((1,)):
890
+ query_list, output = output
891
+ zh_output = trans_output(zh_question ,output)
892
+ else:
893
+ zh_output = None
894
+ zh_output
895
+ ranking_output(zh_question, zh_output)
896
+
897
+ zh_question = "哈利波特的生日是什么?"
898
+ #output = from_zh_question_to_consider_queries(zh_question)
899
+ output = from_zh_question_to_consider_queries(zh_question,
900
+ top_k = 32, top_p_k = 30, top_e_k = 50
901
+ )
902
+ if type(output) == type((1,)):
903
+ query_list, output = output
904
+ zh_output = trans_output(zh_question ,output)
905
+ else:
906
+ zh_output = None
907
+ zh_output
908
+ ranking_output(zh_question, zh_output)
909
+
910
+ zh_question = "史内普的生日是什么时候?"
911
+ #output = from_zh_question_to_consider_queries(zh_question)
912
+ output = from_zh_question_to_consider_queries(zh_question,
913
+ top_k = 32, top_p_k = 30, top_e_k = 50
914
+ )
915
+ if type(output) == type((1,)):
916
+ query_list, output = output
917
+ zh_output = trans_output(zh_question ,output)
918
+ else:
919
+ zh_output = None
920
+ zh_output
921
+ ranking_output(zh_question, zh_output)
922
+
923
+ zh_question = "占卜会议的领导者是谁?"
924
+ #output = from_zh_question_to_consider_queries(zh_question)
925
+ output = from_zh_question_to_consider_queries(zh_question,
926
+ top_k = 32, top_p_k = 30, top_e_k = 50
927
+ )
928
+ if type(output) == type((1,)):
929
+ query_list, output = output
930
+ zh_output = trans_output(zh_question ,output)
931
+ else:
932
+ zh_output = None
933
+ zh_output
934
+ ranking_output(zh_question, zh_output)
935
+
936
+ zh_question = "纽约卫生局的创立时间是什么?"
937
+ #output = from_zh_question_to_consider_queries(zh_question)
938
+ output = from_zh_question_to_consider_queries(zh_question,
939
+ top_k = 32, top_p_k = 30, top_e_k = 50
940
+ )
941
+ if type(output) == type((1,)):
942
+ query_list, output = output
943
+ zh_output = trans_output(zh_question ,output)
944
+ else:
945
+ zh_output = None
946
+ zh_output
947
+ ranking_output(zh_question, zh_output)
948
+
949
+ zh_question = "法兰西魔法部记录室位于哪个城市?"
950
+ #output = from_zh_question_to_consider_queries(zh_question)
951
+ output = from_zh_question_to_consider_queries(zh_question,
952
+ top_k = 32, top_p_k = 30, top_e_k = 50
953
+ )
954
+ if type(output) == type((1,)):
955
+ query_list, output = output
956
+ zh_output = trans_output(zh_question ,output)
957
+ else:
958
+ zh_output = None
959
+ zh_output
960
+ ranking_output(zh_question, zh_output)
961
+
962
+ zh_question = "邓布利多的出生日期是什么?"
963
+ #output = from_zh_question_to_consider_queries(zh_question)
964
+ output = from_zh_question_to_consider_queries(zh_question,
965
+ top_k = 32, top_p_k = 30, top_e_k = 50
966
+ )
967
+ if type(output) == type((1,)):
968
+ query_list, output = output
969
+ zh_output = trans_output(zh_question ,output)
970
+ else:
971
+ zh_output = None
972
+ zh_output
973
+ ranking_output(zh_question, zh_output)
974
+
975
+ zh_question = "哥布林叛乱发生在什么日期?"
976
+ #output = from_zh_question_to_consider_queries(zh_question, top_p_k = 50)
977
+ output = from_zh_question_to_consider_queries(zh_question,
978
+ top_k = 32, top_p_k = 30, top_e_k = 50
979
+ )
980
+ if type(output) == type((1,)):
981
+ query_list, output = output
982
+ zh_output = trans_output(zh_question ,output)
983
+ else:
984
+ zh_output = None
985
+ zh_output
986
+ ranking_output(zh_question, zh_output)
987
+
988
+ zh_question = "决斗比赛的参与者是谁?"
989
+ #output = from_zh_question_to_consider_queries(zh_question)
990
+ output = from_zh_question_to_consider_queries(zh_question,
991
+ top_k = 32, top_p_k = 30, top_e_k = 50
992
+ )
993
+ if type(output) == type((1,)):
994
+ query_list, output = output
995
+ zh_output = trans_output(zh_question ,output)
996
+ else:
997
+ zh_output = None
998
+ zh_output
999
+ ranking_output(zh_question, zh_output)
1000
+
1001
+ zh_question = "赫敏的丈夫是谁?"
1002
+ #output = from_zh_question_to_consider_queries(zh_question)
1003
+ output = from_zh_question_to_consider_queries(zh_question,
1004
+ top_k = 32, top_p_k = 30, top_e_k = 50
1005
+ )
1006
+ if type(output) == type((1,)):
1007
+ query_list, output = output
1008
+ zh_output = trans_output(zh_question ,output)
1009
+ else:
1010
+ zh_output = None
1011
+ zh_output
1012
+ ranking_output(zh_question, zh_output)