palexis3 commited on
Commit
ebcf5e0
1 Parent(s): 417abff

Add income statement answer endpoint

Browse files
app/api/routers/income_statement.py CHANGED
@@ -7,6 +7,9 @@ from app.schema.index import IncomeStatementCreateRequest, IncomeStatementRespon
7
  from app.engine.postgresdb import get_db_session
8
  from app.service.income_statement import call_llm_to_create_income_statement
9
 
 
 
 
10
  income_statement_router = r = APIRouter(prefix="/api/v1/income_statement", tags=["income_statement"])
11
 
12
 
@@ -48,6 +51,30 @@ async def get_income_statements(
48
  return result
49
 
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  @r.get(
52
  "/report/{report_id}",
53
  response_model=IncomeStatementResponse,
 
7
  from app.engine.postgresdb import get_db_session
8
  from app.service.income_statement import call_llm_to_create_income_statement
9
 
10
+ from app.service.query_rag import answer_query, fetch_income_statement_documents
11
+
12
+
13
  income_statement_router = r = APIRouter(prefix="/api/v1/income_statement", tags=["income_statement"])
14
 
15
 
 
51
  return result
52
 
53
 
54
+ @r.get(
55
+ "/answer/{user_id}/{query}",
56
+ responses={
57
+ 200: {"description": "Query answered"},
58
+ 500: {"description": "Internal server error"},
59
+ },
60
+ )
61
+ async def answer_transactions_query(user_id: int, query: str, db: AsyncSession = Depends(get_db_session)):
62
+ """
63
+ Answer queries based on income statements
64
+ """
65
+ try:
66
+ result = await IncomeStatementModel.get_by_user(db, user_id)
67
+ if len(result) == 0:
68
+ raise HTTPException(status_code=500, detail="No transactions found for this user")
69
+ document_splits = await fetch_income_statement_documents(result)
70
+ answer = await answer_query(document_splits, query, user_id)
71
+ return answer
72
+
73
+ except Exception as e:
74
+ raise HTTPException(status_code=500, detail=f"/answer endpoint error: {str(e)}")
75
+
76
+
77
+
78
  @r.get(
79
  "/report/{report_id}",
80
  response_model=IncomeStatementResponse,
app/api/routers/transaction.py CHANGED
@@ -53,7 +53,7 @@ async def get_transactions(user_id: int, db: AsyncSession = Depends(get_db_sessi
53
 
54
 
55
  @r.get(
56
- "/transaction_answer/{user_id}/{query}",
57
  responses={
58
  200: {"description": "Query answered"},
59
  500: {"description": "Internal server error"},
@@ -61,7 +61,7 @@ async def get_transactions(user_id: int, db: AsyncSession = Depends(get_db_sessi
61
  )
62
  async def answer_transactions_query(user_id: int, query: str, db: AsyncSession = Depends(get_db_session)):
63
  """
64
- Retrieve all transactions.
65
  """
66
  try:
67
  result = await TransactionModel.get_by_user(db, user_id)
@@ -73,4 +73,4 @@ async def answer_transactions_query(user_id: int, query: str, db: AsyncSession =
73
  return answer
74
 
75
  except Exception as e:
76
- raise HTTPException(status_code=500, detail=f"answer_transactions_query error: {str(e)}")
 
53
 
54
 
55
  @r.get(
56
+ "/answer/{user_id}/{query}",
57
  responses={
58
  200: {"description": "Query answered"},
59
  500: {"description": "Internal server error"},
 
61
  )
62
  async def answer_transactions_query(user_id: int, query: str, db: AsyncSession = Depends(get_db_session)):
63
  """
64
+ Answer queries based on transactions.
65
  """
66
  try:
67
  result = await TransactionModel.get_by_user(db, user_id)
 
73
  return answer
74
 
75
  except Exception as e:
76
+ raise HTTPException(status_code=500, detail=f"/answer endpoint error: {str(e)}")
app/service/query_rag.py CHANGED
@@ -11,6 +11,7 @@ from fastapi import HTTPException
11
 
12
  from typing import List
13
  from app.model.transaction import Transaction
 
14
 
15
  import os
16
 
@@ -29,6 +30,21 @@ async def fetch_transaction_documents(transactions: List[Transaction]) -> List[D
29
  raise HTTPException(status_code = 500, detail=f"fetch_transaction_documents error: {str(e)}")
30
 
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  async def answer_query(document_splits: List[Document], query: str, user_id: int) -> str:
33
  """Creates an embedding of the transactions table and then returns the answer for the given query.
34
  Args:
 
11
 
12
  from typing import List
13
  from app.model.transaction import Transaction
14
+ from app.model.income_statement import IncomeStatement
15
 
16
  import os
17
 
 
30
  raise HTTPException(status_code = 500, detail=f"fetch_transaction_documents error: {str(e)}")
31
 
32
 
33
+ async def fetch_income_statement_documents(income_statements: List[IncomeStatement]) -> List[Document]:
34
+ try:
35
+ document = ''.join(str(row)+'\n' for row in income_statements)
36
+
37
+ page_splitter = CharacterTextSplitter(chunk_size=100, chunk_overlap=0)
38
+ pages = page_splitter.split_text(document)
39
+
40
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=100, chunk_overlap=10)
41
+ document_splits = text_splitter.create_documents(pages)
42
+
43
+ return document_splits
44
+ except Exception as e:
45
+ raise HTTPException(status_code = 500, detail=f"fetch_income_statement_documents error: {str(e)}")
46
+
47
+
48
  async def answer_query(document_splits: List[Document], query: str, user_id: int) -> str:
49
  """Creates an embedding of the transactions table and then returns the answer for the given query.
50
  Args:
tests/test_transactions.py CHANGED
@@ -10,100 +10,21 @@ from app.schema.index import TransactionType, TransactionCreate
10
  from sqlalchemy.ext.asyncio import AsyncSession
11
  from app.engine.postgresdb import get_db_session
12
  from app.model.user import User
 
13
 
14
- def get_fake_transactions(user_id: int) -> List[TransactionCreate]:
15
- return [
16
- TransactionCreate(
17
- user_id=user_id,
18
- transaction_date=datetime(2022, 1, 1),
19
- category="category",
20
- name_description="name_description",
21
- amount=1.0,
22
- type=TransactionType.EXPENSE,
23
- ),
24
- TransactionCreate(
25
- user_id=user_id,
26
- transaction_date=datetime(2022, 1, 2),
27
- category="category",
28
- name_description="name_description",
29
- amount=2.0,
30
- type=TransactionType.EXPENSE,
31
- ),
32
- TransactionCreate(
33
- user_id=user_id,
34
- transaction_date=datetime(2022, 1, 3),
35
- category="category",
36
- name_description="name_description",
37
- amount=3.0,
38
- type=TransactionType.INCOME,
39
- ),
40
- TransactionCreate(
41
- user_id=user_id,
42
- transaction_date=datetime(2022, 1, 4),
43
- category="category",
44
- name_description="name_description",
45
- amount=4.0,
46
- type=TransactionType.INCOME,
47
- ),
48
- TransactionCreate(
49
- user_id=user_id,
50
- transaction_date=datetime(2022, 1, 5),
51
- category="category",
52
- name_description="name_description",
53
- amount=5.0,
54
- type=TransactionType.EXPENSE,
55
- ),
56
- TransactionCreate(
57
- user_id=user_id,
58
- transaction_date=datetime(2022, 1, 6),
59
- category="category",
60
- name_description="name_description",
61
- amount=6.0,
62
- type=TransactionType.EXPENSE,
63
- ),
64
- TransactionCreate(
65
- user_id=user_id,
66
- transaction_date=datetime(2022, 1, 7),
67
- category="category",
68
- name_description="name_description",
69
- amount=7.0,
70
- type=TransactionType.INCOME,
71
- ),
72
- TransactionCreate(
73
- user_id=user_id,
74
- transaction_date=datetime(2022, 1, 8),
75
- category="category",
76
- name_description="name_description",
77
- amount=8.0,
78
- type=TransactionType.INCOME,
79
- ),
80
- TransactionCreate(
81
- user_id=user_id,
82
- transaction_date=datetime(2022, 1, 9),
83
- category="category",
84
- name_description="name_description",
85
- amount=9.0,
86
- type=TransactionType.EXPENSE,
87
- ),
88
- TransactionCreate(
89
- user_id=user_id,
90
- transaction_date=datetime(2022, 1, 10),
91
- category="category",
92
- name_description="name_description",
93
- amount=10.0,
94
- type=TransactionType.EXPENSE,
95
- ),
96
- ]
97
 
98
  @pytest.mark.asyncio
99
  async def test_transactions(client: TestClient, get_db_session_fixture: AsyncSession) -> None:
100
 
101
  session_override = get_db_session_fixture
 
102
  user = await User.create(session_override, name="user", email="email", hashed_password="password")
103
 
104
  fake_transactions = get_fake_transactions(user.id)
 
105
  await Transaction.bulk_create(session_override, fake_transactions)
106
 
 
107
  response = client.get("/api/v1/transactions/1")
108
  assert response.status_code == 200
109
  assert len(response.json()) == 10
 
10
  from sqlalchemy.ext.asyncio import AsyncSession
11
  from app.engine.postgresdb import get_db_session
12
  from app.model.user import User
13
+ from tests.utils import get_fake_transactions
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  @pytest.mark.asyncio
17
  async def test_transactions(client: TestClient, get_db_session_fixture: AsyncSession) -> None:
18
 
19
  session_override = get_db_session_fixture
20
+ # 1. Create a user
21
  user = await User.create(session_override, name="user", email="email", hashed_password="password")
22
 
23
  fake_transactions = get_fake_transactions(user.id)
24
+ # 2. Create a bunch of transactions
25
  await Transaction.bulk_create(session_override, fake_transactions)
26
 
27
+ # 3. Verify that the transactions are returned
28
  response = client.get("/api/v1/transactions/1")
29
  assert response.status_code == 200
30
  assert len(response.json()) == 10