|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main_path = "." |
|
|
|
import pandas as pd |
|
import numpy as np |
|
import os |
|
import ast |
|
import re |
|
import json |
|
from icecream import ic |
|
from copy import deepcopy |
|
from itertools import product, combinations |
|
|
|
|
|
import pandas as pd |
|
import os |
|
import sys |
|
from pyarrow.filesystem import LocalFileSystem |
|
from functools import reduce |
|
import nltk |
|
from nltk import pos_tag, word_tokenize |
|
from collections import namedtuple |
|
from ast import literal_eval |
|
|
|
from torch.nn import functional |
|
import numpy as np |
|
import torch |
|
from torch import nn |
|
from torch.nn import init |
|
from torch.nn.utils import rnn as rnn_utils |
|
import math |
|
|
|
from icecream import ic |
|
import seaborn as sns |
|
|
|
import matplotlib.pyplot as plt |
|
|
|
import shutil |
|
|
|
|
|
|
|
|
|
|
|
import sqlite3 |
|
import sqlite_utils |
|
from icecream import ic |
|
import jieba |
|
import pandas as pd |
|
import urllib.request |
|
from urllib.parse import quote |
|
from time import sleep |
|
import json |
|
import os |
|
from collections import defaultdict |
|
import re |
|
from functools import reduce, partial |
|
|
|
|
|
op_sql_dict = {0:">", 1:"<", 2:"==", 3:"!="} |
|
|
|
agg_sql_dict = {0:"", 1:"AVG", 2:"MAX", 3:"MIN", 4:"COUNT", 5:"SUM"} |
|
|
|
conn_sql_dict = {0:"", 1:"and", 2:"or"} |
|
|
|
|
|
and_kws = ("且", "而且", "并且", "和", "当中", "同时") |
|
or_kws = ("或", "或者",) |
|
conn_kws = and_kws + or_kws |
|
|
|
pattern_list = [u"[年月\.\-\d]+", u"[年月\d]+", u"[年个月\d]+", u"[年月日\d]+"] |
|
|
|
time_kws = ("什么时候", "时间", "时候") |
|
|
|
sum_count_high_kws = ('多少个', '有几个', '总共') + ('总和','一共',) + ("总数",) |
|
mean_kws = ('平均数', '均值', '平均值', '平均') |
|
max_kws = ('最大', '最多', '最大值', '最高') |
|
min_kws = ('最少', '最小值', '最小', '最低') |
|
sum_count_low_kws = ('个', '总共') + ('总和','加','总','一共','和',) + ("哪些", "查", "数量", "数") + ("几",) + ('多少', "多大") + ("总数",) |
|
max_special_kws = ("以上", "大于") |
|
min_special_kws = ("以下", "小于") |
|
|
|
qst_kws = ("多少", "什么", "多大", "哪些", "怎么", "情况", "那些", "哪个") |
|
|
|
only_kws_columns = {"城市": "=="} |
|
|
|
|
|
|
|
|
|
jointbert_path = os.path.join(main_path, "JointBERT-master") |
|
sys.path.append(jointbert_path) |
|
|
|
|
|
from model.modeling_jointbert import JointBERT |
|
from model.modeling_jointbert import * |
|
from trainer import * |
|
from main import * |
|
from data_loader import * |
|
|
|
|
|
pred_parser = argparse.ArgumentParser() |
|
|
|
pred_parser.add_argument("--input_file", default="conds_pred/seq.in", type=str, help="Input file for prediction") |
|
pred_parser.add_argument("--output_file", default="conds_pred/sample_pred_out.txt", type=str, help="Output file for prediction") |
|
|
|
pred_parser.add_argument("--model_dir", default= os.path.join(main_path ,"data/bert"), type=str, help="Path to save, load model") |
|
|
|
|
|
|
|
pred_parser.add_argument("--batch_size", default=32, type=int, help="Batch size for prediction") |
|
pred_parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available") |
|
|
|
|
|
pred_parser_config_dict = dict(map(lambda item:(item.option_strings[0].replace("--", ""), item.default) ,pred_parser.__dict__["_actions"])) |
|
pred_parser_config_dict = dict(filter(lambda t2: t2[0] != "-h", pred_parser_config_dict.items())) |
|
|
|
pred_parser_namedtuple = namedtuple("pred_parser_config", pred_parser_config_dict.keys()) |
|
for k, v in pred_parser_config_dict.items(): |
|
if type(v) == type(""): |
|
exec("pred_parser_namedtuple.{}='{}'".format(k, v)) |
|
else: |
|
exec("pred_parser_namedtuple.{}={}".format(k, v)) |
|
|
|
|
|
from predict import * |
|
|
|
|
|
pred_config = pred_parser_namedtuple |
|
args = get_args(pred_config) |
|
device = get_device(pred_config) |
|
|
|
args_parser_namedtuple = namedtuple("args_config", args.keys()) |
|
for k, v in args.items(): |
|
if type(v) == type(""): |
|
exec("args_parser_namedtuple.{}='{}'".format(k, v)) |
|
else: |
|
exec("args_parser_namedtuple.{}={}".format(k, v)) |
|
|
|
|
|
args = args_parser_namedtuple |
|
|
|
|
|
args.data_dir = os.path.join(main_path, "data") |
|
|
|
''' |
|
pred_model = MODEL_CLASSES["bert"][1].from_pretrained(args.model_dir, |
|
args=args, |
|
intent_label_lst=get_intent_labels(args), |
|
slot_label_lst=get_slot_labels(args)) |
|
''' |
|
pred_model = MODEL_CLASSES["bert"][1].from_pretrained( |
|
os.path.join(main_path, "data/bert") |
|
, |
|
args=args, |
|
intent_label_lst=get_intent_labels(args), |
|
slot_label_lst=get_slot_labels(args)) |
|
|
|
pred_model.to(device) |
|
pred_model.eval() |
|
|
|
intent_label_lst = get_intent_labels(args) |
|
slot_label_lst = get_slot_labels(args) |
|
pad_token_label_id = args.ignore_index |
|
tokenizer = load_tokenizer(args) |
|
|
|
|
|
|
|
|
|
def predict_single_sent(question): |
|
text = " ".join(list(question)) |
|
batch = convert_input_file_to_tensor_dataset([text.split(" ")], pred_config, args, tokenizer, pad_token_label_id).tensors |
|
batch = tuple(t.to(device) for t in batch) |
|
inputs = {"input_ids": batch[0], |
|
"attention_mask": batch[1], |
|
"intent_label_ids": None, |
|
"slot_labels_ids": None} |
|
inputs["token_type_ids"] = batch[2] |
|
outputs = pred_model(**inputs) |
|
_, (intent_logits, slot_logits) = outputs[:2] |
|
intent_preds = intent_logits.detach().cpu().numpy() |
|
slot_preds = slot_logits.detach().cpu().numpy() |
|
intent_preds = np.argmax(intent_preds, axis=1) |
|
slot_preds = np.argmax(slot_preds, axis=2) |
|
all_slot_label_mask = batch[3].detach().cpu().numpy() |
|
slot_label_map = {i: label for i, label in enumerate(slot_label_lst)} |
|
slot_preds_list = [[] for _ in range(slot_preds.shape[0])] |
|
for i in range(slot_preds.shape[0]): |
|
for j in range(slot_preds.shape[1]): |
|
if all_slot_label_mask[i, j] != pad_token_label_id: |
|
slot_preds_list[i].append(slot_label_map[slot_preds[i][j]]) |
|
pred_l = [] |
|
for words, slot_preds, intent_pred in zip([text.split(" ")], slot_preds_list, intent_preds): |
|
line = "" |
|
for word, pred in zip(words, slot_preds): |
|
if pred == 'O': |
|
line = line + word + " " |
|
else: |
|
line = line + "[{}:{}] ".format(word, pred) |
|
pred_l.append((line, intent_label_lst[intent_pred])) |
|
return pred_l[0] |
|
|
|
|
|
|
|
''' |
|
and_kws = ("且", "而且", "并且", "和", "当中", "同时") |
|
or_kws = ("或", "或者",) |
|
conn_kws = and_kws + or_kws |
|
''' |
|
|
|
|
|
def recurrent_extract(question): |
|
def filter_relation(text): |
|
|
|
kws = conn_kws |
|
req = text |
|
for kw in sorted(kws, key= lambda x: len(x))[::-1]: |
|
req = req.replace(kw, "") |
|
return req |
|
def produce_plain_text(text): |
|
|
|
kws = ["[", "]", " ", ":B-HEADER", ":I-HEADER", ":B-VALUE", ":I-VALUE"] |
|
plain_text = text |
|
for kw in kws: |
|
plain_text = plain_text.replace(kw, "") |
|
return plain_text |
|
def find_min_commmon_strings(c): |
|
|
|
common_strings = list(filter(lambda x: type(x) == type("") , |
|
map(lambda t2: t2[0] |
|
if t2[0] in t2[1] |
|
else (t2[1] |
|
if t2[1] in t2[0] |
|
else (t2[0], t2[1])),combinations(c, 2)))) |
|
req = set([]) |
|
while c: |
|
ele = c.pop() |
|
if all(map(lambda cc: cc not in ele, common_strings)): |
|
req.add(ele) |
|
req = req.union(set(common_strings)) |
|
return set(filter(lambda x: x, req)) |
|
def extract_scope(scope_text): |
|
def find_max_in(plain_text ,b_chars, i_chars): |
|
chars = "".join(b_chars + i_chars) |
|
while chars and chars not in plain_text: |
|
chars = chars[:-1] |
|
return chars |
|
b_header_chars = re.findall(r"([\w\W]):B-HEADER", scope_text) |
|
i_header_chars = re.findall(r"([\w\W]):I-HEADER", scope_text) |
|
b_value_chars = re.findall(r"([\w\W]):B-VALUE", scope_text) |
|
i_value_chars = re.findall(r"([\w\W]):I-VALUE", scope_text) |
|
if len(b_header_chars) != 1 or len(b_value_chars) != 1: |
|
return None |
|
plain_text = produce_plain_text(scope_text) |
|
header = find_max_in(plain_text, b_header_chars, i_header_chars) |
|
value = find_max_in(plain_text, b_value_chars, i_value_chars) |
|
if (not header) or (not value): |
|
return None |
|
return (header, value) |
|
def find_scope(text): |
|
start_index = text.find("[") |
|
end_index = text.rfind("]") |
|
if start_index == -1 or end_index == -1: |
|
return text |
|
scope_text = text[start_index: end_index + 1] |
|
res_text = filter_relation(text.replace(scope_text, "")).replace(" ", "").strip() |
|
return (scope_text, res_text) |
|
def produce_all_attribute_remove(req): |
|
if not req: |
|
return None |
|
string_or_t2 = find_scope(req[-1][0]) |
|
assert type(string_or_t2) in [type(""), type((1,))] |
|
if type(string_or_t2) == type(""): |
|
return string_or_t2 |
|
else: |
|
return string_or_t2[-1] |
|
def extract_all_attribute(req): |
|
extract_list = list(map(lambda t2: (t2[0][0], t2[1], t2[0][1]) , |
|
filter(lambda x: x[0] , |
|
map(lambda tt2_t2: (extract_scope(tt2_t2[0][0]), tt2_t2[1]) , |
|
filter(lambda t2_t2: "HEADER" in t2_t2[0][0] and "VALUE" in t2_t2[0][0] , |
|
filter(lambda string_or_t2_t2: type(string_or_t2_t2[0]) == type((1,)), |
|
map(lambda tttt2: (find_scope(tttt2[0]), tttt2[1]), |
|
req))))))) |
|
return extract_list |
|
def extract_attributes_relation_string(plain_text, all_attributes, res): |
|
if not all_attributes: |
|
return plain_text.replace(res if res else "", "") |
|
def replace_by_one_l_r(text ,t3): |
|
l, _, r = t3 |
|
|
|
l0, l1 = l, l |
|
r0, r1 = r, r |
|
while l0 and l0 not in text: |
|
l0 = l0[:-1] |
|
while l1 and l1 not in text: |
|
l1 = l1[1:] |
|
while r0 and r0 not in text: |
|
r0 = r0[:-1] |
|
while r1 and r1 not in text: |
|
r1 = r1[1:] |
|
if not l or not r: |
|
return text |
|
|
|
conclusion = set([]) |
|
for l_, r_ in product([l0, l1], [r0, r1]): |
|
l_r_conclusion = re.findall("({}.*?{})".format(l_, r_), text) |
|
r_l_conclusion = re.findall("({}.*?{})".format(r_, l_), text) |
|
conclusion = conclusion.union(set(l_r_conclusion + r_l_conclusion)) |
|
|
|
|
|
|
|
conclusion_filtered = find_min_commmon_strings(conclusion) |
|
|
|
conclusion = conclusion_filtered |
|
req_text = text |
|
for c in conclusion: |
|
req_text = req_text.replace(c, "") |
|
return req_text |
|
req_text_ = plain_text |
|
for t3 in all_attributes: |
|
req_text_ = replace_by_one_l_r(req_text_, t3) |
|
return req_text_.replace(res, "") |
|
req = [] |
|
t2 = predict_single_sent(question) |
|
req.append(t2) |
|
while "[" in t2[0]: |
|
scope = find_scope(t2[0]) |
|
if type(scope) == type(""): |
|
break |
|
else: |
|
assert type(scope) == type((1,)) |
|
scope_text, res_text = scope |
|
|
|
t2 = predict_single_sent(res_text) |
|
req.append(t2) |
|
req = list(filter(lambda tt2: "HEADER" in tt2[0] and "VALUE" in tt2[0] , req)) |
|
res = produce_all_attribute_remove(req) |
|
|
|
all_attributes = extract_all_attribute(req) |
|
|
|
|
|
return all_attributes, res, extract_attributes_relation_string(produce_plain_text(req[0][0] if req else ""), all_attributes, res) |
|
|
|
|
|
def rec_more_time(decomp): |
|
assert type(decomp) == type((1,)) and len(decomp) == 3 |
|
assert not decomp[0] |
|
res, relation_string = decomp[1:] |
|
new_decomp = recurrent_extract(relation_string) |
|
|
|
if not new_decomp[0] and new_decomp[1] != decomp[1]: |
|
return rec_more_time(new_decomp) |
|
return (new_decomp[0], res, new_decomp[1]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
train_path = os.path.join(main_path, "data/TableQA-master/train") |
|
def data_loader(table_json_path = os.path.join(train_path ,"train.tables.json"), |
|
json_path = os.path.join(train_path ,"train.json"), |
|
req_table_num = 1): |
|
assert os.path.exists(table_json_path) |
|
assert os.path.exists(json_path) |
|
json_df = pd.read_json(json_path, lines = True) |
|
all_tables = pd.read_json(table_json_path, lines = True) |
|
if req_table_num is not None: |
|
assert type(req_table_num) == type(0) and req_table_num > 0 and req_table_num <= all_tables.shape[0] |
|
else: |
|
req_table_num = all_tables.shape[0] |
|
for i in range(req_table_num): |
|
|
|
|
|
one_table_s = all_tables.iloc[i] |
|
one_table_df = pd.DataFrame(one_table_s["rows"], columns = one_table_s["header"]) |
|
yield one_table_df, json_df[json_df["table_id"] == one_table_s["id"]] |
|
|
|
|
|
|
|
|
|
def findMaxSubString(str1, str2): |
|
""" |
|
""" |
|
maxSub = 0 |
|
maxSubString = "" |
|
|
|
str1_len = len(str1) |
|
str2_len = len(str2) |
|
|
|
for i in range(str1_len): |
|
str1_pos = i |
|
for j in range(str2_len): |
|
str2_pos = j |
|
str1_pos = i |
|
if str1[str1_pos] != str2[str2_pos]: |
|
continue |
|
else: |
|
while (str1_pos < str1_len) and (str2_pos < str2_len): |
|
if str1[str1_pos] == str2[str2_pos]: |
|
str1_pos = str1_pos + 1 |
|
str2_pos = str2_pos + 1 |
|
else: |
|
break |
|
|
|
sub_len = str2_pos - j |
|
if maxSub < sub_len: |
|
maxSub = sub_len |
|
maxSubString = str2[j:str2_pos] |
|
return maxSubString |
|
|
|
|
|
def find_min_commmon_strings(c): |
|
|
|
common_strings = list(filter(lambda x: type(x) == type("") , |
|
map(lambda t2: t2[0] |
|
if t2[0] in t2[1] |
|
else (t2[1] |
|
if t2[1] in t2[0] |
|
else (t2[0], t2[1])),combinations(c, 2)))) |
|
req = set([]) |
|
while c: |
|
ele = c.pop() |
|
if all(map(lambda cc: cc not in ele, common_strings)): |
|
req.add(ele) |
|
req = req.union(set(common_strings)) |
|
return set(filter(lambda x: x, req)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
def time_template_extractor(rows_filtered, pattern = u"[年月\.\-\d]+"): |
|
|
|
re_words = re.compile(pattern) |
|
nest_collection = pd.DataFrame(rows_filtered).applymap(lambda x: tuple(sorted(list(re.findall(re_words, x))))).values.tolist() |
|
def flatten_collection(c): |
|
if not c: |
|
return c |
|
if type(c[0]) == type(""): |
|
return c |
|
else: |
|
c = list(c) |
|
return flatten_collection(reduce(lambda a, b: a + b, map(list ,c))) |
|
return flatten_collection(nest_collection) |
|
|
|
|
|
|
|
|
|
def justify_column_as_datetime(df, threshold = 0.8, time_template_extractor = lambda x: x): |
|
object_columns = list(map(lambda tt2: tt2[0] ,filter(lambda t2: t2[1].name == "object" ,dict(df.dtypes).items()))) |
|
time_columns = [] |
|
for col in object_columns: |
|
input_ = df[[col]].applymap(lambda x: "~" if type(x) != type("") else x) |
|
output_ = time_template_extractor(input_.values.tolist()) |
|
input_ = input_.iloc[:, 0].values.tolist() |
|
time_evidence_cnt = sum(map(lambda t2: t2[0].strip() == t2[1].strip() and t2[0] and t2[1] and t2[0] != "~" and t2[1] != "~",zip(input_, output_))) |
|
if time_evidence_cnt > 0 and time_evidence_cnt / df.shape[0] >= threshold: |
|
|
|
time_columns.append(col) |
|
return time_columns |
|
|
|
def justify_column_as_datetime_reduce(df, threshold = 0.8, time_template_extractor_list = list(map(lambda p: partial(time_template_extractor, pattern = p), pattern_list))): |
|
return sorted(reduce(lambda a, b: a.union(b) ,map(lambda func: set(justify_column_as_datetime(df, threshold, func)), time_template_extractor_list))) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def choose_question_column(decomp, header, df): |
|
assert type(decomp) == type((1,)) and type(header) == type([]) |
|
|
|
time_columns = justify_column_as_datetime_reduce(df) |
|
_, res, _ = decomp |
|
|
|
if type(res) != type(""): |
|
return None |
|
|
|
|
|
|
|
|
|
if any(map(lambda t_kw: t_kw in res, time_kws)): |
|
if len(time_columns) == 1: |
|
return time_columns[0] |
|
else: |
|
''' |
|
return sorted(map(lambda t_col: (t_col ,len(findMaxSubString(t_col, res)) / len(t_col)), time_columns), |
|
key = lambda t2: t2[1])[::-1][0][0] |
|
''' |
|
sort_list = sorted(map(lambda t_col: (t_col ,len(findMaxSubString(t_col, res)) / len(t_col)), time_columns), |
|
key = lambda t2: t2[1])[::-1] |
|
if sort_list: |
|
if sort_list[0]: |
|
return sort_list[0][0] |
|
return None |
|
|
|
c_res_common_dict = dict(filter(lambda t2: t2[1] ,map(lambda c: (c ,findMaxSubString(c, res)), header))) |
|
common_ratio_c_dict = dict(map(lambda t2: (t2[0], len(t2[1]) / len(t2[0])), c_res_common_dict.items())) |
|
common_ratio_res_dict = dict(map(lambda t2: (t2[0], len(t2[1]) / len(res)), c_res_common_dict.items())) |
|
|
|
|
|
|
|
|
|
if not common_ratio_c_dict or not common_ratio_res_dict: |
|
return None |
|
|
|
dict_0_max_key = sorted(common_ratio_c_dict.items(), key = lambda t2: t2[1])[::-1][0][0] |
|
dict_1_max_key = sorted(common_ratio_res_dict.items(), key = lambda t2: t2[1])[::-1][0][0] |
|
return dict_0_max_key if dict_0_max_key == dict_1_max_key else None |
|
|
|
|
|
|
|
''' |
|
sum_count_high_kws = ('多少个', '有几个', '总共') + ('总和','一共',) + ("总数",) |
|
mean_kws = ('平均数', '均值', '平均值', '平均') |
|
max_kws = ('最大', '最多', '最大值', '最高') |
|
min_kws = ('最少', '最小值', '最小', '最低') |
|
sum_count_low_kws = ('个', '总共') + ('总和','加','总','一共','和',) + ("哪些", "查", "数量", "数") + ("几",) + ('多少', "多大") + ("总数",) |
|
max_special_kws = ("以上", "大于") |
|
min_special_kws = ("以下", "小于") |
|
''' |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def simple_label_func(s, drop_header = True): |
|
text_tokens =s.question_cut |
|
header = list(map(lambda x: x[:x.find("(")] if (not x.startswith("(") and x.endswith(")")) else x ,s.header.split(","))) |
|
|
|
|
|
|
|
''' |
|
fit_collection = ('多少个', '有几个', '总共') + ('总和','一共',) + ('平均数', '均值', '平均值', '平均') + ('最大', '最多', '最大值', '最高') + ('最少', '最小值', '最小', '最低') |
|
|
|
''' |
|
fit_collection = sum_count_high_kws + mean_kws + max_kws + min_kws |
|
fit_header = [] |
|
for c in header: |
|
for kw in fit_collection: |
|
if kw in c: |
|
start_idx = c.find(kw) |
|
end_idx = start_idx + len(kw) |
|
fit_header.append(c[start_idx: end_idx]) |
|
|
|
if not drop_header: |
|
header = [] |
|
fit_header = [] |
|
|
|
input_ = "".join(text_tokens) |
|
for c in header + fit_header: |
|
if c in fit_collection: |
|
continue |
|
input_ = input_.replace(c, "") |
|
c0, c1 = c, c |
|
while c0 and c0 not in fit_collection and len(c0) >= 4: |
|
c0 = c0[1:] |
|
if c0 in fit_collection: |
|
break |
|
input_ = input_.replace(c0, "") |
|
while c1 and c1 not in fit_collection and len(c1) >= 4: |
|
c1 = c1[:-1] |
|
if c1 in fit_collection: |
|
break |
|
input_ = input_.replace(c1, "") |
|
|
|
|
|
text_tokens = list(jieba.cut(input_)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
cat_6_collection_high_level = sum_count_high_kws |
|
if any(map(lambda high_level_token: high_level_token in "".join(text_tokens), cat_6_collection_high_level)): |
|
return 6 |
|
|
|
|
|
if any(map(lambda kw: kw in text_tokens, mean_kws)): |
|
return 1 |
|
if any(map(lambda kw: kw in text_tokens, max_kws)): |
|
return 2 |
|
if any(map(lambda kw: kw in text_tokens, min_kws)): |
|
return 3 |
|
|
|
|
|
|
|
cat_6_collection = sum_count_low_kws |
|
if any(map(lambda kw: kw in text_tokens, cat_6_collection)): |
|
return 6 |
|
if any(map(lambda token: "几" in token, text_tokens)): |
|
return 6 |
|
|
|
|
|
if any(map(lambda kw: kw in text_tokens, max_special_kws)): |
|
return 2 |
|
if any(map(lambda kw: kw in text_tokens, min_special_kws)): |
|
return 3 |
|
|
|
|
|
return 0 |
|
|
|
|
|
def simple_special_func(s, drop_header = True): |
|
text_tokens =s.question_cut |
|
header = list(map(lambda x: x[:x.find("(")] if (not x.startswith("(") and x.endswith(")")) else x ,s.header.split(","))) |
|
|
|
|
|
|
|
fit_collection = sum_count_high_kws + mean_kws + max_kws + min_kws |
|
fit_header = [] |
|
for c in header: |
|
for kw in fit_collection: |
|
if kw in c: |
|
start_idx = c.find(kw) |
|
end_idx = start_idx + len(kw) |
|
fit_header.append(c[start_idx: end_idx]) |
|
|
|
input_ = "".join(text_tokens) |
|
if not drop_header: |
|
header = [] |
|
fit_header = [] |
|
|
|
for c in header + fit_header: |
|
if c in fit_collection: |
|
continue |
|
input_ = input_.replace(c, "") |
|
c0, c1 = c, c |
|
while c0 and c0 not in fit_collection and len(c0) >= 4: |
|
c0 = c0[1:] |
|
if c0 in fit_collection: |
|
break |
|
input_ = input_.replace(c0, "") |
|
while c1 and c1 not in fit_collection and len(c1) >= 4: |
|
c1 = c1[:-1] |
|
if c1 in fit_collection: |
|
break |
|
input_ = input_.replace(c1, "") |
|
|
|
|
|
text_tokens = list(jieba.cut(input_)) |
|
|
|
|
|
|
|
|
|
|
|
cat_6_collection_high_level = sum_count_high_kws |
|
if any(map(lambda high_level_token: high_level_token in "".join(text_tokens), cat_6_collection_high_level)): |
|
return 6 |
|
|
|
|
|
if any(map(lambda kw: kw in text_tokens, mean_kws)): |
|
return 1 |
|
if any(map(lambda kw: kw in text_tokens, max_kws)): |
|
return 2 |
|
if any(map(lambda kw: kw in text_tokens, min_kws)): |
|
return 3 |
|
|
|
return 0 |
|
|
|
|
|
def simple_total_label_func(s): |
|
is_special = False if simple_special_func(s) == 0 else True |
|
if not is_special: |
|
return 0 |
|
return simple_label_func(s) |
|
|
|
|
|
|
|
|
|
def split_by_cond(s, extract_return = True): |
|
def recurrent_extract_cond(text): |
|
|
|
|
|
return np.asarray(list(recurrent_extract(text)[0])) |
|
|
|
question = s["question"] |
|
res = s["rec_decomp"][1] |
|
if question is None: |
|
question = "" |
|
if res is None: |
|
res = "" |
|
|
|
common_res = findMaxSubString(question, res) |
|
|
|
|
|
cond_kws = conn_kws |
|
condition_part = question.replace(common_res, "") |
|
fit_kws = set([]) |
|
for kw in cond_kws: |
|
if kw in condition_part and not condition_part.startswith(kw) and not condition_part.endswith(kw): |
|
fit_kws.add(kw) |
|
if not fit_kws: |
|
will_return = ([condition_part.replace(" ", "") + " " + common_res], "") |
|
if extract_return: |
|
|
|
will_return = (np.asarray(list(map(recurrent_extract_cond, will_return[0]))) , will_return[1]) |
|
|
|
|
|
input_ = list(filter(lambda x: x.size ,map(np.asarray ,will_return[0].tolist()))) |
|
if input_: |
|
will_return = (np.concatenate(input_, axis = 0), will_return[1]) |
|
else: |
|
will_return = (np.empty((0, 3)), will_return[1]) |
|
|
|
will_return = will_return[0].reshape((-1, 3)) if will_return[0].size else np.empty((0, 3)) |
|
return will_return |
|
|
|
fit_kw = sorted(fit_kws, key = len)[::-1][0] |
|
condition_part_splits = condition_part.split(fit_kw) |
|
|
|
if fit_kw in or_kws: |
|
fit_kw = "or" |
|
|
|
elif fit_kw in and_kws: |
|
fit_kw = "and" |
|
else: |
|
fit_kw = "" |
|
|
|
will_return = (list(map(lambda cond_: cond_.replace(" ", "") + " " + common_res, condition_part_splits)), fit_kw) |
|
if extract_return: |
|
|
|
will_return = (np.asarray(list(map(recurrent_extract_cond, will_return[0]))), will_return[1]) |
|
|
|
|
|
input_ = list(filter(lambda x: x.size ,map(np.asarray ,will_return[0].tolist()))) |
|
if input_: |
|
will_return = (np.concatenate(input_, axis = 0), will_return[1]) |
|
else: |
|
will_return = (np.empty((0, 3)), will_return[1]) |
|
|
|
will_return = will_return[0].reshape((-1, 3)) if will_return[0].size else np.empty((0, 3)) |
|
|
|
return will_return |
|
|
|
|
|
|
|
def filter_total_conds(s, df, condition_fit_num = 0): |
|
assert condition_fit_num >= 0 and type(condition_fit_num) == type(0) |
|
df = df.copy() |
|
|
|
|
|
df = df.applymap(lambda x: np.nan if x in ["None", None, "/"] else x) |
|
def justify_column_as_float(s): |
|
if "float" in str(s.dtype): |
|
return True |
|
if all(s.map(type).map(lambda tx: "float" in str(tx))): |
|
return True |
|
return False |
|
|
|
float_cols = list(map(lambda tt2: tt2[0],filter(lambda t2: t2[1] ,df.apply(justify_column_as_float, axis = 0).to_dict().items()))) |
|
for f_col in float_cols: |
|
df[f_col] = df[f_col].astype(np.float64) |
|
|
|
|
|
header = df.columns.tolist() |
|
units_cols = filter(lambda c: "(" in c and c.endswith(")"), df.columns.tolist()) |
|
if not float_cols: |
|
float_discribe_df = pd.DataFrame() |
|
else: |
|
float_discribe_df = df[float_cols].describe() |
|
|
|
def call_eval(val): |
|
try: |
|
return literal_eval(val) |
|
except: |
|
return val |
|
|
|
|
|
def find_cond_col(res, header): |
|
|
|
c_res_common_dict = dict(filter(lambda t2: t2[1] ,map(lambda c: (c ,findMaxSubString(c, res)), header))) |
|
|
|
common_ratio_c_dict = dict(map(lambda t2: (t2[0], len(t2[1]) / len(t2[0])), c_res_common_dict.items())) |
|
common_ratio_res_dict = dict(map(lambda t2: (t2[0], len(t2[1]) / len(res)), c_res_common_dict.items())) |
|
|
|
if not common_ratio_c_dict or not common_ratio_res_dict: |
|
return None |
|
|
|
dict_0_max_key = sorted(common_ratio_c_dict.items(), key = lambda t2: t2[1])[::-1][0][0] |
|
dict_1_max_key = sorted(common_ratio_res_dict.items(), key = lambda t2: t2[1])[::-1][0][0] |
|
return dict_0_max_key if dict_0_max_key == dict_1_max_key else None |
|
|
|
|
|
|
|
def filter_cond_col(cond_t3): |
|
assert type(cond_t3) == type((1,)) and len(cond_t3) == 3 |
|
col, _, value = cond_t3 |
|
|
|
if type(value) == type(""): |
|
value = call_eval(value) |
|
|
|
if col not in df.columns.tolist(): |
|
return False |
|
|
|
|
|
if col in float_cols and type(value) not in (type(0), type(0.0)): |
|
return False |
|
|
|
if col not in float_cols and type(value) in (type(0), type(0.0)): |
|
return False |
|
|
|
|
|
if col not in float_cols and type(value) not in (type(0), type(0.0)): |
|
if type(value) == type(""): |
|
if value not in df[col].tolist(): |
|
return False |
|
|
|
if type(value) in (type(0), type(0.0)): |
|
if col in float_discribe_df.columns.tolist(): |
|
if condition_fit_num > 0: |
|
if value >= float_discribe_df[col].loc["min"] and value <= float_discribe_df[col].loc["max"]: |
|
return True |
|
else: |
|
return False |
|
else: |
|
assert condition_fit_num == 0 |
|
return True |
|
|
|
if condition_fit_num > 0: |
|
if value in df[col].tolist(): |
|
return True |
|
else: |
|
return False |
|
else: |
|
assert condition_fit_num == 0 |
|
return True |
|
|
|
return True |
|
|
|
|
|
|
|
|
|
def same_column_cond_filter(cond_list, sort_stats = "mean"): |
|
|
|
if len(cond_list) == len(set(map(lambda t3: t3[0] ,cond_list))): |
|
return cond_list |
|
|
|
req = defaultdict(list) |
|
for t3 in cond_list: |
|
req[t3[0]].append(t3[1:]) |
|
|
|
def t2_list_sort(col_name ,t2_list): |
|
if not t2_list: |
|
return None |
|
t2 = None |
|
if col_name in float_cols: |
|
t2 = sorted(t2_list, key = lambda t2: np.abs(t2[1] - float_discribe_df[col_name].loc[sort_stats]))[0] |
|
else: |
|
if all(map(lambda t2: type(t2[1]) == type("") ,t2_list)): |
|
col_val_cnt_df = df[col_name].value_counts().reset_index() |
|
col_val_cnt_df.columns = ["val", "cnt"] |
|
|
|
|
|
|
|
t2_list_map_to_column_val = list(filter(lambda x: x[1] ,map(lambda t2: (t2[0] ,find_cond_col(t2[1], list(set(col_val_cnt_df["val"].values.tolist())))), t2_list))) |
|
if t2_list_map_to_column_val: |
|
|
|
t2 = sorted(t2_list_map_to_column_val, key = lambda t2: -1 * len(t2[1]))[0] |
|
if t2 is None and t2_list: |
|
t2 = t2_list[0] |
|
return t2 |
|
|
|
cond_list_filtered = list(map(lambda ttt2: (ttt2[0], ttt2[1][0], ttt2[1][1]) , |
|
filter(lambda tt2: tt2[1] ,map(lambda t2: (t2[0] ,t2_list_sort(t2[0], t2[1])), req.items())))) |
|
|
|
return cond_list_filtered |
|
|
|
|
|
total_conds_map_to_column = list(map(lambda t3: (find_cond_col(t3[0], header), t3[1], t3[2]), s["total_conds"])) |
|
total_conds_map_to_column_filtered = list(filter(filter_cond_col, total_conds_map_to_column)) |
|
total_conds_map_to_column_filtered = sorted(set(map(lambda t3: (t3[0], t3[1], call_eval(t3[2]) if type(t3[2]) == type("") else t3[2]), total_conds_map_to_column_filtered))) |
|
|
|
|
|
cp_cond_list = list(filter(lambda t3: t3[1] in (">", "<"), total_conds_map_to_column_filtered)) |
|
eq_cond_list = list(filter(lambda t3: t3[1] in ("==", "!="), total_conds_map_to_column_filtered)) |
|
|
|
cp_cond_list_filtered = same_column_cond_filter(cp_cond_list) |
|
|
|
|
|
return cp_cond_list_filtered + eq_cond_list |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def augment_kw_in_question(question_df, df, only_kws_in_string = []): |
|
|
|
question_df = question_df.copy() |
|
|
|
|
|
def call_eval(val): |
|
try: |
|
return literal_eval(val) |
|
except: |
|
return val |
|
|
|
df = df.copy() |
|
|
|
df = df.applymap(call_eval) |
|
|
|
|
|
df = df.applymap(lambda x: np.nan if x in ["None", None, "/"] else x) |
|
def justify_column_as_float(s): |
|
if "float" in str(s.dtype): |
|
return True |
|
if all(s.map(type).map(lambda tx: "float" in str(tx))): |
|
return True |
|
return False |
|
|
|
float_cols = list(map(lambda tt2: tt2[0],filter(lambda t2: t2[1] ,df.apply(justify_column_as_float, axis = 0).to_dict().items()))) |
|
|
|
|
|
def justify_column_as_kw(s): |
|
if all(s.map(type).map(lambda tx: "str" in str(tx))): |
|
return True |
|
return False |
|
|
|
obj_cols = list(map(lambda tt2: tt2[0],filter(lambda t2: t2[1] ,df.apply(justify_column_as_kw, axis = 0).to_dict().items()))) |
|
obj_cols = list(set(obj_cols).difference(set(float_cols))) |
|
if only_kws_columns: |
|
obj_cols = list(set(obj_cols).intersection(set(only_kws_columns.keys()))) |
|
|
|
|
|
|
|
|
|
kw_augmented_df = df[obj_cols].copy() |
|
|
|
|
|
def extract_question_kws(question): |
|
if not kw_augmented_df.size: |
|
return [] |
|
req = defaultdict(set) |
|
for ridx, r in kw_augmented_df.iterrows(): |
|
for k, v in dict(r).items(): |
|
if v in question: |
|
pattern = "\w?{}\w?".format(v) |
|
all_match = re.findall(pattern, question) |
|
|
|
|
|
key = "{}~{}".format(k, v) |
|
req[key] = req[key].union(set(all_match)) |
|
|
|
|
|
|
|
|
|
if only_kws_in_string: |
|
req = list(map(lambda tt2: tt2[0] ,filter(lambda t2: sum(map(lambda kw: sum(map(lambda t: kw in t ,t2[1])), only_kws_in_string)), req.items()))) |
|
else: |
|
req = list(set(req.keys())) |
|
|
|
def req_to_t3(req_string, relation = "=="): |
|
assert "~" in req_string |
|
left, right = req_string.split("~") |
|
if left in only_kws_columns: |
|
relation = only_kws_columns[left] |
|
return (left, relation, right) |
|
|
|
if not req: |
|
return [] |
|
|
|
return list(map(req_to_t3, req)) |
|
|
|
|
|
|
|
question_df["question_kw_conds"] = question_df["question"].map(extract_question_kws) |
|
return question_df |
|
|
|
|
|
def choose_question_column_by_rm_conds(s, df): |
|
question = s.question |
|
total_conds_filtered = s.total_conds_filtered |
|
|
|
cond_kws = conn_kws |
|
stopwords = ("是", ) |
|
|
|
def construct_res(question): |
|
for k, _, v in total_conds_filtered: |
|
if "(" in k: |
|
k = k[:k.find("(")] |
|
|
|
question = question.replace(str(k), "") |
|
question = question.replace(str(v), "") |
|
for w in cond_kws + stopwords: |
|
question = question.replace(w, "") |
|
return question |
|
|
|
res = construct_res(question) |
|
decomp = (None, res, None) |
|
return choose_question_column(decomp, df.columns.tolist(), df) |
|
|
|
|
|
def split_qst_by_kw(question, kw = "的"): |
|
return question.split(kw) |
|
|
|
|
|
|
|
def choose_res_by_kws(question): |
|
|
|
question = question.replace(" ", "") |
|
|
|
kws = ("的",) + conn_kws |
|
kws = list(kws) |
|
def qst_kw_filter(text): |
|
|
|
if any(map(lambda kw: kw in text, qst_kws)): |
|
return True |
|
return False |
|
|
|
kws_cp = deepcopy(kws) |
|
qst_c = set(question.split(",")) |
|
while kws: |
|
kw = kws.pop() |
|
qst_c = qst_c.union(set(filter(qst_kw_filter ,reduce(lambda a, b: a + b,map(lambda q: split_qst_by_kw(q, kw), qst_c)))) |
|
) |
|
|
|
|
|
|
|
|
|
return sorted(filter(lambda x: x and (x not in kws_cp) and qst_kw_filter(x) ,qst_c), key = len) |
|
|
|
|
|
|
|
def cat6_to_45_by_column_type(s, df): |
|
agg_pred = s.agg_pred |
|
if agg_pred != 6: |
|
return agg_pred |
|
question_column = s.question_column |
|
|
|
def call_eval(val): |
|
try: |
|
return literal_eval(val) |
|
except: |
|
return val |
|
|
|
df = df.copy() |
|
|
|
df = df.applymap(call_eval) |
|
|
|
|
|
df = df.applymap(lambda x: np.nan if x in ["None", None, "/"] else x) |
|
def justify_column_as_float(s): |
|
if "float" in str(s.dtype): |
|
return True |
|
if all(s.map(type).map(lambda tx: "float" in str(tx))): |
|
return True |
|
return False |
|
|
|
float_cols = list(map(lambda tt2: tt2[0],filter(lambda t2: t2[1] ,df.apply(justify_column_as_float, axis = 0).to_dict().items()))) |
|
|
|
|
|
def justify_column_as_kw(s): |
|
if all(s.map(type).map(lambda tx: "str" in str(tx))): |
|
return True |
|
return False |
|
|
|
|
|
obj_cols = df.columns.tolist() |
|
obj_cols = list(set(obj_cols).difference(set(float_cols))) |
|
|
|
|
|
assert len(obj_cols) + len(float_cols) == df.shape[1] |
|
|
|
if question_column in obj_cols: |
|
return 4 |
|
elif question_column in float_cols: |
|
return 5 |
|
else: |
|
return 0 |
|
|
|
|
|
def full_before_cat_decomp(df, question_df, only_req_columns = False): |
|
df, question_df = df.copy(), question_df.copy() |
|
first_train_question_extract_df = pd.DataFrame(question_df["question"].map(lambda question: (question, recurrent_extract(question))).tolist()) |
|
first_train_question_extract_df.columns = ["question", "decomp"] |
|
|
|
first_train_question_extract_df = augment_kw_in_question(first_train_question_extract_df, df) |
|
|
|
first_train_question_extract_df["rec_decomp"] = first_train_question_extract_df["decomp"].map(lambda decomp: decomp if decomp[0] else rec_more_time(decomp)) |
|
|
|
first_train_question_extract_df["question_cut"] = first_train_question_extract_df["rec_decomp"].map(lambda t3: jieba.cut(t3[1]) if t3[1] is not None else []).map(list) |
|
first_train_question_extract_df["header"] = ",".join(df.columns.tolist()) |
|
first_train_question_extract_df["question_column_res"] = first_train_question_extract_df["rec_decomp"].map(lambda decomp: choose_question_column(decomp, df.columns.tolist(), df)) |
|
|
|
|
|
first_train_question_extract_df["agg_res_pred"] = first_train_question_extract_df.apply(simple_total_label_func, axis = 1) |
|
first_train_question_extract_df["question_cut"] = first_train_question_extract_df["question"].map(jieba.cut).map(list) |
|
first_train_question_extract_df["agg_qst_pred"] = first_train_question_extract_df.apply(simple_total_label_func, axis = 1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
first_train_question_extract_df["agg_pred"] = first_train_question_extract_df.apply(lambda s: max(s.agg_res_pred, s.agg_qst_pred), axis = 1) |
|
|
|
|
|
first_train_question_extract_df["conds"] = first_train_question_extract_df["rec_decomp"].map(lambda x: x[0]) |
|
first_train_question_extract_df["split_conds"] = first_train_question_extract_df.apply(split_by_cond, axis = 1).values.tolist() |
|
first_train_question_extract_df["conn_pred"] = first_train_question_extract_df.apply(lambda s: split_by_cond(s, extract_return=False), axis = 1).map(lambda x: x[-1]).values.tolist() |
|
|
|
first_train_question_extract_df["total_conds"] = first_train_question_extract_df.apply(lambda s: list(set(map(tuple,s["question_kw_conds"] + s["conds"] + s["split_conds"].tolist()))), axis = 1).values.tolist() |
|
first_train_question_extract_df["total_conds_filtered"] = first_train_question_extract_df.apply(lambda s: filter_total_conds(s, df), axis = 1).values.tolist() |
|
|
|
|
|
|
|
|
|
|
|
first_train_question_extract_df["question_column_qst"] = first_train_question_extract_df["question"].map(choose_res_by_kws).map(lambda res_list: list(filter(lambda x: x ,map(lambda res: choose_question_column((None, res, None), df.columns.tolist(), df), res_list)))).map(lambda x: x[0] if x else None) |
|
first_train_question_extract_df["question_column"] = first_train_question_extract_df.apply(lambda s: s.question_column_res if s.question_column_res is not None else s.question_column_qst, axis = 1) |
|
|
|
|
|
|
|
|
|
agg_sql_dict = {0:"", 1:"AVG", 2:"MAX", 3:"MIN", 4:"COUNT", 5:"SUM"} |
|
first_train_question_extract_df["agg_pred"] = first_train_question_extract_df.apply(lambda s: cat6_to_45_by_column_type(s, df), axis = 1).map(lambda x: agg_sql_dict[x]) |
|
if only_req_columns: |
|
return first_train_question_extract_df[["question", |
|
"total_conds_filtered", |
|
"conn_pred", |
|
"question_column", |
|
"agg_pred" |
|
]].copy() |
|
|
|
return first_train_question_extract_df.copy() |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
req = list(data_loader(req_table_num=None)) |
|
|
|
|
|
train_df, _ = req[2] |
|
train_df |
|
question = "哪些股票的收盘价大于20?" |
|
qs_df = pd.DataFrame([[question]], columns = ["question"]) |
|
ic(question) |
|
ic(full_before_cat_decomp(train_df, qs_df, only_req_columns=True)) |
|
|
|
|
|
|
|
|
|
question = "哪个股票代码市值最高?" |
|
qs_df = pd.DataFrame([[question]], columns = ["question"]) |
|
ic(question) |
|
ic(full_before_cat_decomp(train_df, qs_df, only_req_columns=True)) |
|
|
|
question = "市值的最大值是多少?" |
|
qs_df = pd.DataFrame([[question]], columns = ["question"]) |
|
ic(question) |
|
ic(full_before_cat_decomp(train_df, qs_df, only_req_columns=True)) |
|
|
|
|
|
question = "EPS大于0的股票有哪些?" |
|
qs_df = pd.DataFrame([[question]], columns = ["question"]) |
|
ic(question) |
|
ic(full_before_cat_decomp(train_df, qs_df, only_req_columns=True)) |
|
|
|
|
|
question = "EPS大于0且周涨跌大于5的平均市值是多少?" |
|
qs_df = pd.DataFrame([[question]], columns = ["question"]) |
|
ic(question) |
|
ic(full_before_cat_decomp(train_df, qs_df, only_req_columns=True)) |
|
|
|
|
|
train_df, _ = req[5] |
|
train_df |
|
question = "产能小于20、销量大于40而且市场占有率超过1的公司有哪些?" |
|
qs_df = pd.DataFrame([[question]], columns = ["question"]) |
|
ic(question) |
|
ic(full_before_cat_decomp(train_df, qs_df, only_req_columns=True)) |
|
|
|
|
|
|
|
question = "产能小于20而且销量大于40而且市场占有率超过1的公司有哪些?" |
|
qs_df = pd.DataFrame([[question]], columns = ["question"]) |
|
ic(question) |
|
ic(full_before_cat_decomp(train_df, qs_df, only_req_columns=True)) |
|
|
|
|
|
train_df, _ = req[6] |
|
train_df |
|
|
|
question = "投资评级为维持的名称有哪些?" |
|
qs_df = pd.DataFrame([[question]], columns = ["question"]) |
|
ic(question) |
|
ic(full_before_cat_decomp(train_df, qs_df, only_req_columns=True)) |
|
|
|
|
|
train_df["公司"] = train_df["名称"] |
|
question = "投资评级为维持的公司有哪些?" |
|
qs_df = pd.DataFrame([[question]], columns = ["question"]) |
|
ic(question) |
|
ic(full_before_cat_decomp(train_df, qs_df, only_req_columns=True)) |
|
|
|
|
|
question = "投资评级为维持而且变动为增持的公司有哪些?" |
|
qs_df = pd.DataFrame([[question]], columns = ["question"]) |
|
ic(question) |
|
ic(full_before_cat_decomp(train_df, qs_df, only_req_columns=True)) |
|
|
|
|
|
question = "投资评级为维持或者变动为增持的公司有哪些?" |
|
qs_df = pd.DataFrame([[question]], columns = ["question"]) |
|
ic(question) |
|
ic(full_before_cat_decomp(train_df, qs_df, only_req_columns=True)) |
|
|
|
|
|
question = "投资评级为维持或者变动为增持的平均收盘价是多少?" |
|
qs_df = pd.DataFrame([[question]], columns = ["question"]) |
|
ic(question) |
|
ic(full_before_cat_decomp(train_df, qs_df, only_req_columns=True)) |
|
|
|
|
|
train_df, _ = req[7] |
|
train_df |
|
question = "宁波的一手房每周交易数据上周成交量是多少?" |
|
qs_df = pd.DataFrame([[question]], columns = ["question"]) |
|
ic(question) |
|
ic(full_before_cat_decomp(train_df, qs_df, only_req_columns=True)) |
|
|
|
|
|
question = "一手房每周交易数据为宁波上周成交量是多少?" |
|
qs_df = pd.DataFrame([[question]], columns = ["question"]) |
|
ic(question) |
|
ic(full_before_cat_decomp(train_df, qs_df, only_req_columns=True)) |
|
|
|
|
|
|
|
train_df["城市"] = train_df["一手房每周交易数据"] |
|
question = "一手房每周交易数据为宁波上周成交量是多少?" |
|
qs_df = pd.DataFrame([[question]], columns = ["question"]) |
|
ic(question) |
|
ic(full_before_cat_decomp(train_df, qs_df, only_req_columns=True)) |
|
|
|
question = "王翔知道宁波一手房的当月累计成交量是多少吗?" |
|
qs_df = pd.DataFrame([[question]], columns = ["question"]) |
|
ic(question) |
|
ic(full_before_cat_decomp(train_df, qs_df, only_req_columns=True)) |
|
|
|
|
|
question = "王翔知道上周成交量大于50的最大同比当月是多少吗?" |
|
qs_df = pd.DataFrame([[question]], columns = ["question"]) |
|
ic(question) |
|
ic(full_before_cat_decomp(train_df, qs_df, only_req_columns=True)) |
|
|
|
|
|
train_df, _ = req[9] |
|
|
|
train_df |
|
cols = train_df.columns.tolist() |
|
cols[-1] = "周跌幅(%)" |
|
train_df.columns = cols |
|
question = "周涨幅大于7的涨股有哪些?" |
|
qs_df = pd.DataFrame([[question]], columns = ["question"]) |
|
ic(question) |
|
ic(full_before_cat_decomp(train_df, qs_df, only_req_columns=True)) |
|
|
|
|
|
|
|
question = "周涨幅大于7的涨股总数是多少?" |
|
qs_df = pd.DataFrame([[question]], columns = ["question"]) |
|
ic(question) |
|
ic(full_before_cat_decomp(train_df, qs_df, only_req_columns=True)) |
|
|
|
|
|
question = "周涨幅大于7的涨股总共有多少个?" |
|
qs_df = pd.DataFrame([[question]], columns = ["question"]) |
|
ic(question) |
|
ic(full_before_cat_decomp(train_df, qs_df, only_req_columns=True)) |
|
|