ishaq101 sofhiaazzhr commited on
Commit
8802920
·
1 Parent(s): 52999bc

Make a query for tabular (XLSX and CSV) (#14)

Browse files

- [NOTICKET] add software to gitignore (d43ecb180036339cf287e93f1d9116bd6eff9b9d)
- [NOTICKET] add pyarrow (e50eadc82e160012e3319267f5bfe084cd9034d4)
- [KM-515][document] Make Query for Tabular Type (XLSX & CSV) (695ca0a154a68077c51499ed83b7507b988be065)
- [KM-455][document] decided methods retrieval for document (8c9cc79223eb1d96c8622d129219a83a4ba2500b)


Co-authored-by: Sofhia Az-Zahra <sofhiaazzhr@users.noreply.huggingface.co>

.gitignore CHANGED
@@ -39,4 +39,7 @@ playground_create_user.py
39
  API_CONTRACT.md
40
  context_engineering/
41
  sample_file/
42
- test_tesseract.py
 
 
 
 
39
  API_CONTRACT.md
40
  context_engineering/
41
  sample_file/
42
+ test_tesseract.py
43
+
44
+ # Windows binaries — installed via apt in Docker instead
45
+ software/
pyproject.toml CHANGED
@@ -90,6 +90,7 @@ dependencies = [
90
  "pdf2image>=1.17.0",
91
  "pytesseract>=0.3.13",
92
  "pypdf2>=3.0.1",
 
93
  ]
94
 
95
  [project.optional-dependencies]
 
90
  "pdf2image>=1.17.0",
91
  "pytesseract>=0.3.13",
92
  "pypdf2>=3.0.1",
93
+ "pyarrow>=24.0.0",
94
  ]
95
 
96
  [project.optional-dependencies]
src/query/executors/tabular.py CHANGED
@@ -1,39 +1,311 @@
1
  """Executor for tabular document sources (source_type="document", file_type csv/xlsx).
2
 
3
  Flow:
4
- 1. Group RetrievalResult chunks by document_id.
5
- 2. For each document: download bytes from Azure Blob -> read with pandas.
6
- 3. Filter DataFrame to relevant columns identified by retrieval.
7
- 4. Return QueryResult per document.
 
 
 
8
  """
 
 
9
 
 
 
 
 
10
  from sqlalchemy.ext.asyncio import AsyncSession
11
 
 
 
12
  from src.middlewares.logging import get_logger
13
  from src.query.base import BaseExecutor, QueryResult
14
  from src.rag.base import RetrievalResult
15
 
16
  logger = get_logger("tabular_executor")
17
 
 
 
 
 
 
 
 
18
  _TABULAR_FILE_TYPES = ("csv", "xlsx")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
 
21
  class TabularExecutor(BaseExecutor):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  async def execute(
23
  self,
24
  results: list[RetrievalResult],
25
  user_id: str,
26
- db: AsyncSession,
 
27
  limit: int = 100,
28
  ) -> list[QueryResult]:
29
- # TODO: implement
30
- # 1. filter results where source_type == "document" and file_type in _TABULAR_FILE_TYPES
31
- # 2. group by document_id -> list of column_names
32
- # 3. per group: look up Document by document_id -> get blob_name
33
- # 4. blob_storage.download_file(blob_name) -> pd.read_csv / pd.read_excel
34
- # 5. df[relevant_columns].head(limit) -> rows as list[dict]
35
- # 6. return QueryResult per document
36
- raise NotImplementedError
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
 
39
  tabular_executor = TabularExecutor()
 
1
  """Executor for tabular document sources (source_type="document", file_type csv/xlsx).
2
 
3
  Flow:
4
+ 1. Group RetrievalResult chunks by (document_id, sheet_name).
5
+ 2. Per group: download Parquet from Azure Blob pandas DataFrame.
6
+ 3. Build schema context from DataFrame columns + sample values.
7
+ 4. LLM decides operation (groupby_sum, filter, top_n, etc.) via structured output.
8
+ 5. Pandas runs the operation; retry up to 3x on error with feedback to LLM.
9
+ 6. Fallback to raw rows if all retries fail.
10
+ 7. Return QueryResult per group.
11
  """
12
+ import asyncio
13
+ from typing import Literal, TypedDict
14
 
15
+ import pandas as pd
16
+ from langchain_core.prompts import ChatPromptTemplate
17
+ from langchain_openai import AzureChatOpenAI
18
+ from pydantic import BaseModel
19
  from sqlalchemy.ext.asyncio import AsyncSession
20
 
21
+ from src.config.settings import settings
22
+ from src.knowledge.parquet_service import download_parquet
23
  from src.middlewares.logging import get_logger
24
  from src.query.base import BaseExecutor, QueryResult
25
  from src.rag.base import RetrievalResult
26
 
27
  logger = get_logger("tabular_executor")
28
 
29
+
30
+ class _GroupInfo(TypedDict):
31
+ columns: list[str]
32
+ filename: str
33
+ file_type: str
34
+
35
+
36
  _TABULAR_FILE_TYPES = ("csv", "xlsx")
37
+ _MAX_RETRIES = 3
38
+
39
+ _SYSTEM_PROMPT = """\
40
+ You are a data analyst. Given a DataFrame schema and a user question, \
41
+ decide which pandas operation to perform.
42
+
43
+ IMPORTANT rules:
44
+ - Use ONLY the exact column names as written in the schema below. Never translate or rename them.
45
+ - For top_n: always set value_col to the column to sort by. Do NOT use sort_col for top_n.
46
+ - For sort: use sort_col for the column to sort by.
47
+ - For filter with comparison (>, <, >=, <=, !=): set filter_operator accordingly (gt, lt, gte, lte, ne). Default is eq (==).
48
+ - For multi-condition filters (AND logic), use the filters field as a list of {{"col", "value", "op"}} dicts instead of filter_col/filter_value.
49
+ Example: status=SUCCESS AND amount_paid>200000 → filters=[{{"col":"status","value":"SUCCESS","op":"eq"}},{{"col":"amount_paid","value":"200000","op":"gt"}}]
50
+ - IMPORTANT: When the question uses "or" / "atau" between values of the same column, you MUST use or_filters (NOT filters).
51
+ or_filters applies OR logic: rows matching ANY condition are kept.
52
+ filters applies AND logic: rows must match ALL conditions.
53
+ Example: "(status FAILED or REVERSED) AND payment_channel=Tokopedia" →
54
+ or_filters=[{{"col":"status","value":"FAILED","op":"eq"}},{{"col":"status","value":"REVERSED","op":"eq"}}]
55
+ filters=[{{"col":"payment_channel","value":"Tokopedia","op":"eq"}}]
56
+ - For groupby with a pre-filter (e.g. count SUCCESS per channel): use filters or or_filters to narrow rows first, then use groupby_count/groupby_sum/groupby_avg on the filtered data by setting both filters and group_col.
57
+
58
+ Schema:
59
+ {schema}
60
+
61
+ {error_section}"""
62
+
63
+
64
+ class TabularOperation(BaseModel):
65
+ operation: Literal[
66
+ "filter", "groupby_sum", "groupby_avg", "groupby_count",
67
+ "top_n", "sort", "aggregate", "raw"
68
+ ]
69
+ group_col: str | None = None # for groupby_*
70
+ value_col: str | None = None # for groupby_*, top_n, aggregate
71
+ filter_col: str | None = None # for single filter
72
+ filter_value: str | None = None # for single filter
73
+ filter_operator: Literal["eq", "ne", "gt", "gte", "lt", "lte"] = "eq" # for single filter
74
+ filters: list[dict] | None = None # for multi-condition AND: [{"col": ..., "value": ..., "op": ...}]
75
+ or_filters: list[dict] | None = None # for OR conditions, applied before AND filters
76
+ sort_col: str | None = None # for sort
77
+ ascending: bool = True # for sort
78
+ n: int | None = None # for top_n
79
+ agg_func: Literal["sum", "avg", "min", "max", "count"] | None = None # for aggregate
80
+ reasoning: str
81
+
82
+
83
+ def _get_filter_mask(df: pd.DataFrame, col: str, value: str, operator: str) -> pd.Series:
84
+ numeric = pd.to_numeric(df[col], errors="coerce")
85
+ if operator == "eq":
86
+ return df[col].astype(str) == str(value)
87
+ elif operator == "ne":
88
+ return df[col].astype(str) != str(value)
89
+ elif operator == "gt":
90
+ return numeric > float(value)
91
+ elif operator == "gte":
92
+ return numeric >= float(value)
93
+ elif operator == "lt":
94
+ return numeric < float(value)
95
+ elif operator == "lte":
96
+ return numeric <= float(value)
97
+ raise ValueError(f"Unknown operator: {operator}")
98
+
99
+
100
+ def _apply_single_filter(df: pd.DataFrame, col: str, value: str, operator: str) -> pd.DataFrame:
101
+ numeric = pd.to_numeric(df[col], errors="coerce")
102
+ if operator == "eq":
103
+ return df[df[col].astype(str) == str(value)]
104
+ elif operator == "ne":
105
+ return df[df[col].astype(str) != str(value)]
106
+ elif operator == "gt":
107
+ return df[numeric > float(value)]
108
+ elif operator == "gte":
109
+ return df[numeric >= float(value)]
110
+ elif operator == "lt":
111
+ return df[numeric < float(value)]
112
+ elif operator == "lte":
113
+ return df[numeric <= float(value)]
114
+ raise ValueError(f"Unknown operator: {operator}")
115
+
116
+
117
+ def _build_schema_context(df: pd.DataFrame) -> str:
118
+ lines = []
119
+ for col in df.columns:
120
+ sample = df[col].dropna().head(3).tolist()
121
+ lines.append(f"- {col} ({df[col].dtype}): sample values: {sample}")
122
+ return "\n".join(lines)
123
+
124
+
125
+ def _apply_operation(df: pd.DataFrame, op: TabularOperation, limit: int) -> pd.DataFrame:
126
+ if op.operation == "groupby_sum":
127
+ if not op.group_col or not op.value_col:
128
+ raise ValueError(f"groupby_sum requires group_col and value_col, got {op}")
129
+ return df.groupby(op.group_col)[op.value_col].sum().reset_index().nlargest(limit, op.value_col)
130
+ elif op.operation == "groupby_avg":
131
+ if not op.group_col or not op.value_col:
132
+ raise ValueError(f"groupby_avg requires group_col and value_col, got {op}")
133
+ return df.groupby(op.group_col)[op.value_col].mean().reset_index().nlargest(limit, op.value_col)
134
+ elif op.operation == "groupby_count":
135
+ if not op.group_col:
136
+ raise ValueError(f"groupby_count requires group_col, got {op}")
137
+ df_filtered = df.copy()
138
+ if op.or_filters:
139
+ or_mask = pd.Series([False] * len(df_filtered), index=df_filtered.index)
140
+ for f in op.or_filters:
141
+ or_mask = or_mask | _get_filter_mask(df_filtered, f["col"], f["value"], f.get("op", "eq"))
142
+ df_filtered = df_filtered[or_mask]
143
+ if op.filters:
144
+ for f in op.filters:
145
+ df_filtered = _apply_single_filter(df_filtered, f["col"], f["value"], f.get("op", "eq"))
146
+ elif op.filter_col and op.filter_value is not None:
147
+ df_filtered = _apply_single_filter(df_filtered, op.filter_col, op.filter_value, op.filter_operator)
148
+ return df_filtered.groupby(op.group_col).size().reset_index(name="count").nlargest(limit, "count")
149
+ elif op.operation == "filter":
150
+ result = df.copy()
151
+ if op.or_filters:
152
+ or_mask = pd.Series([False] * len(result), index=result.index)
153
+ for f in op.or_filters:
154
+ or_mask = or_mask | _get_filter_mask(result, f["col"], f["value"], f.get("op", "eq"))
155
+ result = result[or_mask]
156
+ if op.filters:
157
+ for f in op.filters:
158
+ result = _apply_single_filter(result, f["col"], f["value"], f.get("op", "eq"))
159
+ elif op.filter_col and op.filter_value is not None and not op.or_filters:
160
+ result = _apply_single_filter(result, op.filter_col, op.filter_value, op.filter_operator)
161
+ elif not op.or_filters and not op.filters and (not op.filter_col or op.filter_value is None):
162
+ raise ValueError(f"filter requires filter_col/filter_value or filters or or_filters, got {op}")
163
+ return result.head(limit)
164
+ elif op.operation == "top_n":
165
+ col = op.value_col or op.sort_col
166
+ if not col:
167
+ raise ValueError(f"top_n requires value_col, got {op}")
168
+ n = op.n or limit
169
+ return df.nlargest(n, col)
170
+ elif op.operation == "sort":
171
+ if not op.sort_col:
172
+ raise ValueError(f"sort requires sort_col, got {op}")
173
+ return df.sort_values(op.sort_col, ascending=op.ascending).head(limit)
174
+ elif op.operation == "aggregate":
175
+ if not op.value_col or not op.agg_func:
176
+ raise ValueError(f"aggregate requires value_col and agg_func, got {op}")
177
+ funcs = {"sum": "sum", "avg": "mean", "min": "min", "max": "max", "count": "count"}
178
+ value = getattr(df[op.value_col], funcs[op.agg_func])()
179
+ return pd.DataFrame([{op.value_col: value, "operation": op.agg_func}])
180
+ else: # "raw"
181
+ return df.head(limit)
182
 
183
 
184
  class TabularExecutor(BaseExecutor):
185
+ def __init__(self) -> None:
186
+ self._llm = AzureChatOpenAI(
187
+ azure_deployment=settings.azureai_deployment_name_4o,
188
+ openai_api_version=settings.azureai_api_version_4o,
189
+ azure_endpoint=settings.azureai_endpoint_url_4o,
190
+ api_key=settings.azureai_api_key_4o,
191
+ temperature=0,
192
+ )
193
+ self._prompt = ChatPromptTemplate.from_messages([
194
+ ("system", _SYSTEM_PROMPT),
195
+ ("human", "{question}"),
196
+ ])
197
+ self._chain = self._prompt | self._llm.with_structured_output(TabularOperation)
198
+
199
  async def execute(
200
  self,
201
  results: list[RetrievalResult],
202
  user_id: str,
203
+ _db: AsyncSession,
204
+ question: str,
205
  limit: int = 100,
206
  ) -> list[QueryResult]:
207
+ tabular = [
208
+ r for r in results
209
+ if r.metadata.get("data", {}).get("file_type") in _TABULAR_FILE_TYPES
210
+ ]
211
+
212
+ if not tabular:
213
+ return []
214
+
215
+ # Group by (document_id, sheet_name) → collect relevant column names
216
+ groups: dict[tuple[str, str | None], _GroupInfo] = {}
217
+ for r in tabular:
218
+ data = r.metadata.get("data", {})
219
+ doc_id = data.get("document_id")
220
+ if not doc_id:
221
+ continue
222
+ sheet_name = data.get("sheet_name") # None for CSV
223
+ col_name = data.get("column_name")
224
+ filename = data.get("filename", "")
225
+ file_type = data.get("file_type", "")
226
+
227
+ key = (doc_id, sheet_name)
228
+ if key not in groups:
229
+ groups[key] = {
230
+ "columns": [],
231
+ "filename": filename,
232
+ "file_type": file_type,
233
+ }
234
+ if col_name and col_name not in groups[key]["columns"]:
235
+ groups[key]["columns"].append(col_name)
236
+
237
+ async def _process_group(
238
+ doc_id: str, sheet_name: str | None, info: _GroupInfo
239
+ ) -> QueryResult | None:
240
+ try:
241
+ df = await download_parquet(user_id, doc_id, sheet_name)
242
+ df_result = await self._query_with_agent(df, question, limit)
243
+
244
+ table_label = info["filename"]
245
+ if sheet_name:
246
+ table_label += f" / sheet: {sheet_name}"
247
+
248
+ logger.info(
249
+ "tabular query complete",
250
+ document_id=doc_id,
251
+ sheet=sheet_name,
252
+ file_type=info["file_type"],
253
+ rows=len(df_result),
254
+ columns=len(df_result.columns),
255
+ )
256
+ return QueryResult(
257
+ source_type="document",
258
+ source_id=doc_id,
259
+ table_or_file=table_label,
260
+ columns=list(df_result.columns),
261
+ rows=df_result.to_dict(orient="records"),
262
+ row_count=len(df_result),
263
+ )
264
+ except Exception as e:
265
+ logger.error(
266
+ "tabular query failed",
267
+ document_id=doc_id,
268
+ sheet=sheet_name,
269
+ error=str(e),
270
+ )
271
+ return None
272
+
273
+ gathered = await asyncio.gather(*[
274
+ _process_group(doc_id, sheet_name, info)
275
+ for (doc_id, sheet_name), info in groups.items()
276
+ ])
277
+ return [r for r in gathered if r is not None]
278
+
279
+ async def _query_with_agent(
280
+ self, df: pd.DataFrame, question: str, limit: int
281
+ ) -> pd.DataFrame:
282
+ schema_ctx = _build_schema_context(df)
283
+ prev_error = ""
284
+
285
+ for attempt in range(_MAX_RETRIES):
286
+ error_section = (
287
+ f"Previous attempt failed: {prev_error}\nFix the issue."
288
+ if prev_error else ""
289
+ )
290
+ try:
291
+ op: TabularOperation = await self._chain.ainvoke({
292
+ "schema": schema_ctx,
293
+ "error_section": error_section,
294
+ "question": question,
295
+ })
296
+ logger.info(
297
+ "tabular operation decided",
298
+ operation=op.operation,
299
+ reasoning=op.reasoning,
300
+ )
301
+ return _apply_operation(df, op, limit)
302
+ except Exception as e:
303
+ prev_error = str(e)
304
+ logger.warning("tabular agent error", attempt=attempt + 1, error=prev_error)
305
+
306
+ # Fallback: return raw rows
307
+ logger.warning("tabular agent failed after retries, returning raw rows")
308
+ return df.head(limit)
309
 
310
 
311
  tabular_executor = TabularExecutor()
src/rag/retrievers/document.py CHANGED
@@ -1,32 +1,154 @@
1
- """Document retriever — handles PDF, DOCX, TXT chunks (source_type="document", non-tabular).
2
 
3
- TEAMMATE: implement retrieve() below.
4
- Strategy: MMR (amax_marginal_relevance_search) + score threshold to avoid returning
5
- near-identical chunks from the same PDF page.
6
- Filter: source_type="document" AND data->>'file_type' NOT IN ('csv', 'xlsx')
7
- """
8
 
 
 
9
  from src.db.postgres.vector_store import get_vector_store
10
  from src.middlewares.logging import get_logger
11
  from src.rag.base import BaseRetriever, RetrievalResult
12
 
13
  logger = get_logger("document_retriever")
14
 
15
- _SCORE_THRESHOLD = 0.45 # discard chunks with cosine distance above this
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
 
18
  class DocumentRetriever(BaseRetriever):
19
- def __init__(self):
20
  self.vector_store = get_vector_store()
21
 
22
  async def retrieve(
23
  self, query: str, user_id: str, k: int = 5
24
  ) -> list[RetrievalResult]:
25
- # TODO (teammate): implement MMR retrieval for prose documents
26
- # Filter: {"user_id": user_id, "source_type": "document"}
27
- # then post-filter to exclude file_type in ("csv", "xlsx")
28
- logger.info("document retriever not yet implemented — returning empty")
29
- return []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
 
32
  document_retriever = DocumentRetriever()
 
1
+ """Document retriever — handles PDF, DOCX, TXT chunks (source_type="document", non-tabular)."""
2
 
3
+ from langchain_postgres import PGVector
4
+ from langchain_postgres.vectorstores import DistanceStrategy
5
+ from langchain_openai import AzureOpenAIEmbeddings
6
+ from sqlalchemy import text
 
7
 
8
+ from src.config.settings import settings
9
+ from src.db.postgres.connection import _pgvector_engine
10
  from src.db.postgres.vector_store import get_vector_store
11
  from src.middlewares.logging import get_logger
12
  from src.rag.base import BaseRetriever, RetrievalResult
13
 
14
  logger = get_logger("document_retriever")
15
 
16
+ # Change this one line to switch retrieval method
17
+ # Options: "mmr" | "cosine" | "euclidean" | "inner_product" | "manhattan"
18
+ _RETRIEVAL_METHOD = "mmr"
19
+
20
+ _TABULAR_TYPES = {"csv", "xlsx"}
21
+ _FETCH_K = 20
22
+ _LAMBDA_MULT = 0.5
23
+ _COLLECTION_NAME = "document_embeddings"
24
+
25
+ _embeddings = AzureOpenAIEmbeddings(
26
+ azure_deployment=settings.azureai_deployment_name_embedding,
27
+ openai_api_version=settings.azureai_api_version_embedding,
28
+ azure_endpoint=settings.azureai_endpoint_url_embedding,
29
+ api_key=settings.azureai_api_key_embedding,
30
+ )
31
+
32
+ _euclidean_store = PGVector(
33
+ embeddings=_embeddings,
34
+ connection=_pgvector_engine,
35
+ collection_name=_COLLECTION_NAME,
36
+ distance_strategy=DistanceStrategy.EUCLIDEAN,
37
+ use_jsonb=True,
38
+ async_mode=True,
39
+ create_extension=False,
40
+ )
41
+
42
+ _ip_store = PGVector(
43
+ embeddings=_embeddings,
44
+ connection=_pgvector_engine,
45
+ collection_name=_COLLECTION_NAME,
46
+ distance_strategy=DistanceStrategy.MAX_INNER_PRODUCT,
47
+ use_jsonb=True,
48
+ async_mode=True,
49
+ create_extension=False,
50
+ )
51
+
52
+ _MANHATTAN_SQL = text("""
53
+ SELECT
54
+ lpe.document,
55
+ lpe.cmetadata,
56
+ lpe.embedding <+> CAST(:embedding AS vector) AS distance
57
+ FROM langchain_pg_embedding lpe
58
+ JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid
59
+ WHERE lpc.name = :collection
60
+ AND lpe.cmetadata->>'user_id' = :user_id
61
+ AND lpe.cmetadata->>'source_type' = 'document'
62
+ ORDER BY distance ASC
63
+ LIMIT :k
64
+ """)
65
 
66
 
67
  class DocumentRetriever(BaseRetriever):
68
+ def __init__(self) -> None:
69
  self.vector_store = get_vector_store()
70
 
71
  async def retrieve(
72
  self, query: str, user_id: str, k: int = 5
73
  ) -> list[RetrievalResult]:
74
+ filter_ = {"user_id": user_id, "source_type": "document"}
75
+ fetch_k = k + len(_TABULAR_TYPES)
76
+
77
+ if _RETRIEVAL_METHOD == "manhattan":
78
+ return await self._retrieve_manhattan(query, user_id, k, fetch_k)
79
+
80
+ if _RETRIEVAL_METHOD == "mmr":
81
+ docs = await self.vector_store.amax_marginal_relevance_search(
82
+ query=query,
83
+ k=fetch_k,
84
+ fetch_k=_FETCH_K,
85
+ lambda_mult=_LAMBDA_MULT,
86
+ filter=filter_,
87
+ )
88
+ cosine = await self.vector_store.asimilarity_search_with_score(
89
+ query=query, k=fetch_k, filter=filter_,
90
+ )
91
+ score_map = {doc.page_content: score for doc, score in cosine}
92
+ docs_with_scores = [(doc, score_map.get(doc.page_content, 0.0)) for doc in docs]
93
+ elif _RETRIEVAL_METHOD == "euclidean":
94
+ docs_with_scores = await _euclidean_store.asimilarity_search_with_score(
95
+ query=query, k=fetch_k, filter=filter_,
96
+ )
97
+ elif _RETRIEVAL_METHOD == "inner_product":
98
+ docs_with_scores = await _ip_store.asimilarity_search_with_score(
99
+ query=query, k=fetch_k, filter=filter_,
100
+ )
101
+ else: # cosine
102
+ docs_with_scores = await self.vector_store.asimilarity_search_with_score(
103
+ query=query, k=fetch_k, filter=filter_,
104
+ )
105
+
106
+ results = []
107
+ for doc, score in docs_with_scores:
108
+ file_type = doc.metadata.get("data", {}).get("file_type", "")
109
+ if file_type not in _TABULAR_TYPES:
110
+ results.append(RetrievalResult(
111
+ content=doc.page_content,
112
+ metadata=doc.metadata,
113
+ score=score,
114
+ source_type="document",
115
+ ))
116
+ if len(results) == k:
117
+ break
118
+
119
+ logger.info("retrieved chunks", method=_RETRIEVAL_METHOD, count=len(results))
120
+ return results
121
+
122
+ async def _retrieve_manhattan(
123
+ self, query: str, user_id: str, k: int, fetch_k: int
124
+ ) -> list[RetrievalResult]:
125
+ query_vector = await _embeddings.aembed_query(query)
126
+ vector_str = "[" + ",".join(str(v) for v in query_vector) + "]"
127
+
128
+ async with _pgvector_engine.connect() as conn:
129
+ result = await conn.execute(_MANHATTAN_SQL, {
130
+ "embedding": vector_str,
131
+ "collection": _COLLECTION_NAME,
132
+ "user_id": user_id,
133
+ "k": fetch_k,
134
+ })
135
+ rows = result.fetchall()
136
+
137
+ results = []
138
+ for row in rows:
139
+ file_type = row.cmetadata.get("data", {}).get("file_type", "")
140
+ if file_type not in _TABULAR_TYPES:
141
+ results.append(RetrievalResult(
142
+ content=row.document,
143
+ metadata=row.cmetadata,
144
+ score=float(row.distance),
145
+ source_type="document",
146
+ ))
147
+ if len(results) == k:
148
+ break
149
+
150
+ logger.info("retrieved chunks", method="manhattan", count=len(results))
151
+ return results
152
 
153
 
154
  document_retriever = DocumentRetriever()
uv.lock CHANGED
@@ -47,6 +47,7 @@ dependencies = [
47
  { name = "prometheus-client" },
48
  { name = "psycopg", extra = ["binary", "pool"] },
49
  { name = "psycopg2" },
 
50
  { name = "pydantic" },
51
  { name = "pydantic-settings" },
52
  { name = "pymongo" },
@@ -127,6 +128,7 @@ requires-dist = [
127
  { name = "prometheus-client", specifier = "==0.21.1" },
128
  { name = "psycopg", extras = ["binary", "pool"], specifier = "==3.2.3" },
129
  { name = "psycopg2", specifier = ">=2.9.11" },
 
130
  { name = "pydantic", specifier = "==2.10.3" },
131
  { name = "pydantic-settings", specifier = "==2.7.0" },
132
  { name = "pymongo", specifier = ">=4.14.0" },
@@ -2400,6 +2402,21 @@ wheels = [
2400
  { url = "https://files.pythonhosted.org/packages/b5/bf/635fbe5dd10ed200afbbfbe98f8602829252ca1cce81cc48fb25ed8dadc0/psycopg2-2.9.11-cp312-cp312-win_amd64.whl", hash = "sha256:e03e4a6dbe87ff81540b434f2e5dc2bddad10296db5eea7bdc995bf5f4162938", size = 2713969, upload-time = "2025-10-10T11:10:15.946Z" },
2401
  ]
2402
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2403
  [[package]]
2404
  name = "pyasn1"
2405
  version = "0.6.3"
 
47
  { name = "prometheus-client" },
48
  { name = "psycopg", extra = ["binary", "pool"] },
49
  { name = "psycopg2" },
50
+ { name = "pyarrow" },
51
  { name = "pydantic" },
52
  { name = "pydantic-settings" },
53
  { name = "pymongo" },
 
128
  { name = "prometheus-client", specifier = "==0.21.1" },
129
  { name = "psycopg", extras = ["binary", "pool"], specifier = "==3.2.3" },
130
  { name = "psycopg2", specifier = ">=2.9.11" },
131
+ { name = "pyarrow", specifier = ">=24.0.0" },
132
  { name = "pydantic", specifier = "==2.10.3" },
133
  { name = "pydantic-settings", specifier = "==2.7.0" },
134
  { name = "pymongo", specifier = ">=4.14.0" },
 
2402
  { url = "https://files.pythonhosted.org/packages/b5/bf/635fbe5dd10ed200afbbfbe98f8602829252ca1cce81cc48fb25ed8dadc0/psycopg2-2.9.11-cp312-cp312-win_amd64.whl", hash = "sha256:e03e4a6dbe87ff81540b434f2e5dc2bddad10296db5eea7bdc995bf5f4162938", size = 2713969, upload-time = "2025-10-10T11:10:15.946Z" },
2403
  ]
2404
 
2405
+ [[package]]
2406
+ name = "pyarrow"
2407
+ version = "24.0.0"
2408
+ source = { registry = "https://pypi.org/simple" }
2409
+ sdist = { url = "https://files.pythonhosted.org/packages/91/13/13e1069b351bdc3881266e11147ffccf687505dbb0ea74036237f5d454a5/pyarrow-24.0.0.tar.gz", hash = "sha256:85fe721a14dd823aca09127acbb06c3ca723efbd436c004f16bca601b04dcc83", size = 1180261, upload-time = "2026-04-21T10:51:25.837Z" }
2410
+ wheels = [
2411
+ { url = "https://files.pythonhosted.org/packages/b4/a9/9686d9f07837f91f775e8932659192e02c74f9d8920524b480b85212cc68/pyarrow-24.0.0-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:6233c9ed9ab9d1db47de57d9753256d9dcffbf42db341576099f0fd9f6bf4810", size = 34981559, upload-time = "2026-04-21T10:47:22.17Z" },
2412
+ { url = "https://files.pythonhosted.org/packages/80/b6/0ddf0e9b6ead3474ab087ae598c76b031fc45532bf6a63f3a553440fb258/pyarrow-24.0.0-cp312-cp312-macosx_12_0_x86_64.whl", hash = "sha256:f7616236ec1bc2b15bfdec22a71ab38851c86f8f05ff64f379e1278cf20c634a", size = 36663654, upload-time = "2026-04-21T10:47:28.315Z" },
2413
+ { url = "https://files.pythonhosted.org/packages/7c/3b/926382efe8ce27ba729071d3566ade6dfb86bdf112f366000196b2f5780a/pyarrow-24.0.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:1617043b99bd33e5318ae18eb2919af09c71322ef1ca46566cdafc6e6712fb66", size = 45679394, upload-time = "2026-04-21T10:47:34.821Z" },
2414
+ { url = "https://files.pythonhosted.org/packages/b3/7a/829f7d9dfd37c207206081d6dad474d81dde29952401f07f2ba507814818/pyarrow-24.0.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:6165461f55ef6314f026de6638d661188e3455d3ec49834556a0ebbdbace18bb", size = 48863122, upload-time = "2026-04-21T10:47:42.056Z" },
2415
+ { url = "https://files.pythonhosted.org/packages/5f/e8/f88ce625fe8babaae64e8db2d417c7653adb3019b08aae85c5ed787dc816/pyarrow-24.0.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:3b13dedfe76a0ad2d1d859b0811b53827a4e9d93a0bcb05cf59333ab4980cc7e", size = 49376032, upload-time = "2026-04-21T10:47:48.967Z" },
2416
+ { url = "https://files.pythonhosted.org/packages/36/7a/82c363caa145fff88fb475da50d3bf52bb024f61917be5424c3392eaf878/pyarrow-24.0.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:25ea65d868eb04015cd18e6df2fbe98f07e5bda2abefabcb88fce39a947716f6", size = 51929490, upload-time = "2026-04-21T10:47:55.981Z" },
2417
+ { url = "https://files.pythonhosted.org/packages/66/1c/e3e72c8014ad2743ca64a701652c733cc5cbcee15c0463a32a8c55518d9e/pyarrow-24.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:295f0a7f2e242dabd513737cf076007dc5b2d59237e3eca37b05c0c6446f3826", size = 27355660, upload-time = "2026-04-21T10:48:01.718Z" },
2418
+ ]
2419
+
2420
  [[package]]
2421
  name = "pyasn1"
2422
  version = "0.6.3"