Spaces:
Sleeping
Sleeping
Add income statement answer endpoint
Browse files- app/api/routers/income_statement.py +27 -0
- app/api/routers/transaction.py +3 -3
- app/service/query_rag.py +16 -0
- tests/test_transactions.py +4 -83
app/api/routers/income_statement.py
CHANGED
@@ -7,6 +7,9 @@ from app.schema.index import IncomeStatementCreateRequest, IncomeStatementRespon
|
|
7 |
from app.engine.postgresdb import get_db_session
|
8 |
from app.service.income_statement import call_llm_to_create_income_statement
|
9 |
|
|
|
|
|
|
|
10 |
income_statement_router = r = APIRouter(prefix="/api/v1/income_statement", tags=["income_statement"])
|
11 |
|
12 |
|
@@ -48,6 +51,30 @@ async def get_income_statements(
|
|
48 |
return result
|
49 |
|
50 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
@r.get(
|
52 |
"/report/{report_id}",
|
53 |
response_model=IncomeStatementResponse,
|
|
|
7 |
from app.engine.postgresdb import get_db_session
|
8 |
from app.service.income_statement import call_llm_to_create_income_statement
|
9 |
|
10 |
+
from app.service.query_rag import answer_query, fetch_income_statement_documents
|
11 |
+
|
12 |
+
|
13 |
income_statement_router = r = APIRouter(prefix="/api/v1/income_statement", tags=["income_statement"])
|
14 |
|
15 |
|
|
|
51 |
return result
|
52 |
|
53 |
|
54 |
+
@r.get(
|
55 |
+
"/answer/{user_id}/{query}",
|
56 |
+
responses={
|
57 |
+
200: {"description": "Query answered"},
|
58 |
+
500: {"description": "Internal server error"},
|
59 |
+
},
|
60 |
+
)
|
61 |
+
async def answer_transactions_query(user_id: int, query: str, db: AsyncSession = Depends(get_db_session)):
|
62 |
+
"""
|
63 |
+
Answer queries based on income statements
|
64 |
+
"""
|
65 |
+
try:
|
66 |
+
result = await IncomeStatementModel.get_by_user(db, user_id)
|
67 |
+
if len(result) == 0:
|
68 |
+
raise HTTPException(status_code=500, detail="No transactions found for this user")
|
69 |
+
document_splits = await fetch_income_statement_documents(result)
|
70 |
+
answer = await answer_query(document_splits, query, user_id)
|
71 |
+
return answer
|
72 |
+
|
73 |
+
except Exception as e:
|
74 |
+
raise HTTPException(status_code=500, detail=f"/answer endpoint error: {str(e)}")
|
75 |
+
|
76 |
+
|
77 |
+
|
78 |
@r.get(
|
79 |
"/report/{report_id}",
|
80 |
response_model=IncomeStatementResponse,
|
app/api/routers/transaction.py
CHANGED
@@ -53,7 +53,7 @@ async def get_transactions(user_id: int, db: AsyncSession = Depends(get_db_sessi
|
|
53 |
|
54 |
|
55 |
@r.get(
|
56 |
-
"/
|
57 |
responses={
|
58 |
200: {"description": "Query answered"},
|
59 |
500: {"description": "Internal server error"},
|
@@ -61,7 +61,7 @@ async def get_transactions(user_id: int, db: AsyncSession = Depends(get_db_sessi
|
|
61 |
)
|
62 |
async def answer_transactions_query(user_id: int, query: str, db: AsyncSession = Depends(get_db_session)):
|
63 |
"""
|
64 |
-
|
65 |
"""
|
66 |
try:
|
67 |
result = await TransactionModel.get_by_user(db, user_id)
|
@@ -73,4 +73,4 @@ async def answer_transactions_query(user_id: int, query: str, db: AsyncSession =
|
|
73 |
return answer
|
74 |
|
75 |
except Exception as e:
|
76 |
-
raise HTTPException(status_code=500, detail=f"
|
|
|
53 |
|
54 |
|
55 |
@r.get(
|
56 |
+
"/answer/{user_id}/{query}",
|
57 |
responses={
|
58 |
200: {"description": "Query answered"},
|
59 |
500: {"description": "Internal server error"},
|
|
|
61 |
)
|
62 |
async def answer_transactions_query(user_id: int, query: str, db: AsyncSession = Depends(get_db_session)):
|
63 |
"""
|
64 |
+
Answer queries based on transactions.
|
65 |
"""
|
66 |
try:
|
67 |
result = await TransactionModel.get_by_user(db, user_id)
|
|
|
73 |
return answer
|
74 |
|
75 |
except Exception as e:
|
76 |
+
raise HTTPException(status_code=500, detail=f"/answer endpoint error: {str(e)}")
|
app/service/query_rag.py
CHANGED
@@ -11,6 +11,7 @@ from fastapi import HTTPException
|
|
11 |
|
12 |
from typing import List
|
13 |
from app.model.transaction import Transaction
|
|
|
14 |
|
15 |
import os
|
16 |
|
@@ -29,6 +30,21 @@ async def fetch_transaction_documents(transactions: List[Transaction]) -> List[D
|
|
29 |
raise HTTPException(status_code = 500, detail=f"fetch_transaction_documents error: {str(e)}")
|
30 |
|
31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
async def answer_query(document_splits: List[Document], query: str, user_id: int) -> str:
|
33 |
"""Creates an embedding of the transactions table and then returns the answer for the given query.
|
34 |
Args:
|
|
|
11 |
|
12 |
from typing import List
|
13 |
from app.model.transaction import Transaction
|
14 |
+
from app.model.income_statement import IncomeStatement
|
15 |
|
16 |
import os
|
17 |
|
|
|
30 |
raise HTTPException(status_code = 500, detail=f"fetch_transaction_documents error: {str(e)}")
|
31 |
|
32 |
|
33 |
+
async def fetch_income_statement_documents(income_statements: List[IncomeStatement]) -> List[Document]:
|
34 |
+
try:
|
35 |
+
document = ''.join(str(row)+'\n' for row in income_statements)
|
36 |
+
|
37 |
+
page_splitter = CharacterTextSplitter(chunk_size=100, chunk_overlap=0)
|
38 |
+
pages = page_splitter.split_text(document)
|
39 |
+
|
40 |
+
text_splitter = RecursiveCharacterTextSplitter(chunk_size=100, chunk_overlap=10)
|
41 |
+
document_splits = text_splitter.create_documents(pages)
|
42 |
+
|
43 |
+
return document_splits
|
44 |
+
except Exception as e:
|
45 |
+
raise HTTPException(status_code = 500, detail=f"fetch_income_statement_documents error: {str(e)}")
|
46 |
+
|
47 |
+
|
48 |
async def answer_query(document_splits: List[Document], query: str, user_id: int) -> str:
|
49 |
"""Creates an embedding of the transactions table and then returns the answer for the given query.
|
50 |
Args:
|
tests/test_transactions.py
CHANGED
@@ -10,100 +10,21 @@ from app.schema.index import TransactionType, TransactionCreate
|
|
10 |
from sqlalchemy.ext.asyncio import AsyncSession
|
11 |
from app.engine.postgresdb import get_db_session
|
12 |
from app.model.user import User
|
|
|
13 |
|
14 |
-
def get_fake_transactions(user_id: int) -> List[TransactionCreate]:
|
15 |
-
return [
|
16 |
-
TransactionCreate(
|
17 |
-
user_id=user_id,
|
18 |
-
transaction_date=datetime(2022, 1, 1),
|
19 |
-
category="category",
|
20 |
-
name_description="name_description",
|
21 |
-
amount=1.0,
|
22 |
-
type=TransactionType.EXPENSE,
|
23 |
-
),
|
24 |
-
TransactionCreate(
|
25 |
-
user_id=user_id,
|
26 |
-
transaction_date=datetime(2022, 1, 2),
|
27 |
-
category="category",
|
28 |
-
name_description="name_description",
|
29 |
-
amount=2.0,
|
30 |
-
type=TransactionType.EXPENSE,
|
31 |
-
),
|
32 |
-
TransactionCreate(
|
33 |
-
user_id=user_id,
|
34 |
-
transaction_date=datetime(2022, 1, 3),
|
35 |
-
category="category",
|
36 |
-
name_description="name_description",
|
37 |
-
amount=3.0,
|
38 |
-
type=TransactionType.INCOME,
|
39 |
-
),
|
40 |
-
TransactionCreate(
|
41 |
-
user_id=user_id,
|
42 |
-
transaction_date=datetime(2022, 1, 4),
|
43 |
-
category="category",
|
44 |
-
name_description="name_description",
|
45 |
-
amount=4.0,
|
46 |
-
type=TransactionType.INCOME,
|
47 |
-
),
|
48 |
-
TransactionCreate(
|
49 |
-
user_id=user_id,
|
50 |
-
transaction_date=datetime(2022, 1, 5),
|
51 |
-
category="category",
|
52 |
-
name_description="name_description",
|
53 |
-
amount=5.0,
|
54 |
-
type=TransactionType.EXPENSE,
|
55 |
-
),
|
56 |
-
TransactionCreate(
|
57 |
-
user_id=user_id,
|
58 |
-
transaction_date=datetime(2022, 1, 6),
|
59 |
-
category="category",
|
60 |
-
name_description="name_description",
|
61 |
-
amount=6.0,
|
62 |
-
type=TransactionType.EXPENSE,
|
63 |
-
),
|
64 |
-
TransactionCreate(
|
65 |
-
user_id=user_id,
|
66 |
-
transaction_date=datetime(2022, 1, 7),
|
67 |
-
category="category",
|
68 |
-
name_description="name_description",
|
69 |
-
amount=7.0,
|
70 |
-
type=TransactionType.INCOME,
|
71 |
-
),
|
72 |
-
TransactionCreate(
|
73 |
-
user_id=user_id,
|
74 |
-
transaction_date=datetime(2022, 1, 8),
|
75 |
-
category="category",
|
76 |
-
name_description="name_description",
|
77 |
-
amount=8.0,
|
78 |
-
type=TransactionType.INCOME,
|
79 |
-
),
|
80 |
-
TransactionCreate(
|
81 |
-
user_id=user_id,
|
82 |
-
transaction_date=datetime(2022, 1, 9),
|
83 |
-
category="category",
|
84 |
-
name_description="name_description",
|
85 |
-
amount=9.0,
|
86 |
-
type=TransactionType.EXPENSE,
|
87 |
-
),
|
88 |
-
TransactionCreate(
|
89 |
-
user_id=user_id,
|
90 |
-
transaction_date=datetime(2022, 1, 10),
|
91 |
-
category="category",
|
92 |
-
name_description="name_description",
|
93 |
-
amount=10.0,
|
94 |
-
type=TransactionType.EXPENSE,
|
95 |
-
),
|
96 |
-
]
|
97 |
|
98 |
@pytest.mark.asyncio
|
99 |
async def test_transactions(client: TestClient, get_db_session_fixture: AsyncSession) -> None:
|
100 |
|
101 |
session_override = get_db_session_fixture
|
|
|
102 |
user = await User.create(session_override, name="user", email="email", hashed_password="password")
|
103 |
|
104 |
fake_transactions = get_fake_transactions(user.id)
|
|
|
105 |
await Transaction.bulk_create(session_override, fake_transactions)
|
106 |
|
|
|
107 |
response = client.get("/api/v1/transactions/1")
|
108 |
assert response.status_code == 200
|
109 |
assert len(response.json()) == 10
|
|
|
10 |
from sqlalchemy.ext.asyncio import AsyncSession
|
11 |
from app.engine.postgresdb import get_db_session
|
12 |
from app.model.user import User
|
13 |
+
from tests.utils import get_fake_transactions
|
14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
@pytest.mark.asyncio
|
17 |
async def test_transactions(client: TestClient, get_db_session_fixture: AsyncSession) -> None:
|
18 |
|
19 |
session_override = get_db_session_fixture
|
20 |
+
# 1. Create a user
|
21 |
user = await User.create(session_override, name="user", email="email", hashed_password="password")
|
22 |
|
23 |
fake_transactions = get_fake_transactions(user.id)
|
24 |
+
# 2. Create a bunch of transactions
|
25 |
await Transaction.bulk_create(session_override, fake_transactions)
|
26 |
|
27 |
+
# 3. Verify that the transactions are returned
|
28 |
response = client.get("/api/v1/transactions/1")
|
29 |
assert response.status_code == 200
|
30 |
assert len(response.json()) == 10
|