VladGeekPro Copilot commited on
Commit
144bf42
·
1 Parent(s): 7b59e3d

DeletedDBAgentAndCreatedExpensePredictor

Browse files

Co-authored-by: Copilot <copilot@github.com>

Files changed (4) hide show
  1. app.py +16 -11
  2. expense_predictor.py +102 -0
  3. requirements.txt +0 -2
  4. sql_generator.py +0 -111
app.py CHANGED
@@ -25,7 +25,7 @@ from extractors import (
25
  ExpenseUserExtractor,
26
  ExpenseAmountExtractor,
27
  )
28
- from sql_generator import generate_sql
29
 
30
 
31
  # HuggingFace Token (если нужен для моделей)
@@ -576,10 +576,10 @@ def index():
576
  """Главная страница API."""
577
  return jsonify({
578
  "status": "ok",
579
- "message": "Voice processing API is running",
580
  "endpoints": {
581
  "POST /process-audio": "Process audio file",
582
- "POST /generate-sql": "Generate SQLite SELECT query from natural language",
583
  "GET /health": "Health check",
584
  "GET /test-data": "Run text-only extraction tests"
585
  }
@@ -674,16 +674,21 @@ def process_audio():
674
  os.unlink(temp_path)
675
 
676
 
677
- @app.post("/generate-sql")
678
- def generate_sql_endpoint():
679
- """Генерирует SQL по текстовому запросу и схеме БД."""
680
  payload = parse_json_payload()
681
- query = payload.get("query") or payload.get("text") or ""
682
- limit = payload.get("limit") or 200
683
-
 
 
684
  try:
685
- sql = generate_sql(question=query, limit=int(limit))
686
- return jsonify({"sql": sql})
 
 
 
687
  except Exception as exception:
688
  return jsonify({"status": "error", "message": str(exception)}), 422
689
 
 
25
  ExpenseUserExtractor,
26
  ExpenseAmountExtractor,
27
  )
28
+ from expense_predictor import predict_expenses
29
 
30
 
31
  # HuggingFace Token (если нужен для моделей)
 
576
  """Главная страница API."""
577
  return jsonify({
578
  "status": "ok",
579
+ "message": "Expense Processing API is running",
580
  "endpoints": {
581
  "POST /process-audio": "Process audio file",
582
+ "POST /predict-expenses": "Predict next 3 expenses based on history",
583
  "GET /health": "Health check",
584
  "GET /test-data": "Run text-only extraction tests"
585
  }
 
674
  os.unlink(temp_path)
675
 
676
 
677
+ @app.post("/predict-expenses")
678
+ def predict_expenses_endpoint():
679
+ """Predicts top 3 expenses user should add based on 6-month history."""
680
  payload = parse_json_payload()
681
+ expenses = payload.get("expenses") or []
682
+
683
+ if not isinstance(expenses, list):
684
+ return jsonify({"status": "error", "message": "expenses must be a list"}), 422
685
+
686
  try:
687
+ predictions = predict_expenses(expenses)
688
+ return jsonify({
689
+ "status": "ok",
690
+ "predictions": predictions
691
+ })
692
  except Exception as exception:
693
  return jsonify({"status": "error", "message": str(exception)}), 422
694
 
expense_predictor.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Expense prediction model: suggests next expenses based on 6-month history.
3
+ - Input: JSON array of 300 expense records
4
+ - Output: Top 3 predicted expenses (date, sum, supplier, user)
5
+ """
6
+
7
+ from datetime import datetime, timedelta
8
+ from collections import defaultdict
9
+ from typing import Optional
10
+ import statistics
11
+
12
+
13
+ def predict_expenses(expenses: list[dict]) -> list[dict]:
14
+ """
15
+ Predict top 3 expenses user should add.
16
+
17
+ Input: [{"date": "2026-01-15", "sum": 150.50, "supplier_id": 5, "user_id": 1, ...}, ...]
18
+ Output: [{"date": str, "sum": float, "supplier_id": int, "user_id": int, "confidence": float}, ...]
19
+ """
20
+ if not expenses or len(expenses) < 2:
21
+ print(f"[PREDICT] Not enough records: {len(expenses) if expenses else 0}")
22
+ return []
23
+
24
+ # Group by (supplier_id, user_id)
25
+ supplier_user_history = defaultdict(list)
26
+ supplier_freq = defaultdict(int)
27
+ total_records = len(expenses)
28
+
29
+ print(f"[PREDICT] Total records received: {total_records}")
30
+ for i, exp in enumerate(expenses):
31
+ print(f"[PREDICT] [{i+1}] date={exp.get('date')}, sum={exp.get('sum')}, supplier_id={exp.get('supplier_id')}, user_id={exp.get('user_id')}")
32
+
33
+ for exp in expenses:
34
+ key = (exp["supplier_id"], exp["user_id"])
35
+ supplier_user_history[key].append(exp)
36
+ supplier_freq[key] += 1
37
+
38
+ print(f"[PREDICT] Unique (supplier, user) pairs: {len(supplier_user_history)}")
39
+ for key, count in supplier_freq.items():
40
+ pct = count / total_records * 100
41
+ print(f"[PREDICT] supplier_id={key[0]}, user_id={key[1]} → {count} records ({pct:.1f}%)")
42
+
43
+ # Filter: frequency > 50% over 6 months
44
+ candidates = {
45
+ key: records
46
+ for key, records in supplier_user_history.items()
47
+ if supplier_freq[key] / total_records >= 0.5
48
+ }
49
+
50
+ print(f"[PREDICT] Candidates after >50% filter: {len(candidates)}")
51
+
52
+ if not candidates:
53
+ print("[PREDICT] No candidates passed the frequency filter. Returning empty.")
54
+ return []
55
+
56
+ # Analyze each candidate: avg amount, interval, last date
57
+ predictions = []
58
+
59
+ for (supplier_id, user_id), records in candidates.items():
60
+ amounts = [float(r["sum"]) for r in records]
61
+ avg_amount = statistics.mean(amounts)
62
+
63
+ # Calculate interval between transactions (days)
64
+ dates = sorted([datetime.fromisoformat(r["date"]) for r in records])
65
+ if len(dates) >= 2:
66
+ intervals = [(dates[i+1] - dates[i]).days for i in range(len(dates) - 1)]
67
+ avg_interval = statistics.mean(intervals)
68
+ else:
69
+ avg_interval = 30 # default monthly
70
+
71
+ last_date = dates[-1]
72
+ next_predicted_date = (last_date + timedelta(days=avg_interval)).strftime("%Y-%m-%d")
73
+
74
+ # Confidence: higher if more consistent (lower std dev)
75
+ amount_std = statistics.stdev(amounts) if len(amounts) > 1 else 0
76
+ consistency = max(0, 1 - (amount_std / avg_amount)) if avg_amount > 0 else 0.5
77
+ frequency_score = min(supplier_freq[(supplier_id, user_id)] / total_records, 1.0)
78
+ confidence = (consistency + frequency_score) / 2
79
+
80
+ print(
81
+ f"[PREDICT] supplier_id={supplier_id}, user_id={user_id} | "
82
+ f"avg_amount={avg_amount:.2f}, avg_interval={avg_interval:.1f}d, "
83
+ f"last_date={last_date.date()}, next_date={next_predicted_date}, "
84
+ f"consistency={consistency:.2f}, freq_score={frequency_score:.2f}, confidence={confidence:.2f}"
85
+ )
86
+
87
+ predictions.append({
88
+ "date": next_predicted_date,
89
+ "sum": round(avg_amount, 2),
90
+ "supplier_id": supplier_id,
91
+ "user_id": user_id,
92
+ "confidence": round(confidence, 2)
93
+ })
94
+
95
+ # Return top 3 by confidence
96
+ result = sorted(predictions, key=lambda x: x["confidence"], reverse=True)[:3]
97
+
98
+ print(f"[PREDICT] Final top {len(result)} predictions:")
99
+ for i, p in enumerate(result, 1):
100
+ print(f"[PREDICT] #{i}: supplier_id={p['supplier_id']}, user_id={p['user_id']}, date={p['date']}, sum={p['sum']}, confidence={p['confidence']}")
101
+
102
+ return result
requirements.txt CHANGED
@@ -9,5 +9,3 @@ python-dateutil
9
  iuliia
10
  scikit-learn
11
  sentencepiece
12
- transformers==4.41.2
13
- torch==2.3.1
 
9
  iuliia
10
  scikit-learn
11
  sentencepiece
 
 
sql_generator.py DELETED
@@ -1,111 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import os
4
- from dataclasses import dataclass
5
- from typing import Any
6
-
7
-
8
- # Compact Spider-style schema: only business tables, no Laravel internals.
9
- # Format: table col type , col type | table2 col type , col type
10
- # Foreign keys annotated inline for model guidance.
11
- DEFAULT_DB_SCHEMA = (
12
- "users : id int , name varchar , email varchar , created_at datetime , updated_at datetime | "
13
- "categories : id int , name varchar , slug varchar , notes text , created_at datetime , updated_at datetime | "
14
- "suppliers : id int , name varchar , slug varchar , category_id int , created_at datetime , updated_at datetime | "
15
- "expenses : id int , user_id int , date date , category_id int , supplier_id int , sum numeric , notes text , created_at datetime , updated_at datetime | "
16
- "debts : id int , date date , user_id int , debt_sum numeric , overpayment_id int , notes text , payment_status varchar , partial_sum numeric , date_paid date , created_at datetime , updated_at datetime | "
17
- "overpayments : id int , user_id int , sum numeric , notes text , created_at datetime , updated_at datetime | "
18
- "paid_debts : id int , debt_id int , changed_debt_date date , paid_by_user_id int , payment_status varchar , paid_sum numeric , created_at datetime , updated_at datetime | "
19
- "expense_change_requests : id int , expense_id int , user_id int , action_type varchar , current_date date , current_user_id int , current_category_id int , current_supplier_id int , current_sum numeric , requested_date date , requested_user_id int , requested_category_id int , requested_supplier_id int , requested_sum numeric , notes text , status varchar , applied_at datetime , created_at datetime , updated_at datetime | "
20
- "expense_change_request_votes : id int , expense_change_request_id int , user_id int , vote varchar , notes text , created_at datetime , updated_at datetime"
21
- )
22
-
23
- _SQL_GENERATOR: Any | None = None
24
-
25
-
26
- @dataclass(frozen=True)
27
- class SqlGenerationRequest:
28
- question: str
29
- limit: int = 200
30
-
31
-
32
- def _get_sql_generator() -> Any:
33
- global _SQL_GENERATOR
34
-
35
- if _SQL_GENERATOR is None:
36
- from transformers import pipeline
37
-
38
- model_id = os.getenv("SQL_MODEL", "gaussalgo/T5-LM-Large-text2sql-spider")
39
- _SQL_GENERATOR = pipeline(
40
- task="text2text-generation",
41
- model=model_id,
42
- tokenizer=model_id,
43
- device=-1,
44
- )
45
-
46
- return _SQL_GENERATOR
47
-
48
-
49
- def _build_prompt(payload: SqlGenerationRequest) -> str:
50
- # gaussalgo/T5-LM-Large-text2sql-spider is trained on Spider benchmark.
51
- # Expected format: "Question: {q} | {compact_schema}"
52
- # where schema uses pipe-separator between tables and " : " between table name and columns.
53
- return f"Question: {payload.question} | {DEFAULT_DB_SCHEMA}"
54
-
55
-
56
- def _normalize_sql(raw_sql: str, limit: int) -> str:
57
- sql = (raw_sql or "").strip()
58
- if not sql:
59
- raise ValueError("SQL model returned an empty result.")
60
-
61
- if "```" in sql:
62
- parts = [part.strip() for part in sql.split("```") if part.strip()]
63
- sql = parts[-1]
64
-
65
- upper_sql = sql.upper()
66
- sql_start = upper_sql.find("SELECT")
67
- if sql_start == -1:
68
- raise ValueError("Generated SQL is not a SELECT query.")
69
-
70
- sql = sql[sql_start:]
71
- if ";" in sql:
72
- sql = sql.split(";", 1)[0].strip()
73
-
74
- upper_sql = sql.upper()
75
- forbidden = ("INSERT ", "UPDATE ", "DELETE ", "DROP ", "ALTER ", "PRAGMA ", "ATTACH ", "CREATE ", "REPLACE ")
76
- if any(keyword in upper_sql for keyword in forbidden):
77
- raise ValueError("Generated SQL contains forbidden statements.")
78
-
79
- if not upper_sql.startswith("SELECT "):
80
- raise ValueError("Only SELECT queries are allowed.")
81
-
82
- aggregate_markers = ("COUNT(", "SUM(", "AVG(", "MIN(", "MAX(")
83
- has_limit = " LIMIT " in upper_sql
84
- if not has_limit and not any(marker in upper_sql for marker in aggregate_markers):
85
- sql = f"{sql} LIMIT {limit}"
86
-
87
- return sql
88
-
89
-
90
- def generate_sql(question: str, limit: int = 200) -> str:
91
- clean_question = (question or "").strip()
92
- if not clean_question:
93
- raise ValueError("Field 'query' is required.")
94
-
95
- payload = SqlGenerationRequest(
96
- question=clean_question,
97
- limit=limit,
98
- )
99
-
100
- generator = _get_sql_generator()
101
- prompt = _build_prompt(payload)
102
- result = generator(
103
- prompt,
104
- max_new_tokens=512,
105
- do_sample=False,
106
- num_beams=4,
107
- truncation=True,
108
- )
109
-
110
- generated_text = result[0].get("generated_text", "") if result else ""
111
- return _normalize_sql(generated_text, limit=payload.limit)