|
|
|
|
|
from qa import * |
|
from translate_by_api import * |
|
from extract_by_api import * |
|
from extract_et_by_api import * |
|
|
|
import os |
|
import logging |
|
import subprocess |
|
import time |
|
from pathlib import Path |
|
|
|
from haystack.nodes import Text2SparqlRetriever |
|
from haystack.document_stores import GraphDBKnowledgeGraph, InMemoryKnowledgeGraph |
|
|
|
|
|
import pandas as pd |
|
import numpy as np |
|
import os |
|
import sys |
|
|
|
|
|
from functools import partial, reduce, lru_cache |
|
|
|
|
|
|
|
|
|
from time import time |
|
|
|
from itertools import product |
|
|
|
|
|
from urllib.parse import unquote |
|
|
|
import requests |
|
import json |
|
|
|
import pandas as pd |
|
import numpy as np |
|
import os |
|
import sys |
|
|
|
|
|
from functools import partial, reduce, lru_cache |
|
|
|
|
|
|
|
|
|
from time import time |
|
|
|
from itertools import product |
|
|
|
|
|
|
|
|
|
from rapidfuzz import fuzz |
|
import synonyms |
|
|
|
import sys |
|
|
|
|
|
|
|
|
|
|
|
''' |
|
p00 = os.path.join(model_path, "zh_en_m2m") |
|
assert os.path.exists(p00) |
|
zh_en_naive_model = EasyNMT(p00) |
|
zh_en_naive_model.translate(["宁波在哪?"], source_lang="zh", target_lang = "en") |
|
''' |
|
|
|
''' |
|
from haystack.nodes import FARMReader |
|
#question_reader_save_path = "/Users/svjack/temp/model/en_zh_question_reader_save_epc_2_spo" |
|
question_reader_save_path = os.path.join(model_path, "en_zh_question_reader_save_epc_2_spo") |
|
assert os.path.exists(question_reader_save_path) |
|
en_zh_reader = FARMReader(model_name_or_path=question_reader_save_path, use_gpu=False, |
|
num_processes = 0 |
|
) |
|
''' |
|
|
|
kg = InMemoryKnowledgeGraph(index="tutorial_10_index") |
|
kg.delete_index() |
|
kg.create_index() |
|
|
|
kg.import_from_ttl_file(index="tutorial_10_index", path=Path("data") / "triples.ttl") |
|
|
|
|
|
|
|
|
|
|
|
import re |
|
def transform_namespace_to_prefix_str(g): |
|
namespaces = g.namespaces() |
|
return "\n".join(map(lambda x: "PREFIX {}: <{}>".format(x[0], x[1]), namespaces)) |
|
|
|
|
|
|
|
|
|
wiki_prefix = ''' |
|
PREFIX brick: <https://brickschema.org/schema/Brick#> |
|
PREFIX csvw: <http://www.w3.org/ns/csvw#> |
|
PREFIX dc: <http://purl.org/dc/elements/1.1/> |
|
PREFIX dcat: <http://www.w3.org/ns/dcat#> |
|
PREFIX dcmitype: <http://purl.org/dc/dcmitype/> |
|
PREFIX dcterms: <http://purl.org/dc/terms/> |
|
PREFIX dcam: <http://purl.org/dc/dcam/> |
|
PREFIX doap: <http://usefulinc.com/ns/doap#> |
|
PREFIX foaf: <http://xmlns.com/foaf/0.1/> |
|
PREFIX odrl: <http://www.w3.org/ns/odrl/2/> |
|
PREFIX org: <http://www.w3.org/ns/org#> |
|
PREFIX owl: <http://www.w3.org/2002/07/owl#> |
|
PREFIX prof: <http://www.w3.org/ns/dx/prof/> |
|
PREFIX prov: <http://www.w3.org/ns/prov#> |
|
PREFIX qb: <http://purl.org/linked-data/cube#> |
|
PREFIX rdf: <http://www.w3.org/1999/02/22-rdf-syntax-ns#> |
|
PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#> |
|
PREFIX schema: <https://schema.org/> |
|
PREFIX sh: <http://www.w3.org/ns/shacl#> |
|
PREFIX skos: <http://www.w3.org/2004/02/skos/core#> |
|
PREFIX sosa: <http://www.w3.org/ns/sosa/> |
|
PREFIX ssn: <http://www.w3.org/ns/ssn/> |
|
PREFIX time: <http://www.w3.org/2006/time#> |
|
PREFIX vann: <http://purl.org/vocab/vann/> |
|
PREFIX void: <http://rdfs.org/ns/void#> |
|
PREFIX xsd: <http://www.w3.org/2001/XMLSchema#> |
|
PREFIX xml: <http://www.w3.org/XML/1998/namespace> |
|
PREFIX hp: <https://deepset.ai/harry_potter/> |
|
''' |
|
|
|
prefix_s = pd.Series(wiki_prefix.split("\n")).map( |
|
lambda x: x if x.startswith("PREFIX") else np.nan |
|
).dropna().map( |
|
lambda x: re.findall("PREFIX (.*): <", x) |
|
).map(lambda x: x[0]) |
|
|
|
|
|
prefix_url_dict = dict(map( |
|
lambda y: (y.split(" ")[1].replace(":", ""), y.split(" ")[2].strip()[1:-1]) |
|
,filter( |
|
lambda x: x.strip() |
|
, wiki_prefix.split("\n")))) |
|
|
|
url_prefix_dict = dict(map(lambda t2: t2[::-1], prefix_url_dict.items())) |
|
|
|
all_triples = kg.get_all_triples() |
|
spo_df = pd.DataFrame(all_triples) |
|
spo_df_simple = spo_df.copy() |
|
spo_df_simple = spo_df_simple.applymap(lambda x: x["value"]).applymap(lambda x: |
|
(list(filter(lambda t2: x.startswith(t2[0]), url_prefix_dict.items()))[0], x) if any(map(lambda t2: x.startswith(t2[0]), url_prefix_dict.items())) else (None, x) |
|
).applymap( |
|
lambda t2: t2[1].replace(t2[0][0], "{}:".format(t2[0][1])) if t2[0] is not None else t2[1] |
|
).applymap(unquote) |
|
|
|
''' |
|
#### like property in wikidata |
|
spo_df_simple["p"].map( |
|
lambda x: x[3:] if x.startswith("hp:") else np.nan |
|
).dropna().value_counts() |
|
|
|
#### others in p col (rdf:type) |
|
spo_df_simple["p"].map( |
|
lambda x: x if not x.startswith("hp:") else np.nan |
|
).dropna().value_counts() |
|
|
|
#### groupby different entity type view |
|
pd.concat( |
|
list(map( |
|
lambda t2: t2[1].head(2), |
|
list(spo_df_simple[ |
|
spo_df_simple["p"] == "rdf:type" |
|
].sort_values(by = ["o", "s"]).groupby("o")) |
|
)), axis = 0).head(30) |
|
''' |
|
|
|
|
|
|
|
|
|
|
|
spo_trans_total_df = pd.read_csv("data/spo_trans_total.csv") |
|
spo_trans_dict = dict(spo_trans_total_df.values.tolist()) |
|
''' |
|
with open("../data/spo_trans_dict.json", "w") as f: |
|
json.dump(spo_trans_dict, f) |
|
''' |
|
|
|
spo_trans_back_dict = dict(map(lambda t2: t2[::-1], spo_trans_dict.items())) |
|
spo_df_simple_keyed = spo_df_simple.copy() |
|
|
|
def map_to_trans_key(src): |
|
x = str(src) |
|
if not x.startswith("hp:"): |
|
return np.nan |
|
return x[3:].replace('"', '').replace("'", '').replace("_", " ") |
|
|
|
spo_df_simple_trans = spo_df_simple_keyed.applymap( |
|
lambda x: (x ,map_to_trans_key(x)) |
|
).applymap( |
|
lambda t2: spo_trans_dict.get(t2[1], t2[0]) if type(t2[1]) == type("") else t2[0] |
|
) |
|
|
|
''' |
|
pd.concat( |
|
list(map( |
|
lambda t2: t2[1].head(2), |
|
list(spo_df_simple_trans[ |
|
spo_df_simple_trans["p"] == "rdf:type" |
|
].sort_values(by = ["o", "s"]).groupby("o")) |
|
)), axis = 0).head(50) |
|
|
|
spo_df_simple_trans[ |
|
spo_df_simple_trans["s"] == "斯蒂芬-康福特" |
|
] |
|
''' |
|
|
|
model_dir = "data/" |
|
kgqa_retriever = Text2SparqlRetriever(knowledge_graph=kg, model_name_or_path=model_dir + "hp_v3.4") |
|
|
|
def decode_query(eng_query ,kgqa_retriever, top_k = 3): |
|
self = kgqa_retriever |
|
inputs = self.tok([eng_query], max_length=100, truncation=True, return_tensors="pt") |
|
|
|
temp = self.model.generate( |
|
inputs["input_ids"], num_beams=max(5, top_k + 2), max_length=100, num_return_sequences=top_k + 2, early_stopping=True |
|
) |
|
sparql_queries = [ |
|
self.tok.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in temp |
|
] |
|
return sparql_queries |
|
|
|
import re |
|
from uuid import uuid1 |
|
import jionlp as jio |
|
|
|
special_match_token_list = [ |
|
" filter(", |
|
] |
|
|
|
def fill_bk(str_): |
|
|
|
req = [] |
|
cnt = 0 |
|
have_match_one = False |
|
for char in str_: |
|
|
|
if char == "(": |
|
cnt += 1 |
|
have_match_one = True |
|
if char == ")": |
|
cnt -= 1 |
|
req.append(char) |
|
if cnt == 0 and have_match_one: |
|
break |
|
return "".join(req) |
|
|
|
def match_special_token(query, special_match_token_list): |
|
assert type(query) == type("") |
|
assert type(special_match_token_list) == type([]) |
|
special_match_token_list_ = list(filter(lambda x: x in query, special_match_token_list)) |
|
if not special_match_token_list_: |
|
return [] |
|
return list(map(lambda x: (x , |
|
fill_bk( |
|
query[query.find(x):] |
|
) |
|
), special_match_token_list_)) |
|
|
|
def retrieve_sent_split(sent, |
|
stops_split_pattern = "|".join(map(lambda x: r"\{}".format(x), |
|
" ")) |
|
): |
|
if not sent.strip(): |
|
return [] |
|
|
|
split_list = re.split(stops_split_pattern, sent) |
|
return split_list |
|
|
|
import jionlp as jio |
|
|
|
ask_l = [ |
|
"?answer", "?value", "?obj", "?sbj", "?s", "?x", "?a" |
|
] |
|
ask_ner = jio.ner.LexiconNER({ |
|
"ask": ask_l |
|
}) |
|
|
|
def query_to_t3(query, filter_list = [], ask_ner = ask_ner): |
|
''' |
|
query = query.replace("?answer", " ?answer ") |
|
query = query.replace("?value", " ?value ") |
|
query = query.replace("?obj", " ?obj ") |
|
query = query.replace("?sbj", " ?sbj ") |
|
query = query.replace("?s", " ?s ") |
|
query = query.replace("?x", " ?x ") |
|
''' |
|
l = ask_ner(query) |
|
l = sorted(set(map(lambda x: x["text"], l)), key = len, reverse = True) |
|
|
|
for k in l: |
|
query = query.replace(k, " {} ".format(k)) |
|
|
|
''' |
|
if "where" not in query and "WHERE" not in query: |
|
return [] |
|
''' |
|
|
|
special_token_list = match_special_token(query, special_match_token_list) |
|
|
|
if special_token_list: |
|
special_token_list = list(set(map(lambda t2: t2[1] ,special_token_list))) |
|
uid_special_token_dict = dict(map(lambda x: (str(uuid1()), x), special_token_list)) |
|
special_token_uid_dict = dict(map(lambda t2: t2[::-1], uid_special_token_dict.items())) |
|
assert len(special_token_uid_dict) == len(uid_special_token_dict) |
|
|
|
for k, v in sorted(special_token_uid_dict.items(), key = lambda t2: len(t2[0]), reverse = True): |
|
if k in query: |
|
|
|
query = query.replace(k, "") |
|
else: |
|
uid_special_token_dict = {} |
|
special_token_uid_dict = {} |
|
|
|
''' |
|
if "where" in query: |
|
tail = "where".join(query.split("where")[1:]) |
|
elif "WHERE" in query: |
|
tail = "WHERE".join(query.split("WHERE")[1:]) |
|
''' |
|
|
|
query = query.strip() |
|
if not query.endswith("}"): |
|
query = query + "}" |
|
tail = re.findall(r"{(.*)}", query) |
|
|
|
|
|
if not tail: |
|
return [] |
|
else: |
|
tail = tail[0] |
|
|
|
t3_list = list(map(lambda x: x.strip() ,tail.split("."))) |
|
t3_list_ = [] |
|
for ele in t3_list: |
|
for k, v in uid_special_token_dict.items(): |
|
if k in ele: |
|
ele = ele.replace(k, v) |
|
t3_list_.append(ele) |
|
t3_list = t3_list_ |
|
|
|
if filter_list: |
|
t3_list = list(filter(lambda x: |
|
any(map(lambda y: y in x ,filter_list)) |
|
, t3_list)) |
|
t3_list = list(map(lambda x: |
|
list(filter(lambda y: y.strip() ,retrieve_sent_split(x))) |
|
, t3_list)) |
|
return t3_list |
|
|
|
def decode_property(eng_query ,kgqa_retriever, top_k = 3): |
|
sparql_queries = decode_query(eng_query, kgqa_retriever, top_k = top_k) |
|
if not sparql_queries: |
|
return [] |
|
t3_nest_list = list(map(lambda x: query_to_t3(x), sparql_queries)) |
|
|
|
p_nest_list = [] |
|
for ele in t3_nest_list: |
|
for e in ele: |
|
if len(e) == 3: |
|
p_nest_list.append(e) |
|
|
|
if not p_nest_list: |
|
return [] |
|
p_nest_list = list(map(lambda x: x[1], p_nest_list)) |
|
return p_nest_list |
|
|
|
''' |
|
#### ori query decoder |
|
query = "Harry Potter live in which house?" |
|
query = "when was Stephen cornfoot born?" |
|
decode_query(query, kgqa_retriever) |
|
|
|
#### ori query decoder only maintain property part |
|
query = "Harry Potter live in which house in 1920?" |
|
query = "Harry live in where?" |
|
query = "Harry live in where?" |
|
query = "when was Stephen cornfoot born?" |
|
query = "what is Stephen's loyalty?" |
|
decode_property(query, kgqa_retriever) |
|
|
|
query = "who is the leader of Divination homework meeting?" |
|
''' |
|
|
|
def template_fullfill_reconstruct_query(entity_list = ["http://www.wikidata.org/entity/Q42780"] |
|
, property_list = ["http://www.wikidata.org/prop/direct/P131", |
|
"http://www.wikidata.org/prop/direct/P150" |
|
], |
|
generate_t3_func = lambda el, pl: pd.Series(list(product(el, pl))).map( |
|
lambda ep: [(ep[0], ep[1], "?a"), ("?a", ep[1], ep[0])] |
|
).explode().dropna().drop_duplicates().tolist() |
|
): |
|
assert type(entity_list) == type([]) |
|
assert type(property_list) == type([]) |
|
if not entity_list or not property_list: |
|
return [] |
|
query_list = list(map(list ,generate_t3_func(entity_list, property_list))) |
|
if not query_list: |
|
return [] |
|
req = list(map(lambda x: "select ?a {" + " ".join(x) + "}", query_list)) |
|
return req |
|
|
|
''' |
|
sparql_queries_reconstruct = template_fullfill_reconstruct_query( |
|
["hp:Divination_homework_meeting"], |
|
["hp:leader"] |
|
) |
|
sparql_queries_reconstruct |
|
''' |
|
|
|
def run_sparql_queries(sparql_queries, kgqa_retriever, top_k = 3): |
|
self = kgqa_retriever |
|
answers = [] |
|
for sparql_query in sparql_queries: |
|
ans, query = self._query_kg(sparql_query=sparql_query) |
|
if len(ans) > 0: |
|
answers.append((ans, query)) |
|
|
|
if len(answers) == 0: |
|
answers.append(("", "")) |
|
results = answers[:top_k] |
|
results = [self.format_result(result) for result in results] |
|
return results |
|
|
|
''' |
|
#### one conclusion |
|
run_sparql_queries(sparql_queries_reconstruct, kgqa_retriever) |
|
''' |
|
|
|
|
|
def retrieve_et(zh_question, only_e = True): |
|
assert type(zh_question) == type("") |
|
''' |
|
qst = zh_question |
|
rep = requests.post( |
|
url = "http://localhost:8855/extract_et", |
|
data = { |
|
"question": qst |
|
} |
|
) |
|
output = json.loads(rep.content.decode()) |
|
''' |
|
output = call_entity_property_extract(zh_question) |
|
if only_e: |
|
return output.get("E-TAG", []) |
|
return output |
|
|
|
''' |
|
#### start qa server |
|
def retrieve_head(zh_question): |
|
req = requests.post( |
|
url = "http://localhost:8811/qa_downstream_process", |
|
data = { |
|
"entity": "", |
|
"question": zh_question, |
|
"context": zh_question |
|
} |
|
) |
|
output = json.loads(req.content.decode()) |
|
if "head" in output: |
|
return output["head"] |
|
return "" |
|
''' |
|
def retrieve_head(zh_question): |
|
output = qa_downstream_process( |
|
"", zh_question, zh_question |
|
) |
|
assert type(output) == type({}) |
|
if "head" in output: |
|
return output["head"] |
|
return "" |
|
|
|
''' |
|
zh_question = "谁是占卜会议的领导者?" |
|
retrieve_et(zh_question) |
|
''' |
|
|
|
def property_and_type_slice(spo_df_simple_trans, p_l = [], type_l = []): |
|
req = spo_df_simple_trans.copy() |
|
if type_l: |
|
s_l = req[ |
|
req["o"].isin(type_l) |
|
]["s"].drop_duplicates().dropna().values.tolist() |
|
req = req[ |
|
req["s"].isin(s_l) |
|
] |
|
if req.size == 0: |
|
return None |
|
if p_l: |
|
s_l = req[ |
|
req["p"].isin(p_l) |
|
]["s"].drop_duplicates().dropna().values.tolist() |
|
req = req[ |
|
req["s"].isin(s_l) |
|
] |
|
if req.size == 0: |
|
return None |
|
return req |
|
|
|
''' |
|
### Organisation_ sanple |
|
property_and_type_slice( |
|
spo_df_simple_trans, p_l = ["创立"], type_l = ["hp:Organisation_"] |
|
).sort_values(by = "s")["s"].drop_duplicates().sample(n = 30) |
|
|
|
### people sample |
|
property_and_type_slice( |
|
spo_df_simple_trans, p_l = ["出生"], type_l = ["hp:Individual_"] |
|
).sort_values(by = "s")["s"].drop_duplicates().sample(n = 30) |
|
|
|
zh_question = "谁是占卜会议的领导者?" |
|
en_question = zh_en_naive_model.translate([zh_question], source_lang="zh", target_lang = "en")[0] |
|
en_properties = decode_property(en_question, kgqa_retriever) |
|
en_properties |
|
''' |
|
|
|
all_en_p = spo_df_simple["p"].drop_duplicates().dropna().values.tolist() |
|
all_en_p_tokens = pd.Series(list(map(lambda x: x[3:].split("_") ,filter(lambda x: x.startswith("hp:"), all_en_p)))).explode().dropna().map( |
|
lambda x: x if bool(x) else np.nan |
|
).dropna().drop_duplicates().values.tolist() |
|
|
|
|
|
all_p_df = pd.Series(all_en_p).reset_index().iloc[:, 1:] |
|
all_p_df.columns = ["en_p"] |
|
all_p_df = all_p_df[ |
|
all_p_df["en_p"] != "rdf:type" |
|
] |
|
all_p_df["zh_p"] = all_p_df["en_p"].map( |
|
lambda x: spo_trans_dict.get(x.replace("hp:", "").replace("_", " "), x.replace("hp:", "").replace("_", " ")) |
|
) |
|
|
|
|
|
|
|
decode_map_config_dict = { |
|
"hp:birth": 'hp:born', |
|
'hp:birthday': "hp:born" |
|
} |
|
|
|
|
|
decode_sim_config_dict = { |
|
'hp:ingredients': "hp:characteristics", |
|
"hp:characteristics": 'hp:ingredients' |
|
} |
|
|
|
def decode_property_link_to_ori(decode_property, all_en_p, all_en_p_tokens, equal_threshold = 80): |
|
if not decode_property.startswith("hp:") or not len(decode_property) >= 3: |
|
return None |
|
if decode_property in all_en_p: |
|
return [(decode_property, 100.0)] |
|
if decode_property in decode_map_config_dict: |
|
return [(decode_map_config_dict[decode_property], 99.0)] |
|
def filter_by_p_tokens(decode_property): |
|
req = [] |
|
for ele in decode_property[3:].split("_"): |
|
if ele in all_en_p_tokens: |
|
req.append(ele) |
|
return "hp:{}".format("_".join(req)) |
|
if decode_property == "hp:": |
|
return None |
|
decode_property = filter_by_p_tokens(decode_property) |
|
order_list = sorted(map(lambda x: (x, fuzz.ratio(x, decode_property)), all_en_p), key = lambda t2: t2[1], reverse = True) |
|
return order_list[:10] |
|
|
|
''' |
|
#### minimize maintain one token sorted. |
|
decode_property_link_to_ori("hp:born", all_en_p, all_en_p_tokens, equal_threshold = 80) |
|
decode_property_link_to_ori("hp:birth", all_en_p, all_en_p_tokens, equal_threshold = 80) |
|
decode_property_link_to_ori("hp:head_of_the_assembly", all_en_p, all_en_p_tokens, equal_threshold = 80) |
|
''' |
|
|
|
|
|
def output_to_dict(output, trans_keys = ["answers"]): |
|
non_trans_t2_list = list(filter(lambda t2: t2[0] not in trans_keys, output.items())) |
|
trans_t2_list = list(map(lambda tt2: ( |
|
tt2[0], |
|
list(map(lambda x: x.to_dict(), tt2[1])) |
|
) ,filter(lambda t2: t2[0] in trans_keys, output.items()))) |
|
|
|
return dict(trans_t2_list + non_trans_t2_list) |
|
|
|
def zh_question_to_p_zh_en_map(zh_question, top_k = 3): |
|
|
|
|
|
en_question = call_zh_en_naive_model(zh_question) |
|
en_properties = decode_property(en_question, kgqa_retriever, top_k = top_k) |
|
if not en_properties: |
|
return None |
|
en_properties_top_sort = pd.Series(en_properties).value_counts().index.tolist() |
|
en_properties_mapped = list(map( |
|
lambda x: decode_property_link_to_ori(x, all_en_p, all_en_p_tokens, equal_threshold = 80), en_properties_top_sort |
|
)) |
|
en_properties_mapped = list(filter(lambda x: hasattr(x, "__len__") and len(x) >= 1, en_properties_mapped)) |
|
if not en_properties_mapped: |
|
return None |
|
en_properties_mapped = list(map(lambda x: x[0] ,en_properties_mapped)) |
|
en_properties_mapped_df = pd.DataFrame(en_properties_mapped) |
|
assert en_properties_mapped_df.shape[1] == 2 |
|
en_properties_mapped_df.columns = ["en_property", "score"] |
|
''' |
|
en_properties_mapped_df["zh_property"] = en_properties_mapped_df["en_property"].map( |
|
lambda x: en_zh_reader.predict_on_texts( |
|
question=x.replace("hp:", ""), |
|
texts=[zh_question] |
|
) |
|
).map(output_to_dict) |
|
''' |
|
en_properties_mapped_df["zh_property"] = en_properties_mapped_df["en_property"].map( |
|
lambda x: call_en_zh_reader( |
|
x.replace("hp:", ""), |
|
zh_question |
|
) |
|
) |
|
en_properties_mapped_df["zh_property"] = en_properties_mapped_df["zh_property"].map(lambda x: x["answers"][0] if x["answers"] else {}) |
|
en_properties_mapped_df = en_properties_mapped_df[ |
|
en_properties_mapped_df["zh_property"].map(bool) |
|
] |
|
if en_properties_mapped_df is None or en_properties_mapped_df.size == 0: |
|
return None |
|
|
|
en_properties_mapped_df["ext_score"] = en_properties_mapped_df["zh_property"].map( |
|
lambda x: x["score"] |
|
) |
|
en_properties_mapped_df["zh_property"] = en_properties_mapped_df["zh_property"].map( |
|
lambda x: x["answer"] |
|
) |
|
''' |
|
en_properties_mapped_df = en_properties_mapped_df[ |
|
en_properties_mapped_df["ext_score"].map(lambda x: x > score_threshold) |
|
] |
|
''' |
|
if en_properties_mapped_df is None or en_properties_mapped_df.size == 0: |
|
return None |
|
ask_head = retrieve_head(zh_question) |
|
|
|
if type(ask_head) == type(""): |
|
|
|
first_d = en_properties_mapped_df.iloc[0].to_dict() |
|
first_d["zh_property"] = ask_head |
|
en_properties_mapped_df = pd.DataFrame( |
|
[first_d] + en_properties_mapped_df.apply(lambda x: x.to_dict(), axis = 1).values.tolist() |
|
) |
|
else: |
|
pass |
|
en_properties_mapped_df = en_properties_mapped_df[ |
|
en_properties_mapped_df["zh_property"].map(lambda x: bool(x)) |
|
].drop_duplicates() |
|
return en_properties_mapped_df |
|
|
|
def search_sym_p(question_p_df, all_p_df): |
|
|
|
|
|
req = [] |
|
for idx, r in question_p_df.iterrows(): |
|
all_p_score_df = all_p_df.copy() |
|
all_p_score_df["zh_property"] = [r["zh_property"]] * len(all_p_score_df) |
|
all_p_score_df["en_property"] = [r["en_property"]] * len(all_p_score_df) |
|
req.append(all_p_score_df) |
|
req = pd.concat(req, axis = 0) |
|
req["zh_sim"] = req.apply( |
|
lambda x: synonyms.compare(x["zh_property"], x["zh_p"]), axis = 1 |
|
) |
|
req = req.sort_values(by = "zh_sim", ascending = False) |
|
return req |
|
|
|
all_en_ents = pd.Series(spo_df_simple[["s", "o"]].values.reshape([-1])).drop_duplicates().values.tolist() |
|
all_ents_df = pd.Series(all_en_ents).reset_index().iloc[:, 1:] |
|
all_ents_df.columns = ["en_ent"] |
|
all_ents_df = all_ents_df[ |
|
all_ents_df["en_ent"] != "rdf:type" |
|
] |
|
all_ents_df["zh_ent"] = all_ents_df["en_ent"].map( |
|
lambda x: spo_trans_dict.get(x.replace("hp:", "").replace("_", " "), x.replace("hp:", "").replace("_", " ")) |
|
) |
|
|
|
def search_sym_entity(entity_str, all_ents_df, use_syn = False): |
|
|
|
|
|
req = all_ents_df.copy() |
|
req["entity_str"] = [entity_str] * len(req) |
|
if use_syn: |
|
req["zh_sim"] = req.apply( |
|
lambda x: synonyms.compare(x["zh_ent"], x["entity_str"]), axis = 1 |
|
) |
|
else: |
|
req["zh_sim"] = req.apply( |
|
lambda x: fuzz.ratio(x["zh_ent"], x["entity_str"]), axis = 1 |
|
) |
|
req = req.sort_values(by = "zh_sim", ascending = False) |
|
return req |
|
|
|
zh_question = "谁是占卜会议的领导者?" |
|
zh_question = "洛林出生在哪个国家?" |
|
zh_question = "洛林出生在哪个地方?" |
|
zh_question = "洛林的血缘是什么?" |
|
zh_question = "洛林的生日是什么?" |
|
zh_question = "洛林的家族是什么?" |
|
zh_question = "洛林的性别是什么?" |
|
zh_question = "洛林的标题是什么?" |
|
zh_question = "洛林的主题是什么?" |
|
zh_question = "这个物品的特征是什么?" |
|
zh_question = "强效祛斑药水的特征是什么?" |
|
zh_question = "魔法学校的成立日期是什么?" |
|
zh_question = "魔法学校的校长是谁?" |
|
question_p_df = zh_question_to_p_zh_en_map(zh_question) |
|
|
|
|
|
|
|
|
|
sym_p_df = search_sym_p(question_p_df, all_p_df) |
|
|
|
|
|
''' |
|
#### this can be done, all related with translate accurate |
|
entity_str = "占卜会议" |
|
search_sym_entity(entity_str, all_ents_df) |
|
|
|
#### re translate in massive times |
|
pd.Series(list(spo_trans_dict.keys())).to_csv("../data/all_consider.csv", index = False) |
|
''' |
|
|
|
|
|
''' |
|
sparql_queries_reconstruct = template_fullfill_reconstruct_query( |
|
["hp:Divination_homework_meeting"], |
|
["hp:leader"] |
|
) |
|
sparql_queries_reconstruct |
|
''' |
|
|
|
def from_zh_question_to_consider_queries(zh_question, top_k = 32, top_p_k = 5, top_e_k = 50, kgqa_retriever = kgqa_retriever,): |
|
zh_ents = retrieve_et(zh_question) |
|
if type(zh_ents) != type([]) or not zh_ents: |
|
return None |
|
question_p_df = zh_question_to_p_zh_en_map(zh_question, top_k = top_p_k) |
|
if not hasattr(question_p_df, "size") or question_p_df.size == 0: |
|
return None |
|
|
|
sym_p_df = search_sym_p(question_p_df, all_p_df) |
|
if not hasattr(sym_p_df, "size") or sym_p_df.size == 0: |
|
return None |
|
sim_entity_df_list = [] |
|
for entity_str in zh_ents: |
|
sym_ent_df = search_sym_entity(entity_str, all_ents_df) |
|
if not hasattr(sym_ent_df, "size") or sym_ent_df.size == 0: |
|
continue |
|
sim_entity_df_list.append(sym_ent_df) |
|
if type(sim_entity_df_list) != type([]) or not sim_entity_df_list: |
|
return None |
|
|
|
|
|
sym_ent_df = pd.concat(sim_entity_df_list, axis = 0).sort_values(by = "zh_sim", ascending = False) |
|
|
|
|
|
top_p = sym_p_df["en_p"].drop_duplicates().dropna().head(top_p_k).values.tolist() |
|
top_e = sym_ent_df["en_ent"].drop_duplicates().dropna().head(top_e_k).values.tolist() |
|
|
|
print( |
|
top_e |
|
) |
|
print( |
|
top_p |
|
) |
|
|
|
if not top_p or not top_e: |
|
return None |
|
|
|
sparql_queries_reconstruct = template_fullfill_reconstruct_query( |
|
top_e, |
|
top_p |
|
) |
|
|
|
if not sparql_queries_reconstruct: |
|
return None |
|
|
|
output = run_sparql_queries(sparql_queries_reconstruct, kgqa_retriever, top_k = top_k) |
|
return sparql_queries_reconstruct ,output |
|
|
|
def trans_output(zh_question ,output): |
|
if type(output) != type([]): |
|
return output |
|
def single_trans(d): |
|
assert type(d) == type({}) |
|
if not d: |
|
return d |
|
req = {} |
|
answer = d.get("answer") |
|
if type(answer) == type([]): |
|
answer = list(map(lambda x: |
|
spo_trans_dict.get(x.split("/")[-1].replace("_", " "), |
|
x.split("/")[-1].replace("_", " ") |
|
) if x.startswith("https://deepset.ai/harry_potter") else x |
|
, answer)) |
|
sparql_query = d.get("prediction_meta") |
|
if sparql_query is not None: |
|
sparql_query = sparql_query.get("sparql_query") |
|
if type(sparql_query) == type(""): |
|
t3_in_query = query_to_t3(sparql_query) |
|
hp_l = pd.Series(np.asarray(t3_in_query).reshape([-1])).map(lambda x: x[3:] if x.startswith("hp:") else np.nan).dropna().drop_duplicates().values.tolist() |
|
for ele in sorted(hp_l, key = len, reverse = True): |
|
sparql_query = sparql_query.replace(ele, spo_trans_dict.get(ele.split("/")[-1].replace("_", " "), |
|
ele.split("/")[-1].replace("_", " "))) |
|
if answer is not None: |
|
req["answer"] = answer |
|
if sparql_query is not None: |
|
req["sparql_query"] = sparql_query |
|
return req |
|
output_trans = list(map(single_trans, output)) |
|
output_trans = sorted(output_trans, key = lambda d: |
|
synonyms.compare(zh_question, " " if d.get("sparql_query", " ") else " ") if type(d) == type({}) else 0.0 |
|
, reverse = True) |
|
return output_trans |
|
|
|
def ranking_output(zh_question, zh_output): |
|
e_t_dict = retrieve_et(zh_question, only_e=False) |
|
e = e_t_dict.get("E-TAG", []) |
|
t = e_t_dict.get("T-TAG", []) |
|
e, t = map(" ".join, [e, t]) |
|
print(e, t) |
|
df = pd.DataFrame(zh_output) |
|
df = df.explode("answer") |
|
|
|
df["e_score"] = df["sparql_query"].map(lambda x: re.findall("{(.*)}" ,x)[0]).map(lambda x: |
|
list(filter(lambda y: "?" not in y , |
|
list(np.asarray(x.split())[[0, -1]]) |
|
)) |
|
).map(" ".join).map(lambda x: |
|
[e, x.split(":")[-1]] |
|
).map(lambda x: list(map(lambda y: |
|
y.replace(" ", "") ,x))).map(lambda x: |
|
fuzz.ratio(*x)) |
|
df["t_score"] = df["sparql_query"].map(lambda x: re.findall("{(.*)}" ,x)[0]).map(lambda x: |
|
list(filter(lambda y: "?" not in y , |
|
x.split()[1] |
|
)) |
|
).map(" ".join).map(lambda x: |
|
[t, x.split(":")[-1]] |
|
).map(lambda x: list(map(lambda y: |
|
y.replace(" ", "") ,x))).map(lambda x: |
|
fuzz.ratio(*x)) |
|
|
|
|
|
|
|
df["et_score"] = df[["e_score", "t_score", ]].sum(axis = 1) |
|
df = df.sort_values(by = "et_score", ascending = False) |
|
if df["et_score"].iloc[0] >= 50: |
|
return df |
|
df["e_score"] = df["sparql_query"].map(lambda x: re.findall("{(.*)}" ,x)[0]).map(lambda x: |
|
list(filter(lambda y: "?" not in y , |
|
list(np.asarray(x.split())[[0, -1]]) |
|
)) |
|
).map(" ".join).map(lambda x: |
|
[e, x.split(":")[-1]] |
|
).map(lambda x: list(map(lambda y: |
|
y.replace(" ", "") ,x))).map(lambda x: |
|
synonyms.compare(*x)) |
|
df["t_score"] = df["sparql_query"].map(lambda x: re.findall("{(.*)}" ,x)[0]).map(lambda x: |
|
list(filter(lambda y: "?" not in y , |
|
x.split()[1] |
|
)) |
|
).map(" ".join).map(lambda x: |
|
[t, x.split(":")[-1]] |
|
).map(lambda x: list(map(lambda y: |
|
y.replace(" ", "") ,x))).map(lambda x: |
|
synonyms.compare(*x)) |
|
|
|
|
|
|
|
df["et_score"] = df[["e_score", "t_score", ]].sum(axis = 1) |
|
df = df.sort_values(by = "et_score", ascending = False) |
|
return df |
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
zh_question = "哈利波特的血缘是什么?" |
|
|
|
output = from_zh_question_to_consider_queries(zh_question, |
|
top_k = 32, top_p_k = 30, top_e_k = 50 |
|
) |
|
if type(output) == type((1,)): |
|
query_list, output = output |
|
zh_output = trans_output(zh_question ,output) |
|
else: |
|
zh_output = None |
|
zh_output |
|
ranking_output(zh_question, zh_output) |
|
|
|
zh_question = "哈利波特的生日是什么?" |
|
|
|
output = from_zh_question_to_consider_queries(zh_question, |
|
top_k = 32, top_p_k = 30, top_e_k = 50 |
|
) |
|
if type(output) == type((1,)): |
|
query_list, output = output |
|
zh_output = trans_output(zh_question ,output) |
|
else: |
|
zh_output = None |
|
zh_output |
|
ranking_output(zh_question, zh_output) |
|
|
|
zh_question = "史内普的生日是什么时候?" |
|
|
|
output = from_zh_question_to_consider_queries(zh_question, |
|
top_k = 32, top_p_k = 30, top_e_k = 50 |
|
) |
|
if type(output) == type((1,)): |
|
query_list, output = output |
|
zh_output = trans_output(zh_question ,output) |
|
else: |
|
zh_output = None |
|
zh_output |
|
ranking_output(zh_question, zh_output) |
|
|
|
zh_question = "占卜会议的领导者是谁?" |
|
|
|
output = from_zh_question_to_consider_queries(zh_question, |
|
top_k = 32, top_p_k = 30, top_e_k = 50 |
|
) |
|
if type(output) == type((1,)): |
|
query_list, output = output |
|
zh_output = trans_output(zh_question ,output) |
|
else: |
|
zh_output = None |
|
zh_output |
|
ranking_output(zh_question, zh_output) |
|
|
|
zh_question = "纽约卫生局的创立时间是什么?" |
|
|
|
output = from_zh_question_to_consider_queries(zh_question, |
|
top_k = 32, top_p_k = 30, top_e_k = 50 |
|
) |
|
if type(output) == type((1,)): |
|
query_list, output = output |
|
zh_output = trans_output(zh_question ,output) |
|
else: |
|
zh_output = None |
|
zh_output |
|
ranking_output(zh_question, zh_output) |
|
|
|
zh_question = "法兰西魔法部记录室位于哪个城市?" |
|
|
|
output = from_zh_question_to_consider_queries(zh_question, |
|
top_k = 32, top_p_k = 30, top_e_k = 50 |
|
) |
|
if type(output) == type((1,)): |
|
query_list, output = output |
|
zh_output = trans_output(zh_question ,output) |
|
else: |
|
zh_output = None |
|
zh_output |
|
ranking_output(zh_question, zh_output) |
|
|
|
zh_question = "邓布利多的出生日期是什么?" |
|
|
|
output = from_zh_question_to_consider_queries(zh_question, |
|
top_k = 32, top_p_k = 30, top_e_k = 50 |
|
) |
|
if type(output) == type((1,)): |
|
query_list, output = output |
|
zh_output = trans_output(zh_question ,output) |
|
else: |
|
zh_output = None |
|
zh_output |
|
ranking_output(zh_question, zh_output) |
|
|
|
zh_question = "哥布林叛乱发生在什么日期?" |
|
|
|
output = from_zh_question_to_consider_queries(zh_question, |
|
top_k = 32, top_p_k = 30, top_e_k = 50 |
|
) |
|
if type(output) == type((1,)): |
|
query_list, output = output |
|
zh_output = trans_output(zh_question ,output) |
|
else: |
|
zh_output = None |
|
zh_output |
|
ranking_output(zh_question, zh_output) |
|
|
|
zh_question = "决斗比赛的参与者是谁?" |
|
|
|
output = from_zh_question_to_consider_queries(zh_question, |
|
top_k = 32, top_p_k = 30, top_e_k = 50 |
|
) |
|
if type(output) == type((1,)): |
|
query_list, output = output |
|
zh_output = trans_output(zh_question ,output) |
|
else: |
|
zh_output = None |
|
zh_output |
|
ranking_output(zh_question, zh_output) |
|
|
|
zh_question = "赫敏的丈夫是谁?" |
|
|
|
output = from_zh_question_to_consider_queries(zh_question, |
|
top_k = 32, top_p_k = 30, top_e_k = 50 |
|
) |
|
if type(output) == type((1,)): |
|
query_list, output = output |
|
zh_output = trans_output(zh_question ,output) |
|
else: |
|
zh_output = None |
|
zh_output |
|
ranking_output(zh_question, zh_output) |
|
|