Spaces:
Running
Running
VladGeekPro Copilot commited on
Commit ·
144bf42
1
Parent(s): 7b59e3d
DeletedDBAgentAndCreatedExpensePredictor
Browse filesCo-authored-by: Copilot <copilot@github.com>
- app.py +16 -11
- expense_predictor.py +102 -0
- requirements.txt +0 -2
- sql_generator.py +0 -111
app.py
CHANGED
|
@@ -25,7 +25,7 @@ from extractors import (
|
|
| 25 |
ExpenseUserExtractor,
|
| 26 |
ExpenseAmountExtractor,
|
| 27 |
)
|
| 28 |
-
from
|
| 29 |
|
| 30 |
|
| 31 |
# HuggingFace Token (если нужен для моделей)
|
|
@@ -576,10 +576,10 @@ def index():
|
|
| 576 |
"""Главная страница API."""
|
| 577 |
return jsonify({
|
| 578 |
"status": "ok",
|
| 579 |
-
"message": "
|
| 580 |
"endpoints": {
|
| 581 |
"POST /process-audio": "Process audio file",
|
| 582 |
-
"POST /
|
| 583 |
"GET /health": "Health check",
|
| 584 |
"GET /test-data": "Run text-only extraction tests"
|
| 585 |
}
|
|
@@ -674,16 +674,21 @@ def process_audio():
|
|
| 674 |
os.unlink(temp_path)
|
| 675 |
|
| 676 |
|
| 677 |
-
@app.post("/
|
| 678 |
-
def
|
| 679 |
-
"""
|
| 680 |
payload = parse_json_payload()
|
| 681 |
-
|
| 682 |
-
|
| 683 |
-
|
|
|
|
|
|
|
| 684 |
try:
|
| 685 |
-
|
| 686 |
-
return jsonify({
|
|
|
|
|
|
|
|
|
|
| 687 |
except Exception as exception:
|
| 688 |
return jsonify({"status": "error", "message": str(exception)}), 422
|
| 689 |
|
|
|
|
| 25 |
ExpenseUserExtractor,
|
| 26 |
ExpenseAmountExtractor,
|
| 27 |
)
|
| 28 |
+
from expense_predictor import predict_expenses
|
| 29 |
|
| 30 |
|
| 31 |
# HuggingFace Token (если нужен для моделей)
|
|
|
|
| 576 |
"""Главная страница API."""
|
| 577 |
return jsonify({
|
| 578 |
"status": "ok",
|
| 579 |
+
"message": "Expense Processing API is running",
|
| 580 |
"endpoints": {
|
| 581 |
"POST /process-audio": "Process audio file",
|
| 582 |
+
"POST /predict-expenses": "Predict next 3 expenses based on history",
|
| 583 |
"GET /health": "Health check",
|
| 584 |
"GET /test-data": "Run text-only extraction tests"
|
| 585 |
}
|
|
|
|
| 674 |
os.unlink(temp_path)
|
| 675 |
|
| 676 |
|
| 677 |
+
@app.post("/predict-expenses")
|
| 678 |
+
def predict_expenses_endpoint():
|
| 679 |
+
"""Predicts top 3 expenses user should add based on 6-month history."""
|
| 680 |
payload = parse_json_payload()
|
| 681 |
+
expenses = payload.get("expenses") or []
|
| 682 |
+
|
| 683 |
+
if not isinstance(expenses, list):
|
| 684 |
+
return jsonify({"status": "error", "message": "expenses must be a list"}), 422
|
| 685 |
+
|
| 686 |
try:
|
| 687 |
+
predictions = predict_expenses(expenses)
|
| 688 |
+
return jsonify({
|
| 689 |
+
"status": "ok",
|
| 690 |
+
"predictions": predictions
|
| 691 |
+
})
|
| 692 |
except Exception as exception:
|
| 693 |
return jsonify({"status": "error", "message": str(exception)}), 422
|
| 694 |
|
expense_predictor.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Expense prediction model: suggests next expenses based on 6-month history.
|
| 3 |
+
- Input: JSON array of 300 expense records
|
| 4 |
+
- Output: Top 3 predicted expenses (date, sum, supplier, user)
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from datetime import datetime, timedelta
|
| 8 |
+
from collections import defaultdict
|
| 9 |
+
from typing import Optional
|
| 10 |
+
import statistics
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def predict_expenses(expenses: list[dict]) -> list[dict]:
|
| 14 |
+
"""
|
| 15 |
+
Predict top 3 expenses user should add.
|
| 16 |
+
|
| 17 |
+
Input: [{"date": "2026-01-15", "sum": 150.50, "supplier_id": 5, "user_id": 1, ...}, ...]
|
| 18 |
+
Output: [{"date": str, "sum": float, "supplier_id": int, "user_id": int, "confidence": float}, ...]
|
| 19 |
+
"""
|
| 20 |
+
if not expenses or len(expenses) < 2:
|
| 21 |
+
print(f"[PREDICT] Not enough records: {len(expenses) if expenses else 0}")
|
| 22 |
+
return []
|
| 23 |
+
|
| 24 |
+
# Group by (supplier_id, user_id)
|
| 25 |
+
supplier_user_history = defaultdict(list)
|
| 26 |
+
supplier_freq = defaultdict(int)
|
| 27 |
+
total_records = len(expenses)
|
| 28 |
+
|
| 29 |
+
print(f"[PREDICT] Total records received: {total_records}")
|
| 30 |
+
for i, exp in enumerate(expenses):
|
| 31 |
+
print(f"[PREDICT] [{i+1}] date={exp.get('date')}, sum={exp.get('sum')}, supplier_id={exp.get('supplier_id')}, user_id={exp.get('user_id')}")
|
| 32 |
+
|
| 33 |
+
for exp in expenses:
|
| 34 |
+
key = (exp["supplier_id"], exp["user_id"])
|
| 35 |
+
supplier_user_history[key].append(exp)
|
| 36 |
+
supplier_freq[key] += 1
|
| 37 |
+
|
| 38 |
+
print(f"[PREDICT] Unique (supplier, user) pairs: {len(supplier_user_history)}")
|
| 39 |
+
for key, count in supplier_freq.items():
|
| 40 |
+
pct = count / total_records * 100
|
| 41 |
+
print(f"[PREDICT] supplier_id={key[0]}, user_id={key[1]} → {count} records ({pct:.1f}%)")
|
| 42 |
+
|
| 43 |
+
# Filter: frequency > 50% over 6 months
|
| 44 |
+
candidates = {
|
| 45 |
+
key: records
|
| 46 |
+
for key, records in supplier_user_history.items()
|
| 47 |
+
if supplier_freq[key] / total_records >= 0.5
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
print(f"[PREDICT] Candidates after >50% filter: {len(candidates)}")
|
| 51 |
+
|
| 52 |
+
if not candidates:
|
| 53 |
+
print("[PREDICT] No candidates passed the frequency filter. Returning empty.")
|
| 54 |
+
return []
|
| 55 |
+
|
| 56 |
+
# Analyze each candidate: avg amount, interval, last date
|
| 57 |
+
predictions = []
|
| 58 |
+
|
| 59 |
+
for (supplier_id, user_id), records in candidates.items():
|
| 60 |
+
amounts = [float(r["sum"]) for r in records]
|
| 61 |
+
avg_amount = statistics.mean(amounts)
|
| 62 |
+
|
| 63 |
+
# Calculate interval between transactions (days)
|
| 64 |
+
dates = sorted([datetime.fromisoformat(r["date"]) for r in records])
|
| 65 |
+
if len(dates) >= 2:
|
| 66 |
+
intervals = [(dates[i+1] - dates[i]).days for i in range(len(dates) - 1)]
|
| 67 |
+
avg_interval = statistics.mean(intervals)
|
| 68 |
+
else:
|
| 69 |
+
avg_interval = 30 # default monthly
|
| 70 |
+
|
| 71 |
+
last_date = dates[-1]
|
| 72 |
+
next_predicted_date = (last_date + timedelta(days=avg_interval)).strftime("%Y-%m-%d")
|
| 73 |
+
|
| 74 |
+
# Confidence: higher if more consistent (lower std dev)
|
| 75 |
+
amount_std = statistics.stdev(amounts) if len(amounts) > 1 else 0
|
| 76 |
+
consistency = max(0, 1 - (amount_std / avg_amount)) if avg_amount > 0 else 0.5
|
| 77 |
+
frequency_score = min(supplier_freq[(supplier_id, user_id)] / total_records, 1.0)
|
| 78 |
+
confidence = (consistency + frequency_score) / 2
|
| 79 |
+
|
| 80 |
+
print(
|
| 81 |
+
f"[PREDICT] supplier_id={supplier_id}, user_id={user_id} | "
|
| 82 |
+
f"avg_amount={avg_amount:.2f}, avg_interval={avg_interval:.1f}d, "
|
| 83 |
+
f"last_date={last_date.date()}, next_date={next_predicted_date}, "
|
| 84 |
+
f"consistency={consistency:.2f}, freq_score={frequency_score:.2f}, confidence={confidence:.2f}"
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
predictions.append({
|
| 88 |
+
"date": next_predicted_date,
|
| 89 |
+
"sum": round(avg_amount, 2),
|
| 90 |
+
"supplier_id": supplier_id,
|
| 91 |
+
"user_id": user_id,
|
| 92 |
+
"confidence": round(confidence, 2)
|
| 93 |
+
})
|
| 94 |
+
|
| 95 |
+
# Return top 3 by confidence
|
| 96 |
+
result = sorted(predictions, key=lambda x: x["confidence"], reverse=True)[:3]
|
| 97 |
+
|
| 98 |
+
print(f"[PREDICT] Final top {len(result)} predictions:")
|
| 99 |
+
for i, p in enumerate(result, 1):
|
| 100 |
+
print(f"[PREDICT] #{i}: supplier_id={p['supplier_id']}, user_id={p['user_id']}, date={p['date']}, sum={p['sum']}, confidence={p['confidence']}")
|
| 101 |
+
|
| 102 |
+
return result
|
requirements.txt
CHANGED
|
@@ -9,5 +9,3 @@ python-dateutil
|
|
| 9 |
iuliia
|
| 10 |
scikit-learn
|
| 11 |
sentencepiece
|
| 12 |
-
transformers==4.41.2
|
| 13 |
-
torch==2.3.1
|
|
|
|
| 9 |
iuliia
|
| 10 |
scikit-learn
|
| 11 |
sentencepiece
|
|
|
|
|
|
sql_generator.py
DELETED
|
@@ -1,111 +0,0 @@
|
|
| 1 |
-
from __future__ import annotations
|
| 2 |
-
|
| 3 |
-
import os
|
| 4 |
-
from dataclasses import dataclass
|
| 5 |
-
from typing import Any
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
# Compact Spider-style schema: only business tables, no Laravel internals.
|
| 9 |
-
# Format: table col type , col type | table2 col type , col type
|
| 10 |
-
# Foreign keys annotated inline for model guidance.
|
| 11 |
-
DEFAULT_DB_SCHEMA = (
|
| 12 |
-
"users : id int , name varchar , email varchar , created_at datetime , updated_at datetime | "
|
| 13 |
-
"categories : id int , name varchar , slug varchar , notes text , created_at datetime , updated_at datetime | "
|
| 14 |
-
"suppliers : id int , name varchar , slug varchar , category_id int , created_at datetime , updated_at datetime | "
|
| 15 |
-
"expenses : id int , user_id int , date date , category_id int , supplier_id int , sum numeric , notes text , created_at datetime , updated_at datetime | "
|
| 16 |
-
"debts : id int , date date , user_id int , debt_sum numeric , overpayment_id int , notes text , payment_status varchar , partial_sum numeric , date_paid date , created_at datetime , updated_at datetime | "
|
| 17 |
-
"overpayments : id int , user_id int , sum numeric , notes text , created_at datetime , updated_at datetime | "
|
| 18 |
-
"paid_debts : id int , debt_id int , changed_debt_date date , paid_by_user_id int , payment_status varchar , paid_sum numeric , created_at datetime , updated_at datetime | "
|
| 19 |
-
"expense_change_requests : id int , expense_id int , user_id int , action_type varchar , current_date date , current_user_id int , current_category_id int , current_supplier_id int , current_sum numeric , requested_date date , requested_user_id int , requested_category_id int , requested_supplier_id int , requested_sum numeric , notes text , status varchar , applied_at datetime , created_at datetime , updated_at datetime | "
|
| 20 |
-
"expense_change_request_votes : id int , expense_change_request_id int , user_id int , vote varchar , notes text , created_at datetime , updated_at datetime"
|
| 21 |
-
)
|
| 22 |
-
|
| 23 |
-
_SQL_GENERATOR: Any | None = None
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
@dataclass(frozen=True)
|
| 27 |
-
class SqlGenerationRequest:
|
| 28 |
-
question: str
|
| 29 |
-
limit: int = 200
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
def _get_sql_generator() -> Any:
|
| 33 |
-
global _SQL_GENERATOR
|
| 34 |
-
|
| 35 |
-
if _SQL_GENERATOR is None:
|
| 36 |
-
from transformers import pipeline
|
| 37 |
-
|
| 38 |
-
model_id = os.getenv("SQL_MODEL", "gaussalgo/T5-LM-Large-text2sql-spider")
|
| 39 |
-
_SQL_GENERATOR = pipeline(
|
| 40 |
-
task="text2text-generation",
|
| 41 |
-
model=model_id,
|
| 42 |
-
tokenizer=model_id,
|
| 43 |
-
device=-1,
|
| 44 |
-
)
|
| 45 |
-
|
| 46 |
-
return _SQL_GENERATOR
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
def _build_prompt(payload: SqlGenerationRequest) -> str:
|
| 50 |
-
# gaussalgo/T5-LM-Large-text2sql-spider is trained on Spider benchmark.
|
| 51 |
-
# Expected format: "Question: {q} | {compact_schema}"
|
| 52 |
-
# where schema uses pipe-separator between tables and " : " between table name and columns.
|
| 53 |
-
return f"Question: {payload.question} | {DEFAULT_DB_SCHEMA}"
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
def _normalize_sql(raw_sql: str, limit: int) -> str:
|
| 57 |
-
sql = (raw_sql or "").strip()
|
| 58 |
-
if not sql:
|
| 59 |
-
raise ValueError("SQL model returned an empty result.")
|
| 60 |
-
|
| 61 |
-
if "```" in sql:
|
| 62 |
-
parts = [part.strip() for part in sql.split("```") if part.strip()]
|
| 63 |
-
sql = parts[-1]
|
| 64 |
-
|
| 65 |
-
upper_sql = sql.upper()
|
| 66 |
-
sql_start = upper_sql.find("SELECT")
|
| 67 |
-
if sql_start == -1:
|
| 68 |
-
raise ValueError("Generated SQL is not a SELECT query.")
|
| 69 |
-
|
| 70 |
-
sql = sql[sql_start:]
|
| 71 |
-
if ";" in sql:
|
| 72 |
-
sql = sql.split(";", 1)[0].strip()
|
| 73 |
-
|
| 74 |
-
upper_sql = sql.upper()
|
| 75 |
-
forbidden = ("INSERT ", "UPDATE ", "DELETE ", "DROP ", "ALTER ", "PRAGMA ", "ATTACH ", "CREATE ", "REPLACE ")
|
| 76 |
-
if any(keyword in upper_sql for keyword in forbidden):
|
| 77 |
-
raise ValueError("Generated SQL contains forbidden statements.")
|
| 78 |
-
|
| 79 |
-
if not upper_sql.startswith("SELECT "):
|
| 80 |
-
raise ValueError("Only SELECT queries are allowed.")
|
| 81 |
-
|
| 82 |
-
aggregate_markers = ("COUNT(", "SUM(", "AVG(", "MIN(", "MAX(")
|
| 83 |
-
has_limit = " LIMIT " in upper_sql
|
| 84 |
-
if not has_limit and not any(marker in upper_sql for marker in aggregate_markers):
|
| 85 |
-
sql = f"{sql} LIMIT {limit}"
|
| 86 |
-
|
| 87 |
-
return sql
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
def generate_sql(question: str, limit: int = 200) -> str:
|
| 91 |
-
clean_question = (question or "").strip()
|
| 92 |
-
if not clean_question:
|
| 93 |
-
raise ValueError("Field 'query' is required.")
|
| 94 |
-
|
| 95 |
-
payload = SqlGenerationRequest(
|
| 96 |
-
question=clean_question,
|
| 97 |
-
limit=limit,
|
| 98 |
-
)
|
| 99 |
-
|
| 100 |
-
generator = _get_sql_generator()
|
| 101 |
-
prompt = _build_prompt(payload)
|
| 102 |
-
result = generator(
|
| 103 |
-
prompt,
|
| 104 |
-
max_new_tokens=512,
|
| 105 |
-
do_sample=False,
|
| 106 |
-
num_beams=4,
|
| 107 |
-
truncation=True,
|
| 108 |
-
)
|
| 109 |
-
|
| 110 |
-
generated_text = result[0].get("generated_text", "") if result else ""
|
| 111 |
-
return _normalize_sql(generated_text, limit=payload.limit)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|