svjack's picture
Update wiki_kb_qa_migrate.py
aeee3b1
raw
history blame
38.6 kB
#### qa_env
#from conf import *
import os
d = dict(SYNONYMS_WORD2VEC_BIN_URL_ZH_CN = \
"https://github.com/chatopera/Synonyms/releases/download/3.15.0/words.vector.gz")
for k, v in d.items():
os.environ[k] = v
from qa import *
from translate_by_api import *
from extract_by_api import *
from extract_et_by_api import *
import os
import logging
import subprocess
import time
from pathlib import Path
from haystack.nodes import Text2SparqlRetriever
from haystack.document_stores import GraphDBKnowledgeGraph, InMemoryKnowledgeGraph
#from haystack.utils import fetch_archive_from_http
import pandas as pd
import numpy as np
import os
import sys
#import jieba
from functools import partial, reduce, lru_cache
#from easynmt import EasyNMT
#from sentence_transformers.util import pytorch_cos_sim
#from sentence_transformers import SentenceTransformer
from time import time
from itertools import product
#import pickle as pkl
from urllib.parse import unquote
import requests
import json
import pandas as pd
import numpy as np
import os
import sys
#import jieba
from functools import partial, reduce, lru_cache
#from easynmt import EasyNMT
#from sentence_transformers.util import pytorch_cos_sim
#from sentence_transformers import SentenceTransformer
from time import time
from itertools import product
#import pickle as pkl
#import faiss
from rapidfuzz import fuzz
import synonyms
import sys
#sys.path.insert(0 ,"/Users/svjack/temp/HP_kbqa/script")
#from trans_toolkit import *
#from easynmt import EasyNMT
#zh_en_naive_model = EasyNMT("m2m_100_418M")
'''
p00 = os.path.join(model_path, "zh_en_m2m")
assert os.path.exists(p00)
zh_en_naive_model = EasyNMT(p00)
zh_en_naive_model.translate(["宁波在哪?"], source_lang="zh", target_lang = "en")
'''
'''
from haystack.nodes import FARMReader
#question_reader_save_path = "/Users/svjack/temp/model/en_zh_question_reader_save_epc_2_spo"
question_reader_save_path = os.path.join(model_path, "en_zh_question_reader_save_epc_2_spo")
assert os.path.exists(question_reader_save_path)
en_zh_reader = FARMReader(model_name_or_path=question_reader_save_path, use_gpu=False,
num_processes = 0
)
'''
kg = InMemoryKnowledgeGraph(index="tutorial_10_index")
kg.delete_index()
kg.create_index()
kg.import_from_ttl_file(index="tutorial_10_index", path=Path("data") / "triples.ttl")
#kg.get_params()
#all_triples = kg.get_all_triples()
#spo_df = pd.DataFrame(all_triples)
#### some collection in kb_aug
import re
def transform_namespace_to_prefix_str(g):
namespaces = g.namespaces()
return "\n".join(map(lambda x: "PREFIX {}: <{}>".format(x[0], x[1]), namespaces))
#print(transform_namespace_to_prefix_str(kg.indexes["tutorial_10_index"]))
### ->
wiki_prefix = '''
PREFIX brick: <https://brickschema.org/schema/Brick#>
PREFIX csvw: <http://www.w3.org/ns/csvw#>
PREFIX dc: <http://purl.org/dc/elements/1.1/>
PREFIX dcat: <http://www.w3.org/ns/dcat#>
PREFIX dcmitype: <http://purl.org/dc/dcmitype/>
PREFIX dcterms: <http://purl.org/dc/terms/>
PREFIX dcam: <http://purl.org/dc/dcam/>
PREFIX doap: <http://usefulinc.com/ns/doap#>
PREFIX foaf: <http://xmlns.com/foaf/0.1/>
PREFIX odrl: <http://www.w3.org/ns/odrl/2/>
PREFIX org: <http://www.w3.org/ns/org#>
PREFIX owl: <http://www.w3.org/2002/07/owl#>
PREFIX prof: <http://www.w3.org/ns/dx/prof/>
PREFIX prov: <http://www.w3.org/ns/prov#>
PREFIX qb: <http://purl.org/linked-data/cube#>
PREFIX rdf: <http://www.w3.org/1999/02/22-rdf-syntax-ns#>
PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>
PREFIX schema: <https://schema.org/>
PREFIX sh: <http://www.w3.org/ns/shacl#>
PREFIX skos: <http://www.w3.org/2004/02/skos/core#>
PREFIX sosa: <http://www.w3.org/ns/sosa/>
PREFIX ssn: <http://www.w3.org/ns/ssn/>
PREFIX time: <http://www.w3.org/2006/time#>
PREFIX vann: <http://purl.org/vocab/vann/>
PREFIX void: <http://rdfs.org/ns/void#>
PREFIX xsd: <http://www.w3.org/2001/XMLSchema#>
PREFIX xml: <http://www.w3.org/XML/1998/namespace>
PREFIX hp: <https://deepset.ai/harry_potter/>
'''
prefix_s = pd.Series(wiki_prefix.split("\n")).map(
lambda x: x if x.startswith("PREFIX") else np.nan
).dropna().map(
lambda x: re.findall("PREFIX (.*): <", x)
).map(lambda x: x[0])
prefix_url_dict = dict(map(
lambda y: (y.split(" ")[1].replace(":", ""), y.split(" ")[2].strip()[1:-1])
,filter(
lambda x: x.strip()
, wiki_prefix.split("\n"))))
url_prefix_dict = dict(map(lambda t2: t2[::-1], prefix_url_dict.items()))
all_triples = kg.get_all_triples()
spo_df = pd.DataFrame(all_triples)
spo_df_simple = spo_df.copy()
spo_df_simple = spo_df_simple.applymap(lambda x: x["value"]).applymap(lambda x:
(list(filter(lambda t2: x.startswith(t2[0]), url_prefix_dict.items()))[0], x) if any(map(lambda t2: x.startswith(t2[0]), url_prefix_dict.items())) else (None, x)
).applymap(
lambda t2: t2[1].replace(t2[0][0], "{}:".format(t2[0][1])) if t2[0] is not None else t2[1]
).applymap(unquote)
'''
#### like property in wikidata
spo_df_simple["p"].map(
lambda x: x[3:] if x.startswith("hp:") else np.nan
).dropna().value_counts()
#### others in p col (rdf:type)
spo_df_simple["p"].map(
lambda x: x if not x.startswith("hp:") else np.nan
).dropna().value_counts()
#### groupby different entity type view
pd.concat(
list(map(
lambda t2: t2[1].head(2),
list(spo_df_simple[
spo_df_simple["p"] == "rdf:type"
].sort_values(by = ["o", "s"]).groupby("o"))
)), axis = 0).head(30)
'''
#### spo s(type)o
#### use deepl translate to lookup
#spo_trans_total_df = pd.read_csv("../data/spo_trans_total.csv")
spo_trans_total_df = pd.read_csv("data/spo_trans_total.csv")
spo_trans_dict = dict(spo_trans_total_df.values.tolist())
'''
with open("../data/spo_trans_dict.json", "w") as f:
json.dump(spo_trans_dict, f)
'''
spo_trans_back_dict = dict(map(lambda t2: t2[::-1], spo_trans_dict.items()))
spo_df_simple_keyed = spo_df_simple.copy()
def map_to_trans_key(src):
x = str(src)
if not x.startswith("hp:"):
return np.nan
return x[3:].replace('"', '').replace("'", '').replace("_", " ")
spo_df_simple_trans = spo_df_simple_keyed.applymap(
lambda x: (x ,map_to_trans_key(x))
).applymap(
lambda t2: spo_trans_dict.get(t2[1], t2[0]) if type(t2[1]) == type("") else t2[0]
)
'''
pd.concat(
list(map(
lambda t2: t2[1].head(2),
list(spo_df_simple_trans[
spo_df_simple_trans["p"] == "rdf:type"
].sort_values(by = ["o", "s"]).groupby("o"))
)), axis = 0).head(50)
spo_df_simple_trans[
spo_df_simple_trans["s"] == "斯蒂芬-康福特"
]
'''
model_dir = "data/"
kgqa_retriever = Text2SparqlRetriever(knowledge_graph=kg, model_name_or_path=model_dir + "hp_v3.4")
def decode_query(eng_query ,kgqa_retriever, top_k = 3):
self = kgqa_retriever
inputs = self.tok([eng_query], max_length=100, truncation=True, return_tensors="pt")
# generate top_k+2 SPARQL queries so that we can dismiss some queries with wrong syntax
temp = self.model.generate(
inputs["input_ids"], num_beams=max(5, top_k + 2), max_length=100, num_return_sequences=top_k + 2, early_stopping=True
)
sparql_queries = [
self.tok.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in temp
]
return sparql_queries
import re
from uuid import uuid1
import jionlp as jio
special_match_token_list = [
" filter(",
]
def fill_bk(str_):
#assert str_[0] == "("
req = []
cnt = 0
have_match_one = False
for char in str_:
#print(req)
if char == "(":
cnt += 1
have_match_one = True
if char == ")":
cnt -= 1
req.append(char)
if cnt == 0 and have_match_one:
break
return "".join(req)
def match_special_token(query, special_match_token_list):
assert type(query) == type("")
assert type(special_match_token_list) == type([])
special_match_token_list_ = list(filter(lambda x: x in query, special_match_token_list))
if not special_match_token_list_:
return []
return list(map(lambda x: (x ,
fill_bk(
query[query.find(x):]
)
), special_match_token_list_))
def retrieve_sent_split(sent,
stops_split_pattern = "|".join(map(lambda x: r"\{}".format(x),
" "))
):
if not sent.strip():
return []
split_list = re.split(stops_split_pattern, sent)
return split_list
import jionlp as jio
ask_l = [
"?answer", "?value", "?obj", "?sbj", "?s", "?x", "?a"
]
ask_ner = jio.ner.LexiconNER({
"ask": ask_l
})
def query_to_t3(query, filter_list = [], ask_ner = ask_ner):
'''
query = query.replace("?answer", " ?answer ")
query = query.replace("?value", " ?value ")
query = query.replace("?obj", " ?obj ")
query = query.replace("?sbj", " ?sbj ")
query = query.replace("?s", " ?s ")
query = query.replace("?x", " ?x ")
'''
l = ask_ner(query)
l = sorted(set(map(lambda x: x["text"], l)), key = len, reverse = True)
for k in l:
query = query.replace(k, " {} ".format(k))
'''
if "where" not in query and "WHERE" not in query:
return []
'''
special_token_list = match_special_token(query, special_match_token_list)
#return special_token_list
if special_token_list:
special_token_list = list(set(map(lambda t2: t2[1] ,special_token_list)))
uid_special_token_dict = dict(map(lambda x: (str(uuid1()), x), special_token_list))
special_token_uid_dict = dict(map(lambda t2: t2[::-1], uid_special_token_dict.items()))
assert len(special_token_uid_dict) == len(uid_special_token_dict)
for k, v in sorted(special_token_uid_dict.items(), key = lambda t2: len(t2[0]), reverse = True):
if k in query:
#query = query.replace(k, v)
query = query.replace(k, "")
else:
uid_special_token_dict = {}
special_token_uid_dict = {}
'''
if "where" in query:
tail = "where".join(query.split("where")[1:])
elif "WHERE" in query:
tail = "WHERE".join(query.split("WHERE")[1:])
'''
#return query
query = query.strip()
if not query.endswith("}"):
query = query + "}"
tail = re.findall(r"{(.*)}", query)
#return tail
#return t3_list
if not tail:
return []
else:
tail = tail[0]
t3_list = list(map(lambda x: x.strip() ,tail.split(".")))
t3_list_ = []
for ele in t3_list:
for k, v in uid_special_token_dict.items():
if k in ele:
ele = ele.replace(k, v)
t3_list_.append(ele)
t3_list = t3_list_
if filter_list:
t3_list = list(filter(lambda x:
any(map(lambda y: y in x ,filter_list))
, t3_list))
t3_list = list(map(lambda x:
list(filter(lambda y: y.strip() ,retrieve_sent_split(x)))
, t3_list))
return t3_list
def decode_property(eng_query ,kgqa_retriever, top_k = 3):
sparql_queries = decode_query(eng_query, kgqa_retriever, top_k = top_k)
if not sparql_queries:
return []
t3_nest_list = list(map(lambda x: query_to_t3(x), sparql_queries))
####return t3_nest_list
p_nest_list = []
for ele in t3_nest_list:
for e in ele:
if len(e) == 3:
p_nest_list.append(e)
#p_nest_list = list(filter(lambda x: len(x) == 3, t3_nest_list))
if not p_nest_list:
return []
p_nest_list = list(map(lambda x: x[1], p_nest_list))
return p_nest_list
'''
#### ori query decoder
query = "Harry Potter live in which house?"
query = "when was Stephen cornfoot born?"
decode_query(query, kgqa_retriever)
#### ori query decoder only maintain property part
query = "Harry Potter live in which house in 1920?"
query = "Harry live in where?"
query = "Harry live in where?"
query = "when was Stephen cornfoot born?"
query = "what is Stephen's loyalty?"
decode_property(query, kgqa_retriever)
query = "who is the leader of Divination homework meeting?"
'''
def template_fullfill_reconstruct_query(entity_list = ["http://www.wikidata.org/entity/Q42780"]
, property_list = ["http://www.wikidata.org/prop/direct/P131",
"http://www.wikidata.org/prop/direct/P150"
],
generate_t3_func = lambda el, pl: pd.Series(list(product(el, pl))).map(
lambda ep: [(ep[0], ep[1], "?a"), ("?a", ep[1], ep[0])]
).explode().dropna().drop_duplicates().tolist()
):
assert type(entity_list) == type([])
assert type(property_list) == type([])
if not entity_list or not property_list:
return []
query_list = list(map(list ,generate_t3_func(entity_list, property_list)))
if not query_list:
return []
req = list(map(lambda x: "select ?a {" + " ".join(x) + "}", query_list))
return req
'''
sparql_queries_reconstruct = template_fullfill_reconstruct_query(
["hp:Divination_homework_meeting"],
["hp:leader"]
)
sparql_queries_reconstruct
'''
def run_sparql_queries(sparql_queries, kgqa_retriever, top_k = 3):
self = kgqa_retriever
answers = []
for sparql_query in sparql_queries:
ans, query = self._query_kg(sparql_query=sparql_query)
if len(ans) > 0:
answers.append((ans, query))
# if there are no answers we still want to return something
if len(answers) == 0:
answers.append(("", ""))
results = answers[:top_k]
results = [self.format_result(result) for result in results]
return results
'''
#### one conclusion
run_sparql_queries(sparql_queries_reconstruct, kgqa_retriever)
'''
#### start kbqa_protable_service (server)
def retrieve_et(zh_question, only_e = True):
assert type(zh_question) == type("")
'''
qst = zh_question
rep = requests.post(
url = "http://localhost:8855/extract_et",
data = {
"question": qst
}
)
output = json.loads(rep.content.decode())
'''
output = call_entity_property_extract(zh_question)
if only_e:
return output.get("E-TAG", [])
return output
'''
#### start qa server
def retrieve_head(zh_question):
req = requests.post(
url = "http://localhost:8811/qa_downstream_process",
data = {
"entity": "",
"question": zh_question,
"context": zh_question
}
)
output = json.loads(req.content.decode())
if "head" in output:
return output["head"]
return ""
'''
def retrieve_head(zh_question):
output = qa_downstream_process(
"", zh_question, zh_question
)
assert type(output) == type({})
if "head" in output:
return output["head"]
return ""
'''
zh_question = "谁是占卜会议的领导者?"
retrieve_et(zh_question)
'''
def property_and_type_slice(spo_df_simple_trans, p_l = [], type_l = []):
req = spo_df_simple_trans.copy()
if type_l:
s_l = req[
req["o"].isin(type_l)
]["s"].drop_duplicates().dropna().values.tolist()
req = req[
req["s"].isin(s_l)
]
if req.size == 0:
return None
if p_l:
s_l = req[
req["p"].isin(p_l)
]["s"].drop_duplicates().dropna().values.tolist()
req = req[
req["s"].isin(s_l)
]
if req.size == 0:
return None
return req
'''
### Organisation_ sanple
property_and_type_slice(
spo_df_simple_trans, p_l = ["创立"], type_l = ["hp:Organisation_"]
).sort_values(by = "s")["s"].drop_duplicates().sample(n = 30)
### people sample
property_and_type_slice(
spo_df_simple_trans, p_l = ["出生"], type_l = ["hp:Individual_"]
).sort_values(by = "s")["s"].drop_duplicates().sample(n = 30)
zh_question = "谁是占卜会议的领导者?"
en_question = zh_en_naive_model.translate([zh_question], source_lang="zh", target_lang = "en")[0]
en_properties = decode_property(en_question, kgqa_retriever)
en_properties
'''
all_en_p = spo_df_simple["p"].drop_duplicates().dropna().values.tolist()
all_en_p_tokens = pd.Series(list(map(lambda x: x[3:].split("_") ,filter(lambda x: x.startswith("hp:"), all_en_p)))).explode().dropna().map(
lambda x: x if bool(x) else np.nan
).dropna().drop_duplicates().values.tolist()
###all_en_p_tokens[:10]
all_p_df = pd.Series(all_en_p).reset_index().iloc[:, 1:]
all_p_df.columns = ["en_p"]
all_p_df = all_p_df[
all_p_df["en_p"] != "rdf:type"
]
all_p_df["zh_p"] = all_p_df["en_p"].map(
lambda x: spo_trans_dict.get(x.replace("hp:", "").replace("_", " "), x.replace("hp:", "").replace("_", " "))
)
#all_p_df
#### decoder property mapping: (map decoder to kb exists)
decode_map_config_dict = {
"hp:birth": 'hp:born',
'hp:birthday': "hp:born"
}
#### decoder sim property mapping: (decoder that can not distinguish)
decode_sim_config_dict = {
'hp:ingredients': "hp:characteristics",
"hp:characteristics": 'hp:ingredients'
}
def decode_property_link_to_ori(decode_property, all_en_p, all_en_p_tokens, equal_threshold = 80):
if not decode_property.startswith("hp:") or not len(decode_property) >= 3:
return None
if decode_property in all_en_p:
return [(decode_property, 100.0)]
if decode_property in decode_map_config_dict:
return [(decode_map_config_dict[decode_property], 99.0)]
def filter_by_p_tokens(decode_property):
req = []
for ele in decode_property[3:].split("_"):
if ele in all_en_p_tokens:
req.append(ele)
return "hp:{}".format("_".join(req))
if decode_property == "hp:":
return None
decode_property = filter_by_p_tokens(decode_property)
order_list = sorted(map(lambda x: (x, fuzz.ratio(x, decode_property)), all_en_p), key = lambda t2: t2[1], reverse = True)
return order_list[:10]
'''
#### minimize maintain one token sorted.
decode_property_link_to_ori("hp:born", all_en_p, all_en_p_tokens, equal_threshold = 80)
decode_property_link_to_ori("hp:birth", all_en_p, all_en_p_tokens, equal_threshold = 80)
decode_property_link_to_ori("hp:head_of_the_assembly", all_en_p, all_en_p_tokens, equal_threshold = 80)
'''
def output_to_dict(output, trans_keys = ["answers"]):
non_trans_t2_list = list(filter(lambda t2: t2[0] not in trans_keys, output.items()))
trans_t2_list = list(map(lambda tt2: (
tt2[0],
list(map(lambda x: x.to_dict(), tt2[1]))
) ,filter(lambda t2: t2[0] in trans_keys, output.items())))
#return trans_t2_list
return dict(trans_t2_list + non_trans_t2_list)
def zh_question_to_p_zh_en_map(zh_question, top_k = 3):
#zh_question = "谁是占卜会议的领导者?"
#en_question = zh_en_naive_model.translate([zh_question], source_lang="zh", target_lang = "en")[0]
en_question = call_zh_en_naive_model(zh_question)
en_properties = decode_property(en_question, kgqa_retriever, top_k = top_k)
if not en_properties:
return None
en_properties_top_sort = pd.Series(en_properties).value_counts().index.tolist()
en_properties_mapped = list(map(
lambda x: decode_property_link_to_ori(x, all_en_p, all_en_p_tokens, equal_threshold = 80), en_properties_top_sort
))
en_properties_mapped = list(filter(lambda x: hasattr(x, "__len__") and len(x) >= 1, en_properties_mapped))
if not en_properties_mapped:
return None
en_properties_mapped = list(map(lambda x: x[0] ,en_properties_mapped))
en_properties_mapped_df = pd.DataFrame(en_properties_mapped)
assert en_properties_mapped_df.shape[1] == 2
en_properties_mapped_df.columns = ["en_property", "score"]
'''
en_properties_mapped_df["zh_property"] = en_properties_mapped_df["en_property"].map(
lambda x: en_zh_reader.predict_on_texts(
question=x.replace("hp:", ""),
texts=[zh_question]
)
).map(output_to_dict)
'''
en_properties_mapped_df["zh_property"] = en_properties_mapped_df["en_property"].map(
lambda x: call_en_zh_reader(
x.replace("hp:", ""),
zh_question
)
)
en_properties_mapped_df["zh_property"] = en_properties_mapped_df["zh_property"].map(lambda x: x["answers"][0] if x["answers"] else {})
en_properties_mapped_df = en_properties_mapped_df[
en_properties_mapped_df["zh_property"].map(bool)
]
if en_properties_mapped_df is None or en_properties_mapped_df.size == 0:
return None
#return nerd_df
en_properties_mapped_df["ext_score"] = en_properties_mapped_df["zh_property"].map(
lambda x: x["score"]
)
en_properties_mapped_df["zh_property"] = en_properties_mapped_df["zh_property"].map(
lambda x: x["answer"]
)
'''
en_properties_mapped_df = en_properties_mapped_df[
en_properties_mapped_df["ext_score"].map(lambda x: x > score_threshold)
]
'''
if en_properties_mapped_df is None or en_properties_mapped_df.size == 0:
return None
ask_head = retrieve_head(zh_question)
#if type(ask_head) == type("") and "什么" in ask_head:
if type(ask_head) == type(""):
#ask_head = ask_head.replace("什么", "")
first_d = en_properties_mapped_df.iloc[0].to_dict()
first_d["zh_property"] = ask_head
en_properties_mapped_df = pd.DataFrame(
[first_d] + en_properties_mapped_df.apply(lambda x: x.to_dict(), axis = 1).values.tolist()
)
else:
pass
en_properties_mapped_df = en_properties_mapped_df[
en_properties_mapped_df["zh_property"].map(lambda x: bool(x))
].drop_duplicates()
return en_properties_mapped_df
def search_sym_p(question_p_df, all_p_df):
#zh_p_l = question_p_df["zh_property"].drop_duplicates().values.tolist()
#en_p_l = question_p_df["en_property"].drop_duplicates().values.tolist()
req = []
for idx, r in question_p_df.iterrows():
all_p_score_df = all_p_df.copy()
all_p_score_df["zh_property"] = [r["zh_property"]] * len(all_p_score_df)
all_p_score_df["en_property"] = [r["en_property"]] * len(all_p_score_df)
req.append(all_p_score_df)
req = pd.concat(req, axis = 0)
req["zh_sim"] = req.apply(
lambda x: synonyms.compare(x["zh_property"], x["zh_p"]), axis = 1
)
req = req.sort_values(by = "zh_sim", ascending = False)
return req
all_en_ents = pd.Series(spo_df_simple[["s", "o"]].values.reshape([-1])).drop_duplicates().values.tolist()
all_ents_df = pd.Series(all_en_ents).reset_index().iloc[:, 1:]
all_ents_df.columns = ["en_ent"]
all_ents_df = all_ents_df[
all_ents_df["en_ent"] != "rdf:type"
]
all_ents_df["zh_ent"] = all_ents_df["en_ent"].map(
lambda x: spo_trans_dict.get(x.replace("hp:", "").replace("_", " "), x.replace("hp:", "").replace("_", " "))
)
#all_ents_df
def search_sym_entity(entity_str, all_ents_df, use_syn = False):
#zh_p_l = question_p_df["zh_property"].drop_duplicates().values.tolist()
#en_p_l = question_p_df["en_property"].drop_duplicates().values.tolist()
req = all_ents_df.copy()
req["entity_str"] = [entity_str] * len(req)
if use_syn:
req["zh_sim"] = req.apply(
lambda x: synonyms.compare(x["zh_ent"], x["entity_str"]), axis = 1
)
else:
req["zh_sim"] = req.apply(
lambda x: fuzz.ratio(x["zh_ent"], x["entity_str"]), axis = 1
)
req = req.sort_values(by = "zh_sim", ascending = False)
return req
zh_question = "谁是占卜会议的领导者?"
zh_question = "洛林出生在哪个国家?"
zh_question = "洛林出生在哪个地方?"
zh_question = "洛林的血缘是什么?"
zh_question = "洛林的生日是什么?"
zh_question = "洛林的家族是什么?"
zh_question = "洛林的性别是什么?"
zh_question = "洛林的标题是什么?"
zh_question = "洛林的主题是什么?"
zh_question = "这个物品的特征是什么?"
zh_question = "强效祛斑药水的特征是什么?"
zh_question = "魔法学校的成立日期是什么?"
zh_question = "魔法学校的校长是谁?"
question_p_df = zh_question_to_p_zh_en_map(zh_question)
#question_p_df
#### top en_p as consider (high zh_sim)
#### need preload to precaculate all candidates in all_p_df
sym_p_df = search_sym_p(question_p_df, all_p_df)
#sym_p_df
'''
#### this can be done, all related with translate accurate
entity_str = "占卜会议"
search_sym_entity(entity_str, all_ents_df)
#### re translate in massive times
pd.Series(list(spo_trans_dict.keys())).to_csv("../data/all_consider.csv", index = False)
'''
#### ->
'''
sparql_queries_reconstruct = template_fullfill_reconstruct_query(
["hp:Divination_homework_meeting"],
["hp:leader"]
)
sparql_queries_reconstruct
'''
def from_zh_question_to_consider_queries(zh_question, top_k = 32, top_p_k = 5, top_e_k = 50, kgqa_retriever = kgqa_retriever,):
zh_ents = retrieve_et(zh_question)
if type(zh_ents) != type([]) or not zh_ents:
return None
question_p_df = zh_question_to_p_zh_en_map(zh_question, top_k = top_p_k)
if not hasattr(question_p_df, "size") or question_p_df.size == 0:
return None
### en_p
sym_p_df = search_sym_p(question_p_df, all_p_df)
if not hasattr(sym_p_df, "size") or sym_p_df.size == 0:
return None
sim_entity_df_list = []
for entity_str in zh_ents:
sym_ent_df = search_sym_entity(entity_str, all_ents_df)
if not hasattr(sym_ent_df, "size") or sym_ent_df.size == 0:
continue
sim_entity_df_list.append(sym_ent_df)
if type(sim_entity_df_list) != type([]) or not sim_entity_df_list:
return None
#### en_ent
sym_ent_df = pd.concat(sim_entity_df_list, axis = 0).sort_values(by = "zh_sim", ascending = False)
#return sym_p_df, sym_ent_df
top_p = sym_p_df["en_p"].drop_duplicates().dropna().head(top_p_k).values.tolist()
top_e = sym_ent_df["en_ent"].drop_duplicates().dropna().head(top_e_k).values.tolist()
print(
top_e
)
print(
top_p
)
if not top_p or not top_e:
return None
sparql_queries_reconstruct = template_fullfill_reconstruct_query(
top_e,
top_p
)
#return sparql_queries_reconstruct
if not sparql_queries_reconstruct:
return None
output = run_sparql_queries(sparql_queries_reconstruct, kgqa_retriever, top_k = top_k)
return sparql_queries_reconstruct ,output
def trans_output(zh_question ,output):
if type(output) != type([]):
return output
def single_trans(d):
assert type(d) == type({})
if not d:
return d
req = {}
answer = d.get("answer")
if type(answer) == type([]):
answer = list(map(lambda x:
spo_trans_dict.get(x.split("/")[-1].replace("_", " "),
x.split("/")[-1].replace("_", " ")
) if x.startswith("https://deepset.ai/harry_potter") else x
, answer))
sparql_query = d.get("prediction_meta")
if sparql_query is not None:
sparql_query = sparql_query.get("sparql_query")
if type(sparql_query) == type(""):
t3_in_query = query_to_t3(sparql_query)
hp_l = pd.Series(np.asarray(t3_in_query).reshape([-1])).map(lambda x: x[3:] if x.startswith("hp:") else np.nan).dropna().drop_duplicates().values.tolist()
for ele in sorted(hp_l, key = len, reverse = True):
sparql_query = sparql_query.replace(ele, spo_trans_dict.get(ele.split("/")[-1].replace("_", " "),
ele.split("/")[-1].replace("_", " ")))
if answer is not None:
req["answer"] = answer
if sparql_query is not None:
req["sparql_query"] = sparql_query
return req
output_trans = list(map(single_trans, output))
output_trans = sorted(output_trans, key = lambda d:
synonyms.compare(zh_question, " " if d.get("sparql_query", " ") else " ") if type(d) == type({}) else 0.0
, reverse = True)
return output_trans
def ranking_output(zh_question, zh_output):
e_t_dict = retrieve_et(zh_question, only_e=False)
e = e_t_dict.get("E-TAG", [])
t = e_t_dict.get("T-TAG", [])
e, t = map(" ".join, [e, t])
print(e, t)
df = pd.DataFrame(zh_output)
df = df.explode("answer")
#### e query
df["e_score"] = df["sparql_query"].map(lambda x: re.findall("{(.*)}" ,x)[0]).map(lambda x:
list(filter(lambda y: "?" not in y ,
list(np.asarray(x.split())[[0, -1]])
))
).map(" ".join).map(lambda x:
[e, x.split(":")[-1]]
).map(lambda x: list(map(lambda y:
y.replace(" ", "") ,x))).map(lambda x:
fuzz.ratio(*x))
df["t_score"] = df["sparql_query"].map(lambda x: re.findall("{(.*)}" ,x)[0]).map(lambda x:
list(filter(lambda y: "?" not in y ,
x.split()[1]
))
).map(" ".join).map(lambda x:
[t, x.split(":")[-1]]
).map(lambda x: list(map(lambda y:
y.replace(" ", "") ,x))).map(lambda x:
fuzz.ratio(*x))
#df["a_score"] = df["answer"].map(lambda x: [x, t]).map(lambda x: synonyms.compare(*x)) * 100
df["et_score"] = df[["e_score", "t_score", ]].sum(axis = 1)
df = df.sort_values(by = "et_score", ascending = False)
if df["et_score"].iloc[0] >= 50:
return df
df["e_score"] = df["sparql_query"].map(lambda x: re.findall("{(.*)}" ,x)[0]).map(lambda x:
list(filter(lambda y: "?" not in y ,
list(np.asarray(x.split())[[0, -1]])
))
).map(" ".join).map(lambda x:
[e, x.split(":")[-1]]
).map(lambda x: list(map(lambda y:
y.replace(" ", "") ,x))).map(lambda x:
synonyms.compare(*x))
df["t_score"] = df["sparql_query"].map(lambda x: re.findall("{(.*)}" ,x)[0]).map(lambda x:
list(filter(lambda y: "?" not in y ,
x.split()[1]
))
).map(" ".join).map(lambda x:
[t, x.split(":")[-1]]
).map(lambda x: list(map(lambda y:
y.replace(" ", "") ,x))).map(lambda x:
synonyms.compare(*x))
#df["a_score"] = df["answer"].map(lambda x: [x, t]).map(lambda x: synonyms.compare(*x))
#df["a_score"] = df["a_score"] / 100.0
df["et_score"] = df[["e_score", "t_score", ]].sum(axis = 1)
df = df.sort_values(by = "et_score", ascending = False)
return df
if __name__ == "__main__":
#### 血缘 need fintune, tackle with ranking_output
#### top3 to top5 recall design
zh_question = "哈利波特的血缘是什么?"
#output = from_zh_question_to_consider_queries(zh_question)
output = from_zh_question_to_consider_queries(zh_question,
top_k = 32, top_p_k = 30, top_e_k = 50
)
if type(output) == type((1,)):
query_list, output = output
zh_output = trans_output(zh_question ,output)
else:
zh_output = None
zh_output
ranking_output(zh_question, zh_output)
zh_question = "哈利波特的生日是什么?"
#output = from_zh_question_to_consider_queries(zh_question)
output = from_zh_question_to_consider_queries(zh_question,
top_k = 32, top_p_k = 30, top_e_k = 50
)
if type(output) == type((1,)):
query_list, output = output
zh_output = trans_output(zh_question ,output)
else:
zh_output = None
zh_output
ranking_output(zh_question, zh_output)
zh_question = "史内普的生日是什么时候?"
#output = from_zh_question_to_consider_queries(zh_question)
output = from_zh_question_to_consider_queries(zh_question,
top_k = 32, top_p_k = 30, top_e_k = 50
)
if type(output) == type((1,)):
query_list, output = output
zh_output = trans_output(zh_question ,output)
else:
zh_output = None
zh_output
ranking_output(zh_question, zh_output)
zh_question = "占卜会议的领导者是谁?"
#output = from_zh_question_to_consider_queries(zh_question)
output = from_zh_question_to_consider_queries(zh_question,
top_k = 32, top_p_k = 30, top_e_k = 50
)
if type(output) == type((1,)):
query_list, output = output
zh_output = trans_output(zh_question ,output)
else:
zh_output = None
zh_output
ranking_output(zh_question, zh_output)
zh_question = "纽约卫生局的创立时间是什么?"
#output = from_zh_question_to_consider_queries(zh_question)
output = from_zh_question_to_consider_queries(zh_question,
top_k = 32, top_p_k = 30, top_e_k = 50
)
if type(output) == type((1,)):
query_list, output = output
zh_output = trans_output(zh_question ,output)
else:
zh_output = None
zh_output
ranking_output(zh_question, zh_output)
zh_question = "法兰西魔法部记录室位于哪个城市?"
#output = from_zh_question_to_consider_queries(zh_question)
output = from_zh_question_to_consider_queries(zh_question,
top_k = 32, top_p_k = 30, top_e_k = 50
)
if type(output) == type((1,)):
query_list, output = output
zh_output = trans_output(zh_question ,output)
else:
zh_output = None
zh_output
ranking_output(zh_question, zh_output)
zh_question = "邓布利多的出生日期是什么?"
#output = from_zh_question_to_consider_queries(zh_question)
output = from_zh_question_to_consider_queries(zh_question,
top_k = 32, top_p_k = 30, top_e_k = 50
)
if type(output) == type((1,)):
query_list, output = output
zh_output = trans_output(zh_question ,output)
else:
zh_output = None
zh_output
ranking_output(zh_question, zh_output)
zh_question = "哥布林叛乱发生在什么日期?"
#output = from_zh_question_to_consider_queries(zh_question, top_p_k = 50)
output = from_zh_question_to_consider_queries(zh_question,
top_k = 32, top_p_k = 30, top_e_k = 50
)
if type(output) == type((1,)):
query_list, output = output
zh_output = trans_output(zh_question ,output)
else:
zh_output = None
zh_output
ranking_output(zh_question, zh_output)
zh_question = "决斗比赛的参与者是谁?"
#output = from_zh_question_to_consider_queries(zh_question)
output = from_zh_question_to_consider_queries(zh_question,
top_k = 32, top_p_k = 30, top_e_k = 50
)
if type(output) == type((1,)):
query_list, output = output
zh_output = trans_output(zh_question ,output)
else:
zh_output = None
zh_output
ranking_output(zh_question, zh_output)
zh_question = "赫敏的丈夫是谁?"
#output = from_zh_question_to_consider_queries(zh_question)
output = from_zh_question_to_consider_queries(zh_question,
top_k = 32, top_p_k = 30, top_e_k = 50
)
if type(output) == type((1,)):
query_list, output = output
zh_output = trans_output(zh_question ,output)
else:
zh_output = None
zh_output
ranking_output(zh_question, zh_output)