Cashy / scripts /experiment_table_qa.py
GitHub Actions
Deploy to HF Spaces
17a78b5
"""
Experiment: Local Table QA with TAPAS/TAPEX models on financial data.
Runs a table QA model locally on CPU to answer questions about data
from the PostgreSQL financial database.
Supports two architectures:
- TAPAS (google/tapas-*): cell selection + aggregation
- TAPEX (microsoft/tapex-*): seq2seq text generation (BART-based)
Run:
uv run --extra experiment python scripts/experiment_table_qa.py
uv run --extra experiment python scripts/experiment_table_qa.py --model google/tapas-small-finetuned-wtq
uv run --extra experiment python scripts/experiment_table_qa.py --model microsoft/tapex-base-finetuned-wtq
"""
import argparse
import json
import sys
import time
from datetime import datetime
from pathlib import Path
# Add project root to path so we can import src.*
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
import pandas as pd
from src.db.connection import get_connection
DEFAULT_MODEL = "google/tapas-mini-finetuned-wtq"
QUERY = """
SELECT
transaction_date::text AS date,
transaction_description AS description,
category_name AS category,
entry_amount::text AS amount,
account_name AS account
FROM v_transaction_details
WHERE category_name IS NOT NULL
ORDER BY transaction_date DESC
LIMIT 15
"""
QUESTIONS = [
"What is the total amount?",
"Which category has the highest amount?",
"How many transactions are there?",
]
RESULTS_FILE = Path(__file__).resolve().parent.parent / "eval_cases" / "table_qa_results.jsonl"
def is_tapex(model_name: str) -> bool:
return "tapex" in model_name.lower()
def fetch_table() -> pd.DataFrame:
print("[STEP 1] Connecting to PostgreSQL...")
with get_connection() as conn:
print(f" -> Connected to: {conn.dsn}")
print(f" -> Executing query:\n{QUERY.strip()}")
with conn.cursor() as cur:
cur.execute(QUERY)
columns = [desc[0] for desc in cur.description]
rows = cur.fetchall()
print(f" -> Raw result: {len(rows)} rows, {len(columns)} columns")
print(f" -> Columns: {columns}")
# Build DataFrame from dict of lists — pandas 2.x defaults to object dtype.
# TAPAS tokenizer mutates cells via iloc with its internal Cell namedtuple,
# which requires object dtype (incompatible with pandas 3.0 StringDtype).
data = {col: [str(row[i]) for row in rows] for i, col in enumerate(columns)}
df = pd.DataFrame.from_dict(data)
print(f" -> Dtypes (should be object):\n{df.dtypes.to_string()}")
print(f" -> Sample row: {dict(df.iloc[0])}")
print()
return df
def load_tapas(model_name):
from transformers import pipeline
tqa = pipeline("table-question-answering", model=model_name, device=-1)
param_count = sum(p.numel() for p in tqa.model.parameters())
print(f" -> Architecture: TAPAS (cell selection + aggregation)")
print(f" -> Model class: {type(tqa.model).__name__}")
print(f" -> Tokenizer: {type(tqa.tokenizer).__name__}")
print(f" -> Model params: {param_count:,}")
return tqa, param_count
def load_tapex(model_name):
from transformers import BartForConditionalGeneration, TapexTokenizer
tokenizer = TapexTokenizer.from_pretrained(model_name)
model = BartForConditionalGeneration.from_pretrained(model_name)
param_count = sum(p.numel() for p in model.parameters())
print(f" -> Architecture: TAPEX (seq2seq text generation, BART-based)")
print(f" -> Model class: {type(model).__name__}")
print(f" -> Tokenizer: {type(tokenizer).__name__}")
print(f" -> Model params: {param_count:,}")
return (tokenizer, model), param_count
def run_tapas(tqa, table, query):
result = tqa(table=table, query=query)
return {
"answer": result["answer"],
"cells": result.get("cells", []),
"aggregator": result.get("aggregator", "NONE"),
}
def run_tapex(tapex_pair, table, query):
tokenizer, model = tapex_pair
encoding = tokenizer(table=table, query=query, return_tensors="pt", truncation=True)
print(f" -> Input token count: {encoding['input_ids'].shape[1]}")
outputs = model.generate(**encoding, max_new_tokens=50)
decoded = tokenizer.batch_decode(outputs, skip_special_tokens=True)
answer = decoded[0] if decoded else ""
return {
"answer": answer,
"cells": [],
"aggregator": "seq2seq",
}
def main():
parser = argparse.ArgumentParser(description="Table QA experiment (TAPAS/TAPEX)")
parser.add_argument("--model", default=DEFAULT_MODEL, help="HuggingFace model name")
args = parser.parse_args()
model_name = args.model
use_tapex = is_tapex(model_name)
run_results = {
"timestamp": datetime.now().isoformat(),
"model": model_name,
"architecture": "tapex" if use_tapex else "tapas",
"questions": [],
}
# --- Model loading ---
print("=" * 60)
print("[STEP 2] Loading model")
print("=" * 60)
print(f" -> Model: {model_name}")
print(f" -> Device: CPU")
print()
t0 = time.time()
if use_tapex:
model_obj, param_count = load_tapex(model_name)
else:
model_obj, param_count = load_tapas(model_name)
load_time = time.time() - t0
print(f" -> Load time: {load_time:.2f}s")
print()
run_results["params"] = param_count
run_results["load_time_s"] = round(load_time, 2)
# --- Data fetching ---
print("=" * 60)
print("[STEP 1] Fetching data from PostgreSQL")
print("=" * 60)
table = fetch_table()
# --- Display table ---
print("=" * 60)
print("TABLE (15 most recent transactions)")
print("=" * 60)
print(table.to_string(index=False))
print()
# --- Q&A ---
run_fn = run_tapex if use_tapex else run_tapas
print("=" * 60)
print("[STEP 3] Running Table QA inference")
print("=" * 60)
for i, q in enumerate(QUESTIONS, 1):
print(f"\n--- Question {i}/{len(QUESTIONS)} ---")
print(f" -> Input query: {q!r}")
print(f" -> Input table shape: {table.shape}")
t0 = time.time()
result = run_fn(model_obj, table, q)
inference_time = time.time() - t0
print(f" -> Answer: {result['answer']}")
print(f" -> Cells: {result.get('cells', [])}")
print(f" -> Aggregator: {result.get('aggregator', 'N/A')}")
print(f" -> Inference time: {inference_time:.3f}s")
run_results["questions"].append({
"query": q,
"answer": result["answer"],
"cells": result.get("cells", []),
"aggregator": result.get("aggregator", "N/A"),
"inference_time_s": round(inference_time, 3),
})
print()
# --- Save results ---
RESULTS_FILE.parent.mkdir(parents=True, exist_ok=True)
with open(RESULTS_FILE, "a") as f:
f.write(json.dumps(run_results, ensure_ascii=False) + "\n")
print(f"Results appended to: {RESULTS_FILE}")
print("=" * 60)
print("Experiment complete.")
print("=" * 60)
if __name__ == "__main__":
main()