Praneeth Yerrapragada commited on
Commit
93bae48
1 Parent(s): cf9085d

chore(pr/14): fix income statement endpoints, pytests

Browse files
app/api/routers/income_statement.py CHANGED
@@ -3,7 +3,7 @@ from fastapi import APIRouter, Depends, HTTPException, status
3
  from sqlalchemy.ext.asyncio import AsyncSession
4
  from app.model.transaction import Transaction as TransactionModel
5
  from app.model.income_statement import IncomeStatement as IncomeStatementModel
6
- from app.schema.index import IncomeStatementCreate, IncomeStatementResponse
7
  from app.engine.postgresdb import get_db_session
8
  from app.service.income_statement import call_llm_to_create_income_statement
9
 
@@ -18,7 +18,7 @@ income_statement_router = r = APIRouter(prefix="/api/v1/income_statement", tags=
18
  500: {"description": "Internal server error"},
19
  },
20
  )
21
- async def create_income_statement(payload: IncomeStatementCreate, db: AsyncSession = Depends(get_db_session)) -> None:
22
  try:
23
  await call_llm_to_create_income_statement(payload, db)
24
 
@@ -27,7 +27,7 @@ async def create_income_statement(payload: IncomeStatementCreate, db: AsyncSessi
27
 
28
 
29
  @r.get(
30
- "/{user_id}",
31
  response_model=List[IncomeStatementResponse],
32
  responses={
33
  200: {"description": "New user created"},
@@ -43,14 +43,13 @@ async def get_income_statements(
43
  Retrieve all income statements.
44
  """
45
  result = await IncomeStatementModel.get_by_user(db, user_id)
46
- all_rows = result.all()
47
- if len(all_rows) == 0:
48
  raise HTTPException(status_code=status.HTTP_204_NO_CONTENT, detail="No income statements found for this user")
49
- return all_rows
50
 
51
 
52
  @r.get(
53
- "/{report_id}",
54
  response_model=IncomeStatementResponse,
55
  responses={
56
  200: {"description": "Income statement found"},
 
3
  from sqlalchemy.ext.asyncio import AsyncSession
4
  from app.model.transaction import Transaction as TransactionModel
5
  from app.model.income_statement import IncomeStatement as IncomeStatementModel
6
+ from app.schema.index import IncomeStatementCreateRequest, IncomeStatementResponse
7
  from app.engine.postgresdb import get_db_session
8
  from app.service.income_statement import call_llm_to_create_income_statement
9
 
 
18
  500: {"description": "Internal server error"},
19
  },
20
  )
21
+ async def create_income_statement(payload: IncomeStatementCreateRequest, db: AsyncSession = Depends(get_db_session)) -> None:
22
  try:
23
  await call_llm_to_create_income_statement(payload, db)
24
 
 
27
 
28
 
29
  @r.get(
30
+ "/user/{user_id}",
31
  response_model=List[IncomeStatementResponse],
32
  responses={
33
  200: {"description": "New user created"},
 
43
  Retrieve all income statements.
44
  """
45
  result = await IncomeStatementModel.get_by_user(db, user_id)
46
+ if len(result) == 0:
 
47
  raise HTTPException(status_code=status.HTTP_204_NO_CONTENT, detail="No income statements found for this user")
48
+ return result
49
 
50
 
51
  @r.get(
52
+ "/report/{report_id}",
53
  response_model=IncomeStatementResponse,
54
  responses={
55
  200: {"description": "Income statement found"},
app/api/routers/user.py CHANGED
@@ -27,7 +27,7 @@ async def create_user(user: UserCreate, db: AsyncSession = Depends(get_db_sessio
27
  if db_user and not db_user.is_deleted:
28
  raise HTTPException(status_code=409, detail="User already exists")
29
 
30
- await UserModel.create(db, **user.dict())
31
  user = await UserModel.get(db, email=user.email)
32
  return user
33
  except Exception as e:
@@ -64,7 +64,7 @@ async def update_user(email: str, user_payload: UserUpdate, db: AsyncSession = D
64
  user = await UserModel.get(db, email=email)
65
  if not user:
66
  raise HTTPException(status_code=404, detail="User not found")
67
- await UserModel.update(db, id=user.id, **user_payload.dict())
68
  user = await UserModel.get(db, email=email)
69
  return user
70
  except Exception as e:
 
27
  if db_user and not db_user.is_deleted:
28
  raise HTTPException(status_code=409, detail="User already exists")
29
 
30
+ await UserModel.create(db, **user.model_dump())
31
  user = await UserModel.get(db, email=user.email)
32
  return user
33
  except Exception as e:
 
64
  user = await UserModel.get(db, email=email)
65
  if not user:
66
  raise HTTPException(status_code=404, detail="User not found")
67
+ await UserModel.update(db, id=user.id, **user_payload.model_dump())
68
  user = await UserModel.get(db, email=email)
69
  return user
70
  except Exception as e:
app/model/income_statement.py CHANGED
@@ -8,6 +8,7 @@ from sqlalchemy.dialects.postgresql import JSON
8
 
9
  from app.model.base import BaseModel
10
  from app.engine.postgresdb import Base
 
11
 
12
 
13
  class IncomeStatement(Base, BaseModel):
@@ -21,19 +22,22 @@ class IncomeStatement(Base, BaseModel):
21
  user_id: Mapped[int] = mapped_column(ForeignKey("users.id"))
22
  user = relationship("User", back_populates="income_statements")
23
 
 
 
 
24
  @classmethod
25
- def create(cls: "type[IncomeStatement]", db: AsyncSession, **kwargs) -> "IncomeStatement":
26
- query = sql.insert(cls).values(**kwargs)
27
- income_statements = db.execute(query)
28
- income_statement = income_statements.first()
29
- db.commit()
30
  return income_statement
31
 
32
  @classmethod
33
  async def get_by_user(cls: "type[IncomeStatement]", db: AsyncSession, user_id: int) -> "List[IncomeStatement]":
34
  query = sql.select(cls).where(cls.user_id == user_id)
35
  income_statements = await db.scalars(query)
36
- return income_statements
37
 
38
  @classmethod
39
  async def get(cls: "type[IncomeStatement]", db: AsyncSession, id: int) -> "IncomeStatement":
 
8
 
9
  from app.model.base import BaseModel
10
  from app.engine.postgresdb import Base
11
+ from app.schema.index import IncomeStatement as IncomeStatementSchema
12
 
13
 
14
  class IncomeStatement(Base, BaseModel):
 
22
  user_id: Mapped[int] = mapped_column(ForeignKey("users.id"))
23
  user = relationship("User", back_populates="income_statements")
24
 
25
+ def __str__(self) -> str:
26
+ return f"IncomeStatement(id={self.id}, user_id={self.user_id}, date_from={self.date_from}, date_to={self.date_to}, income={self.income}, expenses={self.expenses})"
27
+
28
  @classmethod
29
+ async def create(cls: "type[IncomeStatement]", db: AsyncSession, **kwargs: IncomeStatementSchema) -> "IncomeStatement":
30
+ income_statement = cls(**kwargs)
31
+ db.add(income_statement)
32
+ await db.commit()
33
+ await db.refresh(income_statement)
34
  return income_statement
35
 
36
  @classmethod
37
  async def get_by_user(cls: "type[IncomeStatement]", db: AsyncSession, user_id: int) -> "List[IncomeStatement]":
38
  query = sql.select(cls).where(cls.user_id == user_id)
39
  income_statements = await db.scalars(query)
40
+ return income_statements.all()
41
 
42
  @classmethod
43
  async def get(cls: "type[IncomeStatement]", db: AsyncSession, id: int) -> "IncomeStatement":
app/model/transaction.py CHANGED
@@ -21,6 +21,9 @@ class Transaction(Base, BaseModel):
21
  user_id = mapped_column(ForeignKey("users.id"))
22
  user = relationship("User", back_populates="transactions")
23
 
 
 
 
24
  @classmethod
25
  async def create(cls: "type[Transaction]", db: AsyncSession, **kwargs) -> "Transaction":
26
  query = sql.insert(cls).values(**kwargs)
@@ -56,4 +59,4 @@ class Transaction(Base, BaseModel):
56
  ) -> "List[Transaction]":
57
  query = sql.select(cls).where(cls.user_id == user_id).where(cls.transaction_date.between(start_date, end_date))
58
  transactions = await db.scalars(query)
59
- return transactions
 
21
  user_id = mapped_column(ForeignKey("users.id"))
22
  user = relationship("User", back_populates="transactions")
23
 
24
+ def __str__(self) -> str:
25
+ return f"{self.transaction_date}, {self.category}, {self.name_description}, {self.amount}, {self.type}"
26
+
27
  @classmethod
28
  async def create(cls: "type[Transaction]", db: AsyncSession, **kwargs) -> "Transaction":
29
  query = sql.insert(cls).values(**kwargs)
 
59
  ) -> "List[Transaction]":
60
  query = sql.select(cls).where(cls.user_id == user_id).where(cls.transaction_date.between(start_date, end_date))
61
  transactions = await db.scalars(query)
62
+ return transactions.all()
app/schema/index.py CHANGED
@@ -1,6 +1,7 @@
1
  from enum import Enum
2
  from datetime import datetime
3
- from typing import List, Optional
 
4
 
5
  from app.schema.base import BaseModel, PydanticBaseModel
6
 
@@ -62,19 +63,23 @@ class FileUploadCreate(PydanticBaseModel):
62
  type: str
63
 
64
 
65
- class IncomeStatementCreate(PydanticBaseModel):
66
  user_id: int
67
  date_from: datetime
68
  date_to: datetime
69
 
70
 
71
- class IncomeStatementResponse(PydanticBaseModel):
72
- id: int
73
- date_from: datetime
74
- date_to: datetime
75
- income: dict
76
- expenses: dict
77
 
78
  class IncomeStatementLLMResponse(PydanticBaseModel):
79
- income: dict
80
- expenses: dict
 
 
 
 
 
 
 
 
1
  from enum import Enum
2
  from datetime import datetime
3
+ from typing import Dict, List
4
+ from typing_extensions import TypedDict
5
 
6
  from app.schema.base import BaseModel, PydanticBaseModel
7
 
 
63
  type: str
64
 
65
 
66
+ class IncomeStatementCreateRequest(PydanticBaseModel):
67
  user_id: int
68
  date_from: datetime
69
  date_to: datetime
70
 
71
 
72
+ class IncomeStatementDetail(TypedDict):
73
+ total: float
74
+ category_totals: List[Dict[str, str | float]]
 
 
 
75
 
76
  class IncomeStatementLLMResponse(PydanticBaseModel):
77
+ income: IncomeStatementDetail
78
+ expenses: IncomeStatementDetail
79
+
80
+
81
+ class IncomeStatement(IncomeStatementCreateRequest, IncomeStatementLLMResponse):
82
+ pass
83
+
84
+ class IncomeStatementResponse(IncomeStatement):
85
+ id: int
app/service/income_statement.py CHANGED
@@ -1,4 +1,5 @@
1
- from app.schema.index import IncomeStatementCreate
 
2
  from app.model.transaction import Transaction as TransactionModel
3
  from app.model.income_statement import IncomeStatement as IncomeStatementModel
4
  from sqlalchemy.ext.asyncio import AsyncSession
@@ -6,13 +7,31 @@ from sqlalchemy.ext.asyncio import AsyncSession
6
  from app.service.llm import call_llm
7
 
8
 
9
- async def call_llm_to_create_income_statement(payload: IncomeStatementCreate, db: AsyncSession) -> None:
10
  transactions = await TransactionModel.get_by_user_between_dates(
11
  db, payload.user_id, payload.date_from, payload.date_to
12
  )
13
 
14
- response = call_llm(transactions)
15
- income = response.income
16
- expenses = response.expenses
17
 
18
- await IncomeStatementModel.create(db, **payload, income=income, expenses=expenses)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from app.schema.index import IncomeStatement, IncomeStatementCreateRequest
3
  from app.model.transaction import Transaction as TransactionModel
4
  from app.model.income_statement import IncomeStatement as IncomeStatementModel
5
  from sqlalchemy.ext.asyncio import AsyncSession
 
7
  from app.service.llm import call_llm
8
 
9
 
10
+ async def call_llm_to_create_income_statement(payload: IncomeStatementCreateRequest, db: AsyncSession) -> None:
11
  transactions = await TransactionModel.get_by_user_between_dates(
12
  db, payload.user_id, payload.date_from, payload.date_to
13
  )
14
 
15
+ if not transactions:
16
+ print("No transactions found")
17
+ return
18
 
19
+ response = await call_llm(transactions)
20
+
21
+ income = response.dict()['income']
22
+ expenses = response.dict()['expenses']
23
+
24
+ try:
25
+ income_statement_create_payload = IncomeStatement(
26
+ user_id=payload.user_id,
27
+ date_from=payload.date_from,
28
+ date_to=payload.date_to,
29
+ income=income,
30
+ expenses=expenses,
31
+ )
32
+
33
+ income_statement = await IncomeStatementModel.create(db, **income_statement_create_payload.model_dump())
34
+ print(f"Income statement created: {income_statement}")
35
+ except Exception as e:
36
+ print(e)
37
+ raise e
app/service/llm.py CHANGED
@@ -1,53 +1,55 @@
1
- import logging
2
- from llama_index.core.settings import Settings
3
- from llama_index.core import PromptTemplate
 
 
4
 
5
  from app.model.transaction import Transaction
6
  from app.schema.index import IncomeStatementLLMResponse
 
7
 
8
- def income_statement_prompt (inputData: Transaction) -> PromptTemplate:
9
- context_str = f"""
 
 
 
 
10
  You are an accountant skilled at organizing transactions from multiple different bank
11
  accounts and credit card statements to prepare an income statement.
12
 
13
- Input data has the following format: transaction_date, type, category, name_description, amount.
14
- The <IN> tag is prepended to the input data as follows:
15
- ```<IN>
16
- {inputData}
17
- ```
18
-
19
- Your task is to prepare an income statement. The income statement's output should be in a json format.
20
- An example of the expected output is as follows with the <OUT> tag. Note that not all categories of the transactions
21
- have been listed below. Use below <OUT> tag as a reference in preparing the income statement.
22
- ```<OUT>
23
- {
24
- "REVENUE": {
25
- "Gross Sales": 73351.11,
26
- "Other Income": 0,
27
- "Balance Dec 2022": 3987.39,
28
- },
29
- "EXPENSES": {
30
- "Advertising": 0,
31
- "Commissions": 0,
32
- "Insurance": 0,
33
- "Memberships": 0,
34
- "Utilities": 0,
35
- }
36
- }
37
- ```
38
 
39
  """
40
- prompt = PromptTemplate(context_str, template_var_mappings={"inputData": inputData})
41
- logging.info(f"Prompt: {prompt}")
 
 
 
 
 
 
 
42
 
43
- async def call_llm(prompt: PromptTemplate) -> str:
44
- llm = Settings.llm.copy()
45
- prompt = income_statement_prompt()
46
- llm.check_prompts(prompt)
 
47
 
48
- llm.system_prompt = prompt
 
 
 
 
49
 
50
- output = await llm.astructured_predict(output_cls=IncomeStatementLLMResponse, prompt=prompt)
 
 
51
 
52
- logging.info(f"Output: {output}")
53
- return output
 
 
1
+ from typing import List
2
+ from langchain_core.prompts import ChatPromptTemplate
3
+ from langchain_core.runnables.base import RunnableSequence
4
+ from langchain_openai import OpenAI
5
+ from langchain.globals import set_llm_cache
6
 
7
  from app.model.transaction import Transaction
8
  from app.schema.index import IncomeStatementLLMResponse
9
+ from config.index import config as env
10
 
11
+ from langchain_core.output_parsers import PydanticOutputParser
12
+
13
+ set_llm_cache(None)
14
+
15
+ def income_statement_prompt () -> ChatPromptTemplate:
16
+ context_str = """
17
  You are an accountant skilled at organizing transactions from multiple different bank
18
  accounts and credit card statements to prepare an income statement.
19
 
20
+ Input data is in the below csv format:
21
+ transaction_date, category, name_description, amount, type\n
22
+ {input_data_csv}
23
+
24
+ Your task is to prepare an income statement. The output should be in the following format: {format_instructions}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
  """
27
+ prompt = ChatPromptTemplate.from_template(context_str)
28
+ return prompt
29
+
30
+ async def call_llm(inputData: List[Transaction]) -> str:
31
+ input_data_csv = '\n'.join(str(x) for x in inputData)
32
+
33
+ output_parser = PydanticOutputParser(pydantic_object=IncomeStatementLLMResponse)
34
+
35
+ prompt = income_statement_prompt().partial(format_instructions=output_parser.get_format_instructions())
36
 
37
+ llm = OpenAI(name='Income Statement Generation Bot',
38
+ api_key=env.OPENAI_API_KEY,
39
+ # cache=True,
40
+ temperature=0.7,
41
+ verbose=True)
42
 
43
+ try:
44
+ runnable_chain = RunnableSequence(prompt, llm, output_parser)
45
+ except Exception as e:
46
+ print(f"runnable_chain error: {str(e)}")
47
+ raise e
48
 
49
+ try:
50
+ output_chunks = runnable_chain.invoke({"input_data_csv": input_data_csv})
51
+ return output_chunks
52
 
53
+ except Exception as e:
54
+ print(f"runnable_chain.invoke error: {str(e)}")
55
+ raise e
tests/conftest.py CHANGED
@@ -54,7 +54,9 @@ async def connection_test(test_db, event_loop):
54
 
55
  with DatabaseJanitor(user=pg_user, host=pg_host, port=pg_port, dbname=pg_db, version=test_db.version, password=pg_password):
56
  connection_str = f"postgresql+psycopg://{pg_user}:@{pg_host}:{pg_port}/{pg_db}"
57
- sessionmanager.init(connection_str, {"echo": True, "future": True})
 
 
58
  yield
59
  await sessionmanager.close()
60
 
 
54
 
55
  with DatabaseJanitor(user=pg_user, host=pg_host, port=pg_port, dbname=pg_db, version=test_db.version, password=pg_password):
56
  connection_str = f"postgresql+psycopg://{pg_user}:@{pg_host}:{pg_port}/{pg_db}"
57
+ sessionmanager.init(connection_str,
58
+ # {"echo": True, "future": True}
59
+ )
60
  yield
61
  await sessionmanager.close()
62
 
tests/test_income_statement.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi.testclient import TestClient
2
+ import pytest
3
+
4
+ from sqlalchemy.ext.asyncio import AsyncSession
5
+ from app.model.transaction import Transaction
6
+ from app.model.user import User
7
+ from tests.utils import get_fake_transactions
8
+
9
+ @pytest.mark.asyncio
10
+ async def test_income_statement(client: TestClient, get_db_session_fixture: AsyncSession) -> None:
11
+ session_override = get_db_session_fixture
12
+
13
+ # 1. Create a user
14
+ user = await User.create(session_override, name="user", email="email", hashed_password="password")
15
+
16
+ # 2. Create a bunch of transactions
17
+ fake_transactions = get_fake_transactions(user.id)
18
+ await Transaction.bulk_create(session_override, fake_transactions)
19
+
20
+ # 3. Create an income statement
21
+ min_date = min(t.transaction_date for t in fake_transactions)
22
+ max_date = max(t.transaction_date for t in fake_transactions)
23
+
24
+ print(f"min_date: {min_date}, max_date: {max_date}")
25
+ response = client.post(
26
+ "/api/v1/income_statement",
27
+ json={
28
+ "user_id": 1,
29
+ "date_from": str(min_date),
30
+ "date_to": str(max_date),
31
+ },
32
+ )
33
+
34
+ assert response.status_code == 200
35
+
36
+ # 4. Verify that the income statement matches the transactions
37
+ response = client.get(f"/api/v1/income_statement/user/1")
38
+ print(response.json())
39
+ assert response.status_code == 200
40
+ assert response.json()[0].get("income")
41
+ assert response.json()[0].get("expenses")
42
+
43
+ report_id = response.json()[0].get("id")
44
+
45
+ # # 5. Verify that the income statement can be retrieved
46
+ if report_id is not None:
47
+ response = client.get(f"/api/v1/income_statement/report/{report_id}")
48
+ assert response.status_code == 200
49
+ assert response.json().get("income")
50
+ assert response.json().get("expenses")
51
+ assert response.json().get("id") == report_id
tests/test_transactions.py CHANGED
@@ -1,109 +1,25 @@
1
- from datetime import datetime
2
- from typing import List
3
- from fastapi import Depends
4
  from fastapi.testclient import TestClient
5
  import pytest
6
 
7
  from app.model.transaction import Transaction
8
- from app.schema.index import TransactionType, TransactionCreate
9
-
10
- from sqlalchemy.ext.asyncio import AsyncSession
11
- from app.engine.postgresdb import get_db_session
12
  from app.model.user import User
 
 
13
 
14
- def get_fake_transactions(user_id: int) -> List[TransactionCreate]:
15
- return [
16
- TransactionCreate(
17
- user_id=user_id,
18
- transaction_date=datetime(2022, 1, 1),
19
- category="category",
20
- name_description="name_description",
21
- amount=1.0,
22
- type=TransactionType.EXPENSE,
23
- ),
24
- TransactionCreate(
25
- user_id=user_id,
26
- transaction_date=datetime(2022, 1, 2),
27
- category="category",
28
- name_description="name_description",
29
- amount=2.0,
30
- type=TransactionType.EXPENSE,
31
- ),
32
- TransactionCreate(
33
- user_id=user_id,
34
- transaction_date=datetime(2022, 1, 3),
35
- category="category",
36
- name_description="name_description",
37
- amount=3.0,
38
- type=TransactionType.INCOME,
39
- ),
40
- TransactionCreate(
41
- user_id=user_id,
42
- transaction_date=datetime(2022, 1, 4),
43
- category="category",
44
- name_description="name_description",
45
- amount=4.0,
46
- type=TransactionType.INCOME,
47
- ),
48
- TransactionCreate(
49
- user_id=user_id,
50
- transaction_date=datetime(2022, 1, 5),
51
- category="category",
52
- name_description="name_description",
53
- amount=5.0,
54
- type=TransactionType.EXPENSE,
55
- ),
56
- TransactionCreate(
57
- user_id=user_id,
58
- transaction_date=datetime(2022, 1, 6),
59
- category="category",
60
- name_description="name_description",
61
- amount=6.0,
62
- type=TransactionType.EXPENSE,
63
- ),
64
- TransactionCreate(
65
- user_id=user_id,
66
- transaction_date=datetime(2022, 1, 7),
67
- category="category",
68
- name_description="name_description",
69
- amount=7.0,
70
- type=TransactionType.INCOME,
71
- ),
72
- TransactionCreate(
73
- user_id=user_id,
74
- transaction_date=datetime(2022, 1, 8),
75
- category="category",
76
- name_description="name_description",
77
- amount=8.0,
78
- type=TransactionType.INCOME,
79
- ),
80
- TransactionCreate(
81
- user_id=user_id,
82
- transaction_date=datetime(2022, 1, 9),
83
- category="category",
84
- name_description="name_description",
85
- amount=9.0,
86
- type=TransactionType.EXPENSE,
87
- ),
88
- TransactionCreate(
89
- user_id=user_id,
90
- transaction_date=datetime(2022, 1, 10),
91
- category="category",
92
- name_description="name_description",
93
- amount=10.0,
94
- type=TransactionType.EXPENSE,
95
- ),
96
- ]
97
 
98
  @pytest.mark.asyncio
99
  async def test_transactions(client: TestClient, get_db_session_fixture: AsyncSession) -> None:
100
 
101
  session_override = get_db_session_fixture
 
 
102
  user = await User.create(session_override, name="user", email="email", hashed_password="password")
103
 
 
104
  fake_transactions = get_fake_transactions(user.id)
105
  await Transaction.bulk_create(session_override, fake_transactions)
106
 
 
107
  response = client.get("/api/v1/transactions/1")
108
  assert response.status_code == 200
109
  assert len(response.json()) == 10
 
 
 
 
1
  from fastapi.testclient import TestClient
2
  import pytest
3
 
4
  from app.model.transaction import Transaction
 
 
 
 
5
  from app.model.user import User
6
+ from sqlalchemy.ext.asyncio import AsyncSession
7
+ from tests.utils import get_fake_transactions
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  @pytest.mark.asyncio
11
  async def test_transactions(client: TestClient, get_db_session_fixture: AsyncSession) -> None:
12
 
13
  session_override = get_db_session_fixture
14
+
15
+ # 1. Create a user
16
  user = await User.create(session_override, name="user", email="email", hashed_password="password")
17
 
18
+ # 2. Create a bunch of transactions
19
  fake_transactions = get_fake_transactions(user.id)
20
  await Transaction.bulk_create(session_override, fake_transactions)
21
 
22
+ # 3. Verify that the transactions are returned
23
  response = client.get("/api/v1/transactions/1")
24
  assert response.status_code == 200
25
  assert len(response.json()) == 10
tests/utils.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime
2
+ from typing import List
3
+ from app.schema.index import TransactionCreate, TransactionType
4
+
5
+
6
+ def get_fake_transactions(user_id: int) -> List[TransactionCreate]:
7
+ return [
8
+ TransactionCreate(
9
+ user_id=user_id,
10
+ transaction_date=datetime(2022, 1, 1),
11
+ category="category1",
12
+ name_description="name_description",
13
+ amount=1.0,
14
+ type=TransactionType.EXPENSE,
15
+ ),
16
+ TransactionCreate(
17
+ user_id=user_id,
18
+ transaction_date=datetime(2022, 1, 2),
19
+ category="category2",
20
+ name_description="name_description",
21
+ amount=2.0,
22
+ type=TransactionType.EXPENSE,
23
+ ),
24
+ TransactionCreate(
25
+ user_id=user_id,
26
+ transaction_date=datetime(2022, 1, 3),
27
+ category="category3",
28
+ name_description="name_description",
29
+ amount=3.0,
30
+ type=TransactionType.INCOME,
31
+ ),
32
+ TransactionCreate(
33
+ user_id=user_id,
34
+ transaction_date=datetime(2022, 1, 4),
35
+ category="category1",
36
+ name_description="name_description",
37
+ amount=4.0,
38
+ type=TransactionType.INCOME,
39
+ ),
40
+ TransactionCreate(
41
+ user_id=user_id,
42
+ transaction_date=datetime(2022, 1, 5),
43
+ category="category2",
44
+ name_description="name_description",
45
+ amount=5.0,
46
+ type=TransactionType.EXPENSE,
47
+ ),
48
+ TransactionCreate(
49
+ user_id=user_id,
50
+ transaction_date=datetime(2022, 1, 6),
51
+ category="category3",
52
+ name_description="name_description",
53
+ amount=6.0,
54
+ type=TransactionType.EXPENSE,
55
+ ),
56
+ TransactionCreate(
57
+ user_id=user_id,
58
+ transaction_date=datetime(2022, 1, 7),
59
+ category="category1",
60
+ name_description="name_description",
61
+ amount=7.0,
62
+ type=TransactionType.INCOME,
63
+ ),
64
+ TransactionCreate(
65
+ user_id=user_id,
66
+ transaction_date=datetime(2022, 1, 8),
67
+ category="category2",
68
+ name_description="name_description",
69
+ amount=8.0,
70
+ type=TransactionType.INCOME,
71
+ ),
72
+ TransactionCreate(
73
+ user_id=user_id,
74
+ transaction_date=datetime(2022, 1, 9),
75
+ category="category3",
76
+ name_description="name_description",
77
+ amount=9.0,
78
+ type=TransactionType.EXPENSE,
79
+ ),
80
+ TransactionCreate(
81
+ user_id=user_id,
82
+ transaction_date=datetime(2022, 1, 10),
83
+ category="category1",
84
+ name_description="name_description",
85
+ amount=10.0,
86
+ type=TransactionType.EXPENSE,
87
+ ),
88
+ ]