File size: 7,667 Bytes
0a65f9d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
import pandas as pd
from typing_extensions import Any, List, Dict
from loguru import logger
from tqdm import tqdm
from .base_conversion_utils import (
    clean_query,
    build_schema_maps,
    convert_actual_code_to_modified_dict,
    convert_modified_to_actual_code_string
)
from .line_based_parsing import (
    clean_modified_dict,
    convert_to_lines,
    parse_line_based_query
)
from .schema_utils import schema_to_line_based


def modify_single_row_base_form(mongo_query: str, schema: Dict[str, Any]) -> str:
    """
    Modifies a single MongoDB query string based on the provided schema and schema maps.
    """
    try:
        # Clean the query
        mongo_query = clean_query(mongo_query)
        # Build schema maps
        in2out, out2in = build_schema_maps(schema)
        # Convert the actual code to modified code
        modified_query = convert_actual_code_to_modified_dict(mongo_query, out2in)
        # Collection Name
        collection_name = schema["collections"][0]["name"]
        # Convert the modified code back to actual code
        reconstructed_query = convert_modified_to_actual_code_string(modified_query, in2out, collection_name)
        # Clean the reconstructed query
        reconstructed_query = clean_query(reconstructed_query)
        if reconstructed_query != mongo_query:
            return None, None, None, None, None, None
        else:
            return mongo_query, modified_query, collection_name, in2out, out2in, schema
    except Exception as _:
        return None, None, None, None, None, None


def modify_all_rows_base_from(mongo_queries: List[str], schemas: List[Dict[str, Any]], nl_queries: List[str], additional_infos: List[str]) -> List[Dict[str, Any]]:
    """
    Modifies all MongoDB queries based on the provided schemas.
    """
    modified_queries = []
    for i, (mongo_query, schema) in tqdm(enumerate(zip(mongo_queries, schemas)), total=len(mongo_queries), desc="Modifying Queries"):
        mongo_query, modified_query, collection_name, in2out, out2in, schema = modify_single_row_base_form(mongo_query, schema)
        if modified_query is not None:
            modified_queries.append({
                "mongo_query": mongo_query,
                "natural_language_query": nl_queries[i],
                "additional_info": additional_infos[i],
                "modified_query": modified_query,
                "collection_name": collection_name,
                "in2out": in2out,
                "out2in": out2in,
                "schema": schema
            })
    return modified_queries


def modify_line_based_parsing(modified_query_data: str) -> Dict[str, Any]:
    """
    Tests the line-based parsing of a modified MongoDB query.
    """
    try:
        modified_query = clean_modified_dict(modified_query_data["modified_query"])
        lines = convert_to_lines(modified_query)
        reconstructed_query = parse_line_based_query(lines)
        if modified_query != reconstructed_query:
            return None
        else:
            modified_query_data["line_based_query"] = lines
            return modified_query_data
    except Exception as e:
        return None


def modify_all_line_based_parsing(modified_queries: List[Dict[str, Any]]):
    """
    Tests the line-based parsing for all modified MongoDB queries.
    """
    line_based_queries = []
    for query_data in tqdm(modified_queries, desc="Testing Line-based Parsing", total=len(modified_queries)):
        line_based_query = modify_line_based_parsing(query_data)
        if line_based_query:
            line_based_queries.append(line_based_query)
    return line_based_queries


def modify_all_schema(query_data: List[Dict[str, Any]]) -> List[str]:
    """
    Converts all schemas to line-based format.
    """
    final_data = []
    for query in tqdm(query_data, desc="Converting Schemas to Line-based Format", total=len(query_data)):
        # try:
        line_based_schema = schema_to_line_based(query["schema"])
        # if line_based_schema:
        query["line_based_schema"] = line_based_schema
        final_data.append(query)
        # except Exception as e:
        #     pass
        # logger.debug(f"Line-based schema: {line_based_schema}")
    return final_data


def load_csv(file_path: str) -> pd.DataFrame:
    """
    Loads a CSV file into a pandas DataFrame.
    """
    try:
        df = pd.read_csv(file_path)
        logger.info(f"Loaded CSV file: {file_path}")
        return df
    except Exception as e:
        logger.error(f"Error loading CSV file: {e}")
        raise e
    

def modify_dataframe(df: pd.DataFrame) -> pd.DataFrame:
    """
    Modifies a DataFrame by applying the modify_all_rows function.
    """
    logger.info("Modifying DataFrame...")
    logger.debug(f"input DataFrame length: {len(df)}")
    mongo_queries = df["mongo_query"].tolist()
    schemas = df["schema"].apply(eval).tolist()
    nl_queries = df["natural_language_query"].tolist()
    additional_infos = df["additional_info"].tolist()
    modified_queries = modify_all_rows_base_from(mongo_queries, schemas, nl_queries, additional_infos)
    logger.debug(f"Modified queries length: {len(modified_queries)}")
    line_based_queries = modify_all_line_based_parsing(modified_queries)
    logger.debug(f"Line-based queries length: {len(line_based_queries)}")
    final_data = modify_all_schema(line_based_queries)
    logger.debug(f"Modified schemas length: {len(final_data)}")
    return final_data

def main(final_data: List[Dict[str, Any]]):
    # try reconstructing original query from line-based query
    for i in range(len(final_data)):
        index_allowed = [746]
        if i in index_allowed:
            continue
        original_query = final_data[i]["mongo_query"]
        line_based_query = final_data[i]["line_based_query"]
        # reconstructed modified query
        reconstructed_modified_query = parse_line_based_query(line_based_query)
        # reconstructed original query
        reconstructed_original_query = convert_modified_to_actual_code_string(reconstructed_modified_query, final_data[i]["in2out"], final_data[i]["collection_name"])
        if original_query != clean_query(reconstructed_original_query):
            
            logger.error(f"index: {i}")
            logger.error(f"Original query: {original_query}")
            logger.error(f"Reconstructed original query: {reconstructed_original_query}")
            logger.error(f"Modified query: {final_data[i]['modified_query']}")
            logger.error(f"Reconstructed modified query: {reconstructed_modified_query}")
            logger.error(f"Line-based query: {line_based_query}")
            # logger.error(f"Schema: {final_data[i]['schema']}")
            logger.warning("--------------------------------------------------")
            assert original_query == clean_query(reconstructed_original_query), f"Original query does not match reconstructed original query at index {i}"
    exit(0)
        

if __name__ == "__main__":
    pdf_path = "./data_v3/data_v2.csv"
    df = load_csv(pdf_path)
    final_data = modify_dataframe(df)
    # main(final_data)
    logger.info(f"Final data length: {len(final_data)}")
    logger.debug(f"Final data type: {final_data[0]}\n\n")

    for i, (query_data) in enumerate(final_data):
        logger.debug(f"Modified schema {i}: {query_data['line_based_schema']}")
        logger.debug(f"Line-based query {i}: {query_data['line_based_query']}")
        logger.debug(f"NL query {i}: {query_data['natural_language_query']}")
        logger.debug(f"Additional info {i}: {query_data['additional_info']}")
        print('\n\n\n\n')
        if i > 3:
            break