add llm service for generating income statements

#14
app/api/routers/file_upload.py CHANGED
@@ -1,11 +1,14 @@
1
  from typing import Annotated
2
- from fastapi import APIRouter, UploadFile
3
  from app.categorization.file_processing import process_file, save_results
4
  from app.schema.index import FileUploadCreate
5
  import asyncio
6
  import os
7
  import csv
8
 
 
 
 
9
  file_upload_router = r = APIRouter(prefix="/api/v1/file_upload", tags=["file_upload"])
10
 
11
  @r.post(
@@ -16,7 +19,7 @@ file_upload_router = r = APIRouter(prefix="/api/v1/file_upload", tags=["file_upl
16
  500: {"description": "Internal server error"},
17
  },
18
  )
19
- async def create_file(input_file: UploadFile):
20
  try:
21
  # Create directory to store all uploaded .csv files
22
  file_upload_directory_path = "data/tx_data/input"
@@ -31,10 +34,10 @@ async def create_file(input_file: UploadFile):
31
  # With the newly created file and it's path, process and save it for embedding
32
  processed_file = process_file(os.path.realpath(input_file.filename))
33
  result = await asyncio.gather(processed_file)
34
- save_results(result)
35
 
36
  except Exception:
37
  return {"message": "There was an error uploading this file. Ensure you have a .csv file with the following columns:"
38
- "\n source, date, type, category, description, amount"}
39
 
40
  return {"message": f"Successfully uploaded {input_file.filename}"}
 
1
  from typing import Annotated
2
+ from fastapi import APIRouter, UploadFile, Depends
3
  from app.categorization.file_processing import process_file, save_results
4
  from app.schema.index import FileUploadCreate
5
  import asyncio
6
  import os
7
  import csv
8
 
9
+ from app.engine.postgresdb import get_db_session
10
+ from sqlalchemy.ext.asyncio import AsyncSession
11
+
12
  file_upload_router = r = APIRouter(prefix="/api/v1/file_upload", tags=["file_upload"])
13
 
14
  @r.post(
 
19
  500: {"description": "Internal server error"},
20
  },
21
  )
22
+ async def create_file(input_file: UploadFile, db: AsyncSession = Depends(get_db_session)):
23
  try:
24
  # Create directory to store all uploaded .csv files
25
  file_upload_directory_path = "data/tx_data/input"
 
34
  # With the newly created file and it's path, process and save it for embedding
35
  processed_file = process_file(os.path.realpath(input_file.filename))
36
  result = await asyncio.gather(processed_file)
37
+ await save_results(db, result)
38
 
39
  except Exception:
40
  return {"message": "There was an error uploading this file. Ensure you have a .csv file with the following columns:"
41
+ "\n transaction_date, type, category, name_description, amount"}
42
 
43
  return {"message": f"Successfully uploaded {input_file.filename}"}
app/api/routers/income_statement.py CHANGED
@@ -3,7 +3,7 @@ from fastapi import APIRouter, Depends, HTTPException, status
3
  from sqlalchemy.ext.asyncio import AsyncSession
4
  from app.model.transaction import Transaction as TransactionModel
5
  from app.model.income_statement import IncomeStatement as IncomeStatementModel
6
- from app.schema.index import IncomeStatementCreate, IncomeStatementResponse
7
  from app.engine.postgresdb import get_db_session
8
  from app.service.income_statement import call_llm_to_create_income_statement
9
 
@@ -18,7 +18,7 @@ income_statement_router = r = APIRouter(prefix="/api/v1/income_statement", tags=
18
  500: {"description": "Internal server error"},
19
  },
20
  )
21
- async def create_income_statement(payload: IncomeStatementCreate, db: AsyncSession = Depends(get_db_session)) -> None:
22
  try:
23
  await call_llm_to_create_income_statement(payload, db)
24
 
@@ -27,7 +27,7 @@ async def create_income_statement(payload: IncomeStatementCreate, db: AsyncSessi
27
 
28
 
29
  @r.get(
30
- "/{user_id}",
31
  response_model=List[IncomeStatementResponse],
32
  responses={
33
  200: {"description": "New user created"},
@@ -43,14 +43,13 @@ async def get_income_statements(
43
  Retrieve all income statements.
44
  """
45
  result = await IncomeStatementModel.get_by_user(db, user_id)
46
- all_rows = result.all()
47
- if len(all_rows) == 0:
48
  raise HTTPException(status_code=status.HTTP_204_NO_CONTENT, detail="No income statements found for this user")
49
- return all_rows
50
 
51
 
52
  @r.get(
53
- "/{report_id}",
54
  response_model=IncomeStatementResponse,
55
  responses={
56
  200: {"description": "Income statement found"},
 
3
  from sqlalchemy.ext.asyncio import AsyncSession
4
  from app.model.transaction import Transaction as TransactionModel
5
  from app.model.income_statement import IncomeStatement as IncomeStatementModel
6
+ from app.schema.index import IncomeStatementCreateRequest, IncomeStatementResponse
7
  from app.engine.postgresdb import get_db_session
8
  from app.service.income_statement import call_llm_to_create_income_statement
9
 
 
18
  500: {"description": "Internal server error"},
19
  },
20
  )
21
+ async def create_income_statement(payload: IncomeStatementCreateRequest, db: AsyncSession = Depends(get_db_session)) -> None:
22
  try:
23
  await call_llm_to_create_income_statement(payload, db)
24
 
 
27
 
28
 
29
  @r.get(
30
+ "/user/{user_id}",
31
  response_model=List[IncomeStatementResponse],
32
  responses={
33
  200: {"description": "New user created"},
 
43
  Retrieve all income statements.
44
  """
45
  result = await IncomeStatementModel.get_by_user(db, user_id)
46
+ if len(result) == 0:
 
47
  raise HTTPException(status_code=status.HTTP_204_NO_CONTENT, detail="No income statements found for this user")
48
+ return result
49
 
50
 
51
  @r.get(
52
+ "/report/{report_id}",
53
  response_model=IncomeStatementResponse,
54
  responses={
55
  200: {"description": "Income statement found"},
app/api/routers/user.py CHANGED
@@ -27,7 +27,7 @@ async def create_user(user: UserCreate, db: AsyncSession = Depends(get_db_sessio
27
  if db_user and not db_user.is_deleted:
28
  raise HTTPException(status_code=409, detail="User already exists")
29
 
30
- await UserModel.create(db, **user.dict())
31
  user = await UserModel.get(db, email=user.email)
32
  return user
33
  except Exception as e:
@@ -64,7 +64,7 @@ async def update_user(email: str, user_payload: UserUpdate, db: AsyncSession = D
64
  user = await UserModel.get(db, email=email)
65
  if not user:
66
  raise HTTPException(status_code=404, detail="User not found")
67
- await UserModel.update(db, id=user.id, **user_payload.dict())
68
  user = await UserModel.get(db, email=email)
69
  return user
70
  except Exception as e:
 
27
  if db_user and not db_user.is_deleted:
28
  raise HTTPException(status_code=409, detail="User already exists")
29
 
30
+ await UserModel.create(db, **user.model_dump())
31
  user = await UserModel.get(db, email=user.email)
32
  return user
33
  except Exception as e:
 
64
  user = await UserModel.get(db, email=email)
65
  if not user:
66
  raise HTTPException(status_code=404, detail="User not found")
67
+ await UserModel.update(db, id=user.id, **user_payload.model_dump())
68
  user = await UserModel.get(db, email=email)
69
  return user
70
  except Exception as e:
app/categorization/file_processing.py CHANGED
@@ -14,6 +14,8 @@ from app.categorization.config import RESULT_OUTPUT_FILE, CATEGORY_REFERENCE_OUT
14
  from app.model.transaction import Transaction
15
  from app.schema.index import TransactionCreate
16
 
 
 
17
 
18
  # Read file and process it (e.g. categorize transactions)
19
  async def process_file(file_path: str) -> Dict[str, Union[str, pd.DataFrame]]:
@@ -71,7 +73,7 @@ def standardize_csv_file(file_path: str) -> pd.DataFrame:
71
  return tx_list
72
 
73
 
74
- async def save_results(results: List) -> None:
75
  """
76
  Merge all interim results in the input folder and write the merged results to the output file.
77
 
@@ -104,7 +106,7 @@ async def save_results(results: List) -> None:
104
  # Save to database
105
  # FIXME: get user_id from session
106
  txn_list_to_save = [TransactionCreate(**row.to_dict(), user_id=1) for _, row in tx_list.iterrows()]
107
- await Transaction.bulk_create(txn_list_to_save)
108
 
109
  new_ref_data = tx_list[["name/description", "category"]]
110
  if os.path.exists(CATEGORY_REFERENCE_OUTPUT_FILE):
 
14
  from app.model.transaction import Transaction
15
  from app.schema.index import TransactionCreate
16
 
17
+ from sqlalchemy.ext.asyncio import AsyncSession
18
+
19
 
20
  # Read file and process it (e.g. categorize transactions)
21
  async def process_file(file_path: str) -> Dict[str, Union[str, pd.DataFrame]]:
 
73
  return tx_list
74
 
75
 
76
+ async def save_results(db: AsyncSession, results: List) -> None:
77
  """
78
  Merge all interim results in the input folder and write the merged results to the output file.
79
 
 
106
  # Save to database
107
  # FIXME: get user_id from session
108
  txn_list_to_save = [TransactionCreate(**row.to_dict(), user_id=1) for _, row in tx_list.iterrows()]
109
+ await Transaction.bulk_create(db, txn_list_to_save)
110
 
111
  new_ref_data = tx_list[["name/description", "category"]]
112
  if os.path.exists(CATEGORY_REFERENCE_OUTPUT_FILE):
app/model/income_statement.py CHANGED
@@ -8,6 +8,7 @@ from sqlalchemy.dialects.postgresql import JSON
8
 
9
  from app.model.base import BaseModel
10
  from app.engine.postgresdb import Base
 
11
 
12
 
13
  class IncomeStatement(Base, BaseModel):
@@ -21,19 +22,22 @@ class IncomeStatement(Base, BaseModel):
21
  user_id: Mapped[int] = mapped_column(ForeignKey("users.id"))
22
  user = relationship("User", back_populates="income_statements")
23
 
 
 
 
24
  @classmethod
25
- def create(cls: "type[IncomeStatement]", db: AsyncSession, **kwargs) -> "IncomeStatement":
26
- query = sql.insert(cls).values(**kwargs)
27
- income_statements = db.execute(query)
28
- income_statement = income_statements.first()
29
- db.commit()
30
  return income_statement
31
 
32
  @classmethod
33
  async def get_by_user(cls: "type[IncomeStatement]", db: AsyncSession, user_id: int) -> "List[IncomeStatement]":
34
  query = sql.select(cls).where(cls.user_id == user_id)
35
  income_statements = await db.scalars(query)
36
- return income_statements
37
 
38
  @classmethod
39
  async def get(cls: "type[IncomeStatement]", db: AsyncSession, id: int) -> "IncomeStatement":
 
8
 
9
  from app.model.base import BaseModel
10
  from app.engine.postgresdb import Base
11
+ from app.schema.index import IncomeStatement as IncomeStatementSchema
12
 
13
 
14
  class IncomeStatement(Base, BaseModel):
 
22
  user_id: Mapped[int] = mapped_column(ForeignKey("users.id"))
23
  user = relationship("User", back_populates="income_statements")
24
 
25
+ def __str__(self) -> str:
26
+ return f"IncomeStatement(id={self.id}, user_id={self.user_id}, date_from={self.date_from}, date_to={self.date_to}, income={self.income}, expenses={self.expenses})"
27
+
28
  @classmethod
29
+ async def create(cls: "type[IncomeStatement]", db: AsyncSession, **kwargs: IncomeStatementSchema) -> "IncomeStatement":
30
+ income_statement = cls(**kwargs)
31
+ db.add(income_statement)
32
+ await db.commit()
33
+ await db.refresh(income_statement)
34
  return income_statement
35
 
36
  @classmethod
37
  async def get_by_user(cls: "type[IncomeStatement]", db: AsyncSession, user_id: int) -> "List[IncomeStatement]":
38
  query = sql.select(cls).where(cls.user_id == user_id)
39
  income_statements = await db.scalars(query)
40
+ return income_statements.all()
41
 
42
  @classmethod
43
  async def get(cls: "type[IncomeStatement]", db: AsyncSession, id: int) -> "IncomeStatement":
app/model/transaction.py CHANGED
@@ -21,6 +21,9 @@ class Transaction(Base, BaseModel):
21
  user_id = mapped_column(ForeignKey("users.id"))
22
  user = relationship("User", back_populates="transactions")
23
 
 
 
 
24
  @classmethod
25
  async def create(cls: "type[Transaction]", db: AsyncSession, **kwargs) -> "Transaction":
26
  query = sql.insert(cls).values(**kwargs)
@@ -31,7 +34,8 @@ class Transaction(Base, BaseModel):
31
 
32
  @classmethod
33
  async def bulk_create(cls: "type[Transaction]", db: AsyncSession, transactions: List[TransactionCreate]) -> None:
34
- query = sql.insert(cls).values(transactions)
 
35
  await db.execute(query)
36
  await db.commit()
37
 
@@ -55,4 +59,4 @@ class Transaction(Base, BaseModel):
55
  ) -> "List[Transaction]":
56
  query = sql.select(cls).where(cls.user_id == user_id).where(cls.transaction_date.between(start_date, end_date))
57
  transactions = await db.scalars(query)
58
- return transactions
 
21
  user_id = mapped_column(ForeignKey("users.id"))
22
  user = relationship("User", back_populates="transactions")
23
 
24
+ def __str__(self) -> str:
25
+ return f"{self.transaction_date}, {self.category}, {self.name_description}, {self.amount}, {self.type}"
26
+
27
  @classmethod
28
  async def create(cls: "type[Transaction]", db: AsyncSession, **kwargs) -> "Transaction":
29
  query = sql.insert(cls).values(**kwargs)
 
34
 
35
  @classmethod
36
  async def bulk_create(cls: "type[Transaction]", db: AsyncSession, transactions: List[TransactionCreate]) -> None:
37
+ values = [transaction.model_dump() for transaction in transactions]
38
+ query = sql.insert(cls).values(values)
39
  await db.execute(query)
40
  await db.commit()
41
 
 
59
  ) -> "List[Transaction]":
60
  query = sql.select(cls).where(cls.user_id == user_id).where(cls.transaction_date.between(start_date, end_date))
61
  transactions = await db.scalars(query)
62
+ return transactions.all()
app/schema/index.py CHANGED
@@ -1,6 +1,7 @@
1
  from enum import Enum
2
  from datetime import datetime
3
- from typing import List, Optional
 
4
 
5
  from app.schema.base import BaseModel, PydanticBaseModel
6
 
@@ -62,15 +63,23 @@ class FileUploadCreate(PydanticBaseModel):
62
  type: str
63
 
64
 
65
- class IncomeStatementCreate(PydanticBaseModel):
66
  user_id: int
67
  date_from: datetime
68
  date_to: datetime
69
 
70
 
71
- class IncomeStatementResponse(PydanticBaseModel):
72
- id: int
73
- date_from: datetime
74
- date_to: datetime
75
- income: dict
76
- expenses: dict
 
 
 
 
 
 
 
 
 
1
  from enum import Enum
2
  from datetime import datetime
3
+ from typing import Dict, List
4
+ from typing_extensions import TypedDict
5
 
6
  from app.schema.base import BaseModel, PydanticBaseModel
7
 
 
63
  type: str
64
 
65
 
66
+ class IncomeStatementCreateRequest(PydanticBaseModel):
67
  user_id: int
68
  date_from: datetime
69
  date_to: datetime
70
 
71
 
72
+ class IncomeStatementDetail(TypedDict):
73
+ total: float
74
+ category_totals: List[Dict[str, str | float]]
75
+
76
+ class IncomeStatementLLMResponse(PydanticBaseModel):
77
+ income: IncomeStatementDetail
78
+ expenses: IncomeStatementDetail
79
+
80
+
81
+ class IncomeStatement(IncomeStatementCreateRequest, IncomeStatementLLMResponse):
82
+ pass
83
+
84
+ class IncomeStatementResponse(IncomeStatement):
85
+ id: int
app/service/income_statement.py CHANGED
@@ -1,15 +1,37 @@
1
- from app.schema.index import IncomeStatementCreate
 
2
  from app.model.transaction import Transaction as TransactionModel
3
  from app.model.income_statement import IncomeStatement as IncomeStatementModel
4
  from sqlalchemy.ext.asyncio import AsyncSession
5
 
 
6
 
7
- async def call_llm_to_create_income_statement(payload: IncomeStatementCreate, db: AsyncSession) -> None:
 
8
  transactions = await TransactionModel.get_by_user_between_dates(
9
  db, payload.user_id, payload.date_from, payload.date_to
10
  )
11
 
12
- # TODO: Call LLM to generate income and expenses
13
- income = {}
14
- expenses = {}
15
- await IncomeStatementModel.create(db, **payload, income=income, expenses=expenses)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from app.schema.index import IncomeStatement, IncomeStatementCreateRequest
3
  from app.model.transaction import Transaction as TransactionModel
4
  from app.model.income_statement import IncomeStatement as IncomeStatementModel
5
  from sqlalchemy.ext.asyncio import AsyncSession
6
 
7
+ from app.service.llm import call_llm
8
 
9
+
10
+ async def call_llm_to_create_income_statement(payload: IncomeStatementCreateRequest, db: AsyncSession) -> None:
11
  transactions = await TransactionModel.get_by_user_between_dates(
12
  db, payload.user_id, payload.date_from, payload.date_to
13
  )
14
 
15
+ if not transactions:
16
+ print("No transactions found")
17
+ return
18
+
19
+ response = await call_llm(transactions)
20
+
21
+ income = response.dict()['income']
22
+ expenses = response.dict()['expenses']
23
+
24
+ try:
25
+ income_statement_create_payload = IncomeStatement(
26
+ user_id=payload.user_id,
27
+ date_from=payload.date_from,
28
+ date_to=payload.date_to,
29
+ income=income,
30
+ expenses=expenses,
31
+ )
32
+
33
+ income_statement = await IncomeStatementModel.create(db, **income_statement_create_payload.model_dump())
34
+ print(f"Income statement created: {income_statement}")
35
+ except Exception as e:
36
+ print(e)
37
+ raise e
app/service/llm.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ from langchain_core.prompts import ChatPromptTemplate
3
+ from langchain_core.runnables.base import RunnableSequence
4
+ from langchain_openai import OpenAI
5
+ from langchain.globals import set_llm_cache
6
+
7
+ from app.model.transaction import Transaction
8
+ from app.schema.index import IncomeStatementLLMResponse
9
+ from config.index import config as env
10
+
11
+ from langchain_core.output_parsers import PydanticOutputParser
12
+
13
+ set_llm_cache(None)
14
+
15
+ def income_statement_prompt () -> ChatPromptTemplate:
16
+ context_str = """
17
+ You are an accountant skilled at organizing transactions from multiple different bank
18
+ accounts and credit card statements to prepare an income statement.
19
+
20
+ Input data is in the below csv format:
21
+ transaction_date, category, name_description, amount, type\n
22
+ {input_data_csv}
23
+
24
+ Your task is to prepare an income statement. The output should be in the following format: {format_instructions}
25
+
26
+ """
27
+ prompt = ChatPromptTemplate.from_template(context_str)
28
+ return prompt
29
+
30
+ async def call_llm(inputData: List[Transaction]) -> str:
31
+ input_data_csv = '\n'.join(str(x) for x in inputData)
32
+
33
+ output_parser = PydanticOutputParser(pydantic_object=IncomeStatementLLMResponse)
34
+
35
+ prompt = income_statement_prompt().partial(format_instructions=output_parser.get_format_instructions())
36
+
37
+ llm = OpenAI(name='Income Statement Generation Bot',
38
+ api_key=env.OPENAI_API_KEY,
39
+ # cache=True,
40
+ temperature=0.7,
41
+ verbose=True)
42
+
43
+ try:
44
+ runnable_chain = RunnableSequence(prompt, llm, output_parser)
45
+ except Exception as e:
46
+ print(f"runnable_chain error: {str(e)}")
47
+ raise e
48
+
49
+ try:
50
+ output_chunks = runnable_chain.invoke({"input_data_csv": input_data_csv})
51
+ return output_chunks
52
+
53
+ except Exception as e:
54
+ print(f"runnable_chain.invoke error: {str(e)}")
55
+ raise e
pyproject.toml CHANGED
@@ -52,6 +52,9 @@ version = "0.2.2"
52
  [tool.black]
53
  line-length = 119
54
 
 
 
 
55
  [build-system]
56
  requires = [ "poetry-core" ]
57
  build-backend = "poetry.core.masonry.api"
 
52
  [tool.black]
53
  line-length = 119
54
 
55
+ [tool.pytest.ini_options]
56
+ asyncio_mode = "auto"
57
+
58
  [build-system]
59
  requires = [ "poetry-core" ]
60
  build-backend = "poetry.core.masonry.api"
tests/conftest.py CHANGED
@@ -37,7 +37,7 @@ def client(app):
37
  yield c
38
 
39
 
40
- @pytest.mark.asyncio(scope="session")
41
  def event_loop(request):
42
  loop = asyncio.get_event_loop_policy().new_event_loop()
43
  yield loop
@@ -52,9 +52,11 @@ async def connection_test(test_db, event_loop):
52
  pg_db = test_db.dbname
53
  pg_password = test_db.password
54
 
55
- with DatabaseJanitor(pg_user, pg_host, pg_port, pg_db, test_db.version, pg_password):
56
- connection_str = f"postgresql+asyncpg://{pg_user}:@{pg_host}:{pg_port}/{pg_db}"
57
- sessionmanager.init(connection_str)
 
 
58
  yield
59
  await sessionmanager.close()
60
 
@@ -73,3 +75,9 @@ async def session_override(app, connection_test):
73
  yield session
74
 
75
  app.dependency_overrides[get_db_session] = get_db_session_override
 
 
 
 
 
 
 
37
  yield c
38
 
39
 
40
+ @pytest.fixture(scope="session")
41
  def event_loop(request):
42
  loop = asyncio.get_event_loop_policy().new_event_loop()
43
  yield loop
 
52
  pg_db = test_db.dbname
53
  pg_password = test_db.password
54
 
55
+ with DatabaseJanitor(user=pg_user, host=pg_host, port=pg_port, dbname=pg_db, version=test_db.version, password=pg_password):
56
+ connection_str = f"postgresql+psycopg://{pg_user}:@{pg_host}:{pg_port}/{pg_db}"
57
+ sessionmanager.init(connection_str,
58
+ # {"echo": True, "future": True}
59
+ )
60
  yield
61
  await sessionmanager.close()
62
 
 
75
  yield session
76
 
77
  app.dependency_overrides[get_db_session] = get_db_session_override
78
+
79
+
80
+ @pytest.fixture(scope="function", autouse=True)
81
+ async def get_db_session_fixture():
82
+ async with sessionmanager.session() as session:
83
+ yield session
tests/test_income_statement.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi.testclient import TestClient
2
+ import pytest
3
+
4
+ from sqlalchemy.ext.asyncio import AsyncSession
5
+ from app.model.transaction import Transaction
6
+ from app.model.user import User
7
+ from tests.utils import get_fake_transactions
8
+
9
+ @pytest.mark.asyncio
10
+ async def test_income_statement(client: TestClient, get_db_session_fixture: AsyncSession) -> None:
11
+ session_override = get_db_session_fixture
12
+
13
+ # 1. Create a user
14
+ user = await User.create(session_override, name="user", email="email", hashed_password="password")
15
+
16
+ # 2. Create a bunch of transactions
17
+ fake_transactions = get_fake_transactions(user.id)
18
+ await Transaction.bulk_create(session_override, fake_transactions)
19
+
20
+ # 3. Create an income statement
21
+ min_date = min(t.transaction_date for t in fake_transactions)
22
+ max_date = max(t.transaction_date for t in fake_transactions)
23
+
24
+ print(f"min_date: {min_date}, max_date: {max_date}")
25
+ response = client.post(
26
+ "/api/v1/income_statement",
27
+ json={
28
+ "user_id": 1,
29
+ "date_from": str(min_date),
30
+ "date_to": str(max_date),
31
+ },
32
+ )
33
+
34
+ assert response.status_code == 200
35
+
36
+ # 4. Verify that the income statement matches the transactions
37
+ response = client.get(f"/api/v1/income_statement/user/1")
38
+ print(response.json())
39
+ assert response.status_code == 200
40
+ assert response.json()[0].get("income")
41
+ assert response.json()[0].get("expenses")
42
+
43
+ report_id = response.json()[0].get("id")
44
+
45
+ # # 5. Verify that the income statement can be retrieved
46
+ if report_id is not None:
47
+ response = client.get(f"/api/v1/income_statement/report/{report_id}")
48
+ assert response.status_code == 200
49
+ assert response.json().get("income")
50
+ assert response.json().get("expenses")
51
+ assert response.json().get("id") == report_id
tests/test_transactions.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi.testclient import TestClient
2
+ import pytest
3
+
4
+ from app.model.transaction import Transaction
5
+ from app.model.user import User
6
+ from sqlalchemy.ext.asyncio import AsyncSession
7
+ from tests.utils import get_fake_transactions
8
+
9
+
10
+ @pytest.mark.asyncio
11
+ async def test_transactions(client: TestClient, get_db_session_fixture: AsyncSession) -> None:
12
+
13
+ session_override = get_db_session_fixture
14
+
15
+ # 1. Create a user
16
+ user = await User.create(session_override, name="user", email="email", hashed_password="password")
17
+
18
+ # 2. Create a bunch of transactions
19
+ fake_transactions = get_fake_transactions(user.id)
20
+ await Transaction.bulk_create(session_override, fake_transactions)
21
+
22
+ # 3. Verify that the transactions are returned
23
+ response = client.get("/api/v1/transactions/1")
24
+ assert response.status_code == 200
25
+ assert len(response.json()) == 10
tests/utils.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime
2
+ from typing import List
3
+ from app.schema.index import TransactionCreate, TransactionType
4
+
5
+
6
+ def get_fake_transactions(user_id: int) -> List[TransactionCreate]:
7
+ return [
8
+ TransactionCreate(
9
+ user_id=user_id,
10
+ transaction_date=datetime(2022, 1, 1),
11
+ category="category1",
12
+ name_description="name_description",
13
+ amount=1.0,
14
+ type=TransactionType.EXPENSE,
15
+ ),
16
+ TransactionCreate(
17
+ user_id=user_id,
18
+ transaction_date=datetime(2022, 1, 2),
19
+ category="category2",
20
+ name_description="name_description",
21
+ amount=2.0,
22
+ type=TransactionType.EXPENSE,
23
+ ),
24
+ TransactionCreate(
25
+ user_id=user_id,
26
+ transaction_date=datetime(2022, 1, 3),
27
+ category="category3",
28
+ name_description="name_description",
29
+ amount=3.0,
30
+ type=TransactionType.INCOME,
31
+ ),
32
+ TransactionCreate(
33
+ user_id=user_id,
34
+ transaction_date=datetime(2022, 1, 4),
35
+ category="category1",
36
+ name_description="name_description",
37
+ amount=4.0,
38
+ type=TransactionType.INCOME,
39
+ ),
40
+ TransactionCreate(
41
+ user_id=user_id,
42
+ transaction_date=datetime(2022, 1, 5),
43
+ category="category2",
44
+ name_description="name_description",
45
+ amount=5.0,
46
+ type=TransactionType.EXPENSE,
47
+ ),
48
+ TransactionCreate(
49
+ user_id=user_id,
50
+ transaction_date=datetime(2022, 1, 6),
51
+ category="category3",
52
+ name_description="name_description",
53
+ amount=6.0,
54
+ type=TransactionType.EXPENSE,
55
+ ),
56
+ TransactionCreate(
57
+ user_id=user_id,
58
+ transaction_date=datetime(2022, 1, 7),
59
+ category="category1",
60
+ name_description="name_description",
61
+ amount=7.0,
62
+ type=TransactionType.INCOME,
63
+ ),
64
+ TransactionCreate(
65
+ user_id=user_id,
66
+ transaction_date=datetime(2022, 1, 8),
67
+ category="category2",
68
+ name_description="name_description",
69
+ amount=8.0,
70
+ type=TransactionType.INCOME,
71
+ ),
72
+ TransactionCreate(
73
+ user_id=user_id,
74
+ transaction_date=datetime(2022, 1, 9),
75
+ category="category3",
76
+ name_description="name_description",
77
+ amount=9.0,
78
+ type=TransactionType.EXPENSE,
79
+ ),
80
+ TransactionCreate(
81
+ user_id=user_id,
82
+ transaction_date=datetime(2022, 1, 10),
83
+ category="category1",
84
+ name_description="name_description",
85
+ amount=10.0,
86
+ type=TransactionType.EXPENSE,
87
+ ),
88
+ ]