sofhiaazzhr commited on
Commit
73b7fe3
Β·
1 Parent(s): 36ffff4

make executors self-contained, remove redundant pre-filter

Browse files
src/query/executors/tabular.py CHANGED
@@ -1,13 +1,9 @@
1
  """Executor for tabular document sources (source_type="document", file_type csv/xlsx).
2
 
3
- Receives sheet-level RetrievalResults from SchemaRetriever (each result
4
- represents a relevant sheet, with its full column list available via
5
- data.column_names in metadata).
6
-
7
  Flow:
8
  1. Group RetrievalResult chunks by (document_id, sheet_name).
9
  2. Per group: download Parquet from Azure Blob β†’ pandas DataFrame.
10
- 3. Build schema context from full DataFrame columns + sample values.
11
  4. LLM decides operation (groupby_sum, filter, top_n, etc.) via structured output.
12
  5. Pandas runs the operation; retry up to 3x on error with feedback to LLM.
13
  6. Fallback to raw rows if all retries fail.
@@ -50,12 +46,8 @@ IMPORTANT rules:
50
  - For filter with comparison (>, <, >=, <=, !=): set filter_operator accordingly (gt, lt, gte, lte, ne). Default is eq (==).
51
  - For multi-condition filters (AND logic), use the filters field as a list of {{"col", "value", "op"}} dicts instead of filter_col/filter_value.
52
  Example: status=SUCCESS AND amount_paid>200000 β†’ filters=[{{"col":"status","value":"SUCCESS","op":"eq"}},{{"col":"amount_paid","value":"200000","op":"gt"}}]
53
- - IMPORTANT: When the question uses "or" / "atau" between values of the same column, you MUST use or_filters (NOT filters).
54
- or_filters applies OR logic: rows matching ANY condition are kept.
55
- filters applies AND logic: rows must match ALL conditions.
56
- Example: "(status FAILED or REVERSED) AND payment_channel=Tokopedia" β†’
57
- or_filters=[{{"col":"status","value":"FAILED","op":"eq"}},{{"col":"status","value":"REVERSED","op":"eq"}}]
58
- filters=[{{"col":"payment_channel","value":"Tokopedia","op":"eq"}}]
59
  - For groupby with a pre-filter (e.g. count SUCCESS per channel): use filters or or_filters to narrow rows first, then use groupby_count/groupby_sum/groupby_avg on the filtered data by setting both filters and group_col.
60
 
61
  Schema:
@@ -85,9 +77,6 @@ class TabularOperation(BaseModel):
85
 
86
  def _get_filter_mask(df: pd.DataFrame, col: str, value: str, operator: str) -> pd.Series:
87
  numeric = pd.to_numeric(df[col], errors="coerce")
88
- coerced_nulls = numeric.isnull() & df[col].notna()
89
- if coerced_nulls.any():
90
- logger.warning("numeric coercion introduced NaN", col=col, count=int(coerced_nulls.sum()))
91
  if operator == "eq":
92
  return df[col].astype(str) == str(value)
93
  elif operator == "ne":
@@ -104,23 +93,7 @@ def _get_filter_mask(df: pd.DataFrame, col: str, value: str, operator: str) -> p
104
 
105
 
106
  def _apply_single_filter(df: pd.DataFrame, col: str, value: str, operator: str) -> pd.DataFrame:
107
- numeric = pd.to_numeric(df[col], errors="coerce")
108
- coerced_nulls = numeric.isnull() & df[col].notna()
109
- if coerced_nulls.any():
110
- logger.warning("numeric coercion introduced NaN", col=col, count=int(coerced_nulls.sum()))
111
- if operator == "eq":
112
- return df[df[col].astype(str) == str(value)]
113
- elif operator == "ne":
114
- return df[df[col].astype(str) != str(value)]
115
- elif operator == "gt":
116
- return df[numeric > float(value)]
117
- elif operator == "gte":
118
- return df[numeric >= float(value)]
119
- elif operator == "lt":
120
- return df[numeric < float(value)]
121
- elif operator == "lte":
122
- return df[numeric <= float(value)]
123
- raise ValueError(f"Unknown operator: {operator}")
124
 
125
 
126
  def _build_schema_context(df: pd.DataFrame) -> str:
@@ -181,15 +154,9 @@ def _apply_operation(df: pd.DataFrame, op: TabularOperation, limit: int) -> pd.D
181
  raise ValueError(f"sort requires sort_col, got {op}")
182
  return df.sort_values(op.sort_col, ascending=op.ascending).head(limit)
183
  elif op.operation == "aggregate":
184
- if not op.agg_func:
185
- raise ValueError(f"aggregate requires agg_func, got {op}")
186
- if op.agg_func == "count":
187
- if not op.value_col:
188
- return pd.DataFrame([{"column_name": c, "dtype": str(df[c].dtype)} for c in df.columns])
189
- return pd.DataFrame([{"count": int(df[op.value_col].count()), "operation": "count"}])
190
- if not op.value_col:
191
- raise ValueError(f"aggregate requires value_col for {op.agg_func}, got {op}")
192
- funcs = {"sum": "sum", "avg": "mean", "min": "min", "max": "max"}
193
  value = getattr(df[op.value_col], funcs[op.agg_func])()
194
  return pd.DataFrame([{op.value_col: value, "operation": op.agg_func}])
195
  else: # "raw"
@@ -279,7 +246,6 @@ class TabularExecutor(BaseExecutor):
279
  )
280
  return None
281
 
282
- # Each group runs independently β€” cross-file JOIN is out of scope for v1.
283
  gathered = await asyncio.gather(*[
284
  _process_group(doc_id, sheet_name, info)
285
  for (doc_id, sheet_name), info in groups.items()
@@ -313,7 +279,7 @@ class TabularExecutor(BaseExecutor):
313
  prev_error = str(e)
314
  logger.warning("tabular agent error", attempt=attempt + 1, error=prev_error)
315
 
316
- # Fallback: return raw rows (all columns β€” chat.py caps rows at 20 before LLM)
317
  logger.warning("tabular agent failed after retries, returning raw rows")
318
  return df.head(limit)
319
 
 
1
  """Executor for tabular document sources (source_type="document", file_type csv/xlsx).
2
 
 
 
 
 
3
  Flow:
4
  1. Group RetrievalResult chunks by (document_id, sheet_name).
5
  2. Per group: download Parquet from Azure Blob β†’ pandas DataFrame.
6
+ 3. Build schema context from DataFrame columns + sample values.
7
  4. LLM decides operation (groupby_sum, filter, top_n, etc.) via structured output.
8
  5. Pandas runs the operation; retry up to 3x on error with feedback to LLM.
9
  6. Fallback to raw rows if all retries fail.
 
46
  - For filter with comparison (>, <, >=, <=, !=): set filter_operator accordingly (gt, lt, gte, lte, ne). Default is eq (==).
47
  - For multi-condition filters (AND logic), use the filters field as a list of {{"col", "value", "op"}} dicts instead of filter_col/filter_value.
48
  Example: status=SUCCESS AND amount_paid>200000 β†’ filters=[{{"col":"status","value":"SUCCESS","op":"eq"}},{{"col":"amount_paid","value":"200000","op":"gt"}}]
49
+ - For OR conditions on a column (e.g. value is A or B), use or_filters. Combine with filters for mixed AND+OR logic.
50
+ Example: (status=FAILED OR status=REVERSED) AND payment_channel=X β†’ or_filters=[{{"col":"status","value":"FAILED","op":"eq"}},{{"col":"status","value":"REVERSED","op":"eq"}}], filters=[{{"col":"payment_channel","value":"X","op":"eq"}}]
 
 
 
 
51
  - For groupby with a pre-filter (e.g. count SUCCESS per channel): use filters or or_filters to narrow rows first, then use groupby_count/groupby_sum/groupby_avg on the filtered data by setting both filters and group_col.
52
 
53
  Schema:
 
77
 
78
  def _get_filter_mask(df: pd.DataFrame, col: str, value: str, operator: str) -> pd.Series:
79
  numeric = pd.to_numeric(df[col], errors="coerce")
 
 
 
80
  if operator == "eq":
81
  return df[col].astype(str) == str(value)
82
  elif operator == "ne":
 
93
 
94
 
95
  def _apply_single_filter(df: pd.DataFrame, col: str, value: str, operator: str) -> pd.DataFrame:
96
+ return df[_get_filter_mask(df, col, value, operator)]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
 
99
  def _build_schema_context(df: pd.DataFrame) -> str:
 
154
  raise ValueError(f"sort requires sort_col, got {op}")
155
  return df.sort_values(op.sort_col, ascending=op.ascending).head(limit)
156
  elif op.operation == "aggregate":
157
+ if not op.value_col or not op.agg_func:
158
+ raise ValueError(f"aggregate requires value_col and agg_func, got {op}")
159
+ funcs = {"sum": "sum", "avg": "mean", "min": "min", "max": "max", "count": "count"}
 
 
 
 
 
 
160
  value = getattr(df[op.value_col], funcs[op.agg_func])()
161
  return pd.DataFrame([{op.value_col: value, "operation": op.agg_func}])
162
  else: # "raw"
 
246
  )
247
  return None
248
 
 
249
  gathered = await asyncio.gather(*[
250
  _process_group(doc_id, sheet_name, info)
251
  for (doc_id, sheet_name), info in groups.items()
 
279
  prev_error = str(e)
280
  logger.warning("tabular agent error", attempt=attempt + 1, error=prev_error)
281
 
282
+ # Fallback: return raw rows
283
  logger.warning("tabular agent failed after retries, returning raw rows")
284
  return df.head(limit)
285
 
src/query/query_executor.py CHANGED
@@ -22,19 +22,9 @@ class QueryExecutor:
22
  question: str,
23
  limit: int = 100,
24
  ) -> list[QueryResult]:
25
- db_results = [r for r in results if r.source_type == "database"]
26
- tabular_results = [
27
- r for r in results
28
- if r.source_type == "document"
29
- and r.metadata.get("data", {}).get("file_type") in ("csv", "xlsx")
30
- ]
31
-
32
- async def _empty() -> list[QueryResult]:
33
- return []
34
-
35
  batches = await asyncio.gather(
36
- db_executor.execute(db_results, user_id, db, question, limit) if db_results else _empty(),
37
- tabular_executor.execute(tabular_results, user_id, db, question, limit) if tabular_results else _empty(),
38
  return_exceptions=True,
39
  )
40
 
 
22
  question: str,
23
  limit: int = 100,
24
  ) -> list[QueryResult]:
 
 
 
 
 
 
 
 
 
 
25
  batches = await asyncio.gather(
26
+ db_executor.execute(results, user_id, db, question, limit),
27
+ tabular_executor.execute(results, user_id, db, question, limit),
28
  return_exceptions=True,
29
  )
30