Spaces:
Sleeping
Sleeping
fix income statement endpoints
#16
by
praneethys
- opened
- app/api/routers/income_statement.py +6 -7
- app/api/routers/user.py +2 -2
- app/model/income_statement.py +10 -6
- app/model/transaction.py +4 -1
- app/schema/index.py +15 -10
- app/service/income_statement.py +25 -6
- app/service/llm.py +42 -40
- tests/conftest.py +3 -1
- tests/test_income_statement.py +51 -0
- tests/test_transactions.py +6 -90
- tests/utils.py +88 -0
app/api/routers/income_statement.py
CHANGED
@@ -3,7 +3,7 @@ from fastapi import APIRouter, Depends, HTTPException, status
|
|
3 |
from sqlalchemy.ext.asyncio import AsyncSession
|
4 |
from app.model.transaction import Transaction as TransactionModel
|
5 |
from app.model.income_statement import IncomeStatement as IncomeStatementModel
|
6 |
-
from app.schema.index import
|
7 |
from app.engine.postgresdb import get_db_session
|
8 |
from app.service.income_statement import call_llm_to_create_income_statement
|
9 |
|
@@ -18,7 +18,7 @@ income_statement_router = r = APIRouter(prefix="/api/v1/income_statement", tags=
|
|
18 |
500: {"description": "Internal server error"},
|
19 |
},
|
20 |
)
|
21 |
-
async def create_income_statement(payload:
|
22 |
try:
|
23 |
await call_llm_to_create_income_statement(payload, db)
|
24 |
|
@@ -27,7 +27,7 @@ async def create_income_statement(payload: IncomeStatementCreate, db: AsyncSessi
|
|
27 |
|
28 |
|
29 |
@r.get(
|
30 |
-
"/{user_id}",
|
31 |
response_model=List[IncomeStatementResponse],
|
32 |
responses={
|
33 |
200: {"description": "New user created"},
|
@@ -43,14 +43,13 @@ async def get_income_statements(
|
|
43 |
Retrieve all income statements.
|
44 |
"""
|
45 |
result = await IncomeStatementModel.get_by_user(db, user_id)
|
46 |
-
|
47 |
-
if len(all_rows) == 0:
|
48 |
raise HTTPException(status_code=status.HTTP_204_NO_CONTENT, detail="No income statements found for this user")
|
49 |
-
return
|
50 |
|
51 |
|
52 |
@r.get(
|
53 |
-
"/{report_id}",
|
54 |
response_model=IncomeStatementResponse,
|
55 |
responses={
|
56 |
200: {"description": "Income statement found"},
|
|
|
3 |
from sqlalchemy.ext.asyncio import AsyncSession
|
4 |
from app.model.transaction import Transaction as TransactionModel
|
5 |
from app.model.income_statement import IncomeStatement as IncomeStatementModel
|
6 |
+
from app.schema.index import IncomeStatementCreateRequest, IncomeStatementResponse
|
7 |
from app.engine.postgresdb import get_db_session
|
8 |
from app.service.income_statement import call_llm_to_create_income_statement
|
9 |
|
|
|
18 |
500: {"description": "Internal server error"},
|
19 |
},
|
20 |
)
|
21 |
+
async def create_income_statement(payload: IncomeStatementCreateRequest, db: AsyncSession = Depends(get_db_session)) -> None:
|
22 |
try:
|
23 |
await call_llm_to_create_income_statement(payload, db)
|
24 |
|
|
|
27 |
|
28 |
|
29 |
@r.get(
|
30 |
+
"/user/{user_id}",
|
31 |
response_model=List[IncomeStatementResponse],
|
32 |
responses={
|
33 |
200: {"description": "New user created"},
|
|
|
43 |
Retrieve all income statements.
|
44 |
"""
|
45 |
result = await IncomeStatementModel.get_by_user(db, user_id)
|
46 |
+
if len(result) == 0:
|
|
|
47 |
raise HTTPException(status_code=status.HTTP_204_NO_CONTENT, detail="No income statements found for this user")
|
48 |
+
return result
|
49 |
|
50 |
|
51 |
@r.get(
|
52 |
+
"/report/{report_id}",
|
53 |
response_model=IncomeStatementResponse,
|
54 |
responses={
|
55 |
200: {"description": "Income statement found"},
|
app/api/routers/user.py
CHANGED
@@ -27,7 +27,7 @@ async def create_user(user: UserCreate, db: AsyncSession = Depends(get_db_sessio
|
|
27 |
if db_user and not db_user.is_deleted:
|
28 |
raise HTTPException(status_code=409, detail="User already exists")
|
29 |
|
30 |
-
await UserModel.create(db, **user.
|
31 |
user = await UserModel.get(db, email=user.email)
|
32 |
return user
|
33 |
except Exception as e:
|
@@ -64,7 +64,7 @@ async def update_user(email: str, user_payload: UserUpdate, db: AsyncSession = D
|
|
64 |
user = await UserModel.get(db, email=email)
|
65 |
if not user:
|
66 |
raise HTTPException(status_code=404, detail="User not found")
|
67 |
-
await UserModel.update(db, id=user.id, **user_payload.
|
68 |
user = await UserModel.get(db, email=email)
|
69 |
return user
|
70 |
except Exception as e:
|
|
|
27 |
if db_user and not db_user.is_deleted:
|
28 |
raise HTTPException(status_code=409, detail="User already exists")
|
29 |
|
30 |
+
await UserModel.create(db, **user.model_dump())
|
31 |
user = await UserModel.get(db, email=user.email)
|
32 |
return user
|
33 |
except Exception as e:
|
|
|
64 |
user = await UserModel.get(db, email=email)
|
65 |
if not user:
|
66 |
raise HTTPException(status_code=404, detail="User not found")
|
67 |
+
await UserModel.update(db, id=user.id, **user_payload.model_dump())
|
68 |
user = await UserModel.get(db, email=email)
|
69 |
return user
|
70 |
except Exception as e:
|
app/model/income_statement.py
CHANGED
@@ -8,6 +8,7 @@ from sqlalchemy.dialects.postgresql import JSON
|
|
8 |
|
9 |
from app.model.base import BaseModel
|
10 |
from app.engine.postgresdb import Base
|
|
|
11 |
|
12 |
|
13 |
class IncomeStatement(Base, BaseModel):
|
@@ -21,19 +22,22 @@ class IncomeStatement(Base, BaseModel):
|
|
21 |
user_id: Mapped[int] = mapped_column(ForeignKey("users.id"))
|
22 |
user = relationship("User", back_populates="income_statements")
|
23 |
|
|
|
|
|
|
|
24 |
@classmethod
|
25 |
-
def create(cls: "type[IncomeStatement]", db: AsyncSession, **kwargs) -> "IncomeStatement":
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
db.
|
30 |
return income_statement
|
31 |
|
32 |
@classmethod
|
33 |
async def get_by_user(cls: "type[IncomeStatement]", db: AsyncSession, user_id: int) -> "List[IncomeStatement]":
|
34 |
query = sql.select(cls).where(cls.user_id == user_id)
|
35 |
income_statements = await db.scalars(query)
|
36 |
-
return income_statements
|
37 |
|
38 |
@classmethod
|
39 |
async def get(cls: "type[IncomeStatement]", db: AsyncSession, id: int) -> "IncomeStatement":
|
|
|
8 |
|
9 |
from app.model.base import BaseModel
|
10 |
from app.engine.postgresdb import Base
|
11 |
+
from app.schema.index import IncomeStatement as IncomeStatementSchema
|
12 |
|
13 |
|
14 |
class IncomeStatement(Base, BaseModel):
|
|
|
22 |
user_id: Mapped[int] = mapped_column(ForeignKey("users.id"))
|
23 |
user = relationship("User", back_populates="income_statements")
|
24 |
|
25 |
+
def __str__(self) -> str:
|
26 |
+
return f"IncomeStatement(id={self.id}, user_id={self.user_id}, date_from={self.date_from}, date_to={self.date_to}, income={self.income}, expenses={self.expenses})"
|
27 |
+
|
28 |
@classmethod
|
29 |
+
async def create(cls: "type[IncomeStatement]", db: AsyncSession, **kwargs: IncomeStatementSchema) -> "IncomeStatement":
|
30 |
+
income_statement = cls(**kwargs)
|
31 |
+
db.add(income_statement)
|
32 |
+
await db.commit()
|
33 |
+
await db.refresh(income_statement)
|
34 |
return income_statement
|
35 |
|
36 |
@classmethod
|
37 |
async def get_by_user(cls: "type[IncomeStatement]", db: AsyncSession, user_id: int) -> "List[IncomeStatement]":
|
38 |
query = sql.select(cls).where(cls.user_id == user_id)
|
39 |
income_statements = await db.scalars(query)
|
40 |
+
return income_statements.all()
|
41 |
|
42 |
@classmethod
|
43 |
async def get(cls: "type[IncomeStatement]", db: AsyncSession, id: int) -> "IncomeStatement":
|
app/model/transaction.py
CHANGED
@@ -21,6 +21,9 @@ class Transaction(Base, BaseModel):
|
|
21 |
user_id = mapped_column(ForeignKey("users.id"))
|
22 |
user = relationship("User", back_populates="transactions")
|
23 |
|
|
|
|
|
|
|
24 |
@classmethod
|
25 |
async def create(cls: "type[Transaction]", db: AsyncSession, **kwargs) -> "Transaction":
|
26 |
query = sql.insert(cls).values(**kwargs)
|
@@ -56,4 +59,4 @@ class Transaction(Base, BaseModel):
|
|
56 |
) -> "List[Transaction]":
|
57 |
query = sql.select(cls).where(cls.user_id == user_id).where(cls.transaction_date.between(start_date, end_date))
|
58 |
transactions = await db.scalars(query)
|
59 |
-
return transactions
|
|
|
21 |
user_id = mapped_column(ForeignKey("users.id"))
|
22 |
user = relationship("User", back_populates="transactions")
|
23 |
|
24 |
+
def __str__(self) -> str:
|
25 |
+
return f"{self.transaction_date}, {self.category}, {self.name_description}, {self.amount}, {self.type}"
|
26 |
+
|
27 |
@classmethod
|
28 |
async def create(cls: "type[Transaction]", db: AsyncSession, **kwargs) -> "Transaction":
|
29 |
query = sql.insert(cls).values(**kwargs)
|
|
|
59 |
) -> "List[Transaction]":
|
60 |
query = sql.select(cls).where(cls.user_id == user_id).where(cls.transaction_date.between(start_date, end_date))
|
61 |
transactions = await db.scalars(query)
|
62 |
+
return transactions.all()
|
app/schema/index.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
from enum import Enum
|
2 |
from datetime import datetime
|
3 |
-
from typing import
|
|
|
4 |
|
5 |
from app.schema.base import BaseModel, PydanticBaseModel
|
6 |
|
@@ -62,19 +63,23 @@ class FileUploadCreate(PydanticBaseModel):
|
|
62 |
type: str
|
63 |
|
64 |
|
65 |
-
class
|
66 |
user_id: int
|
67 |
date_from: datetime
|
68 |
date_to: datetime
|
69 |
|
70 |
|
71 |
-
class
|
72 |
-
|
73 |
-
|
74 |
-
date_to: datetime
|
75 |
-
income: dict
|
76 |
-
expenses: dict
|
77 |
|
78 |
class IncomeStatementLLMResponse(PydanticBaseModel):
|
79 |
-
income:
|
80 |
-
expenses:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from enum import Enum
|
2 |
from datetime import datetime
|
3 |
+
from typing import Dict, List
|
4 |
+
from typing_extensions import TypedDict
|
5 |
|
6 |
from app.schema.base import BaseModel, PydanticBaseModel
|
7 |
|
|
|
63 |
type: str
|
64 |
|
65 |
|
66 |
+
class IncomeStatementCreateRequest(PydanticBaseModel):
|
67 |
user_id: int
|
68 |
date_from: datetime
|
69 |
date_to: datetime
|
70 |
|
71 |
|
72 |
+
class IncomeStatementDetail(TypedDict):
|
73 |
+
total: float
|
74 |
+
category_totals: List[Dict[str, str | float]]
|
|
|
|
|
|
|
75 |
|
76 |
class IncomeStatementLLMResponse(PydanticBaseModel):
|
77 |
+
income: IncomeStatementDetail
|
78 |
+
expenses: IncomeStatementDetail
|
79 |
+
|
80 |
+
|
81 |
+
class IncomeStatement(IncomeStatementCreateRequest, IncomeStatementLLMResponse):
|
82 |
+
pass
|
83 |
+
|
84 |
+
class IncomeStatementResponse(IncomeStatement):
|
85 |
+
id: int
|
app/service/income_statement.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
-
|
|
|
2 |
from app.model.transaction import Transaction as TransactionModel
|
3 |
from app.model.income_statement import IncomeStatement as IncomeStatementModel
|
4 |
from sqlalchemy.ext.asyncio import AsyncSession
|
@@ -6,13 +7,31 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|
6 |
from app.service.llm import call_llm
|
7 |
|
8 |
|
9 |
-
async def call_llm_to_create_income_statement(payload:
|
10 |
transactions = await TransactionModel.get_by_user_between_dates(
|
11 |
db, payload.user_id, payload.date_from, payload.date_to
|
12 |
)
|
13 |
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
|
18 |
-
await
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from app.schema.index import IncomeStatement, IncomeStatementCreateRequest
|
3 |
from app.model.transaction import Transaction as TransactionModel
|
4 |
from app.model.income_statement import IncomeStatement as IncomeStatementModel
|
5 |
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
7 |
from app.service.llm import call_llm
|
8 |
|
9 |
|
10 |
+
async def call_llm_to_create_income_statement(payload: IncomeStatementCreateRequest, db: AsyncSession) -> None:
|
11 |
transactions = await TransactionModel.get_by_user_between_dates(
|
12 |
db, payload.user_id, payload.date_from, payload.date_to
|
13 |
)
|
14 |
|
15 |
+
if not transactions:
|
16 |
+
print("No transactions found")
|
17 |
+
return
|
18 |
|
19 |
+
response = await call_llm(transactions)
|
20 |
+
|
21 |
+
income = response.dict()['income']
|
22 |
+
expenses = response.dict()['expenses']
|
23 |
+
|
24 |
+
try:
|
25 |
+
income_statement_create_payload = IncomeStatement(
|
26 |
+
user_id=payload.user_id,
|
27 |
+
date_from=payload.date_from,
|
28 |
+
date_to=payload.date_to,
|
29 |
+
income=income,
|
30 |
+
expenses=expenses,
|
31 |
+
)
|
32 |
+
|
33 |
+
income_statement = await IncomeStatementModel.create(db, **income_statement_create_payload.model_dump())
|
34 |
+
print(f"Income statement created: {income_statement}")
|
35 |
+
except Exception as e:
|
36 |
+
print(e)
|
37 |
+
raise e
|
app/service/llm.py
CHANGED
@@ -1,53 +1,55 @@
|
|
1 |
-
import
|
2 |
-
from
|
3 |
-
from
|
|
|
|
|
4 |
|
5 |
from app.model.transaction import Transaction
|
6 |
from app.schema.index import IncomeStatementLLMResponse
|
|
|
7 |
|
8 |
-
|
9 |
-
|
|
|
|
|
|
|
|
|
10 |
You are an accountant skilled at organizing transactions from multiple different bank
|
11 |
accounts and credit card statements to prepare an income statement.
|
12 |
|
13 |
-
Input data
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
Your task is to prepare an income statement. The income statement's output should be in a json format.
|
20 |
-
An example of the expected output is as follows with the <OUT> tag. Note that not all categories of the transactions
|
21 |
-
have been listed below. Use below <OUT> tag as a reference in preparing the income statement.
|
22 |
-
```<OUT>
|
23 |
-
{
|
24 |
-
"REVENUE": {
|
25 |
-
"Gross Sales": 73351.11,
|
26 |
-
"Other Income": 0,
|
27 |
-
"Balance Dec 2022": 3987.39,
|
28 |
-
},
|
29 |
-
"EXPENSES": {
|
30 |
-
"Advertising": 0,
|
31 |
-
"Commissions": 0,
|
32 |
-
"Insurance": 0,
|
33 |
-
"Memberships": 0,
|
34 |
-
"Utilities": 0,
|
35 |
-
}
|
36 |
-
}
|
37 |
-
```
|
38 |
|
39 |
"""
|
40 |
-
prompt =
|
41 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
|
|
47 |
|
48 |
-
|
|
|
|
|
|
|
|
|
49 |
|
50 |
-
|
|
|
|
|
51 |
|
52 |
-
|
53 |
-
|
|
|
|
1 |
+
from typing import List
|
2 |
+
from langchain_core.prompts import ChatPromptTemplate
|
3 |
+
from langchain_core.runnables.base import RunnableSequence
|
4 |
+
from langchain_openai import OpenAI
|
5 |
+
from langchain.globals import set_llm_cache
|
6 |
|
7 |
from app.model.transaction import Transaction
|
8 |
from app.schema.index import IncomeStatementLLMResponse
|
9 |
+
from config.index import config as env
|
10 |
|
11 |
+
from langchain_core.output_parsers import PydanticOutputParser
|
12 |
+
|
13 |
+
set_llm_cache(None)
|
14 |
+
|
15 |
+
def income_statement_prompt () -> ChatPromptTemplate:
|
16 |
+
context_str = """
|
17 |
You are an accountant skilled at organizing transactions from multiple different bank
|
18 |
accounts and credit card statements to prepare an income statement.
|
19 |
|
20 |
+
Input data is in the below csv format:
|
21 |
+
transaction_date, category, name_description, amount, type\n
|
22 |
+
{input_data_csv}
|
23 |
+
|
24 |
+
Your task is to prepare an income statement. The output should be in the following format: {format_instructions}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
|
26 |
"""
|
27 |
+
prompt = ChatPromptTemplate.from_template(context_str)
|
28 |
+
return prompt
|
29 |
+
|
30 |
+
async def call_llm(inputData: List[Transaction]) -> str:
|
31 |
+
input_data_csv = '\n'.join(str(x) for x in inputData)
|
32 |
+
|
33 |
+
output_parser = PydanticOutputParser(pydantic_object=IncomeStatementLLMResponse)
|
34 |
+
|
35 |
+
prompt = income_statement_prompt().partial(format_instructions=output_parser.get_format_instructions())
|
36 |
|
37 |
+
llm = OpenAI(name='Income Statement Generation Bot',
|
38 |
+
api_key=env.OPENAI_API_KEY,
|
39 |
+
# cache=True,
|
40 |
+
temperature=0.7,
|
41 |
+
verbose=True)
|
42 |
|
43 |
+
try:
|
44 |
+
runnable_chain = RunnableSequence(prompt, llm, output_parser)
|
45 |
+
except Exception as e:
|
46 |
+
print(f"runnable_chain error: {str(e)}")
|
47 |
+
raise e
|
48 |
|
49 |
+
try:
|
50 |
+
output_chunks = runnable_chain.invoke({"input_data_csv": input_data_csv})
|
51 |
+
return output_chunks
|
52 |
|
53 |
+
except Exception as e:
|
54 |
+
print(f"runnable_chain.invoke error: {str(e)}")
|
55 |
+
raise e
|
tests/conftest.py
CHANGED
@@ -54,7 +54,9 @@ async def connection_test(test_db, event_loop):
|
|
54 |
|
55 |
with DatabaseJanitor(user=pg_user, host=pg_host, port=pg_port, dbname=pg_db, version=test_db.version, password=pg_password):
|
56 |
connection_str = f"postgresql+psycopg://{pg_user}:@{pg_host}:{pg_port}/{pg_db}"
|
57 |
-
sessionmanager.init(connection_str,
|
|
|
|
|
58 |
yield
|
59 |
await sessionmanager.close()
|
60 |
|
|
|
54 |
|
55 |
with DatabaseJanitor(user=pg_user, host=pg_host, port=pg_port, dbname=pg_db, version=test_db.version, password=pg_password):
|
56 |
connection_str = f"postgresql+psycopg://{pg_user}:@{pg_host}:{pg_port}/{pg_db}"
|
57 |
+
sessionmanager.init(connection_str,
|
58 |
+
# {"echo": True, "future": True}
|
59 |
+
)
|
60 |
yield
|
61 |
await sessionmanager.close()
|
62 |
|
tests/test_income_statement.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi.testclient import TestClient
|
2 |
+
import pytest
|
3 |
+
|
4 |
+
from sqlalchemy.ext.asyncio import AsyncSession
|
5 |
+
from app.model.transaction import Transaction
|
6 |
+
from app.model.user import User
|
7 |
+
from tests.utils import get_fake_transactions
|
8 |
+
|
9 |
+
@pytest.mark.asyncio
|
10 |
+
async def test_income_statement(client: TestClient, get_db_session_fixture: AsyncSession) -> None:
|
11 |
+
session_override = get_db_session_fixture
|
12 |
+
|
13 |
+
# 1. Create a user
|
14 |
+
user = await User.create(session_override, name="user", email="email", hashed_password="password")
|
15 |
+
|
16 |
+
# 2. Create a bunch of transactions
|
17 |
+
fake_transactions = get_fake_transactions(user.id)
|
18 |
+
await Transaction.bulk_create(session_override, fake_transactions)
|
19 |
+
|
20 |
+
# 3. Create an income statement
|
21 |
+
min_date = min(t.transaction_date for t in fake_transactions)
|
22 |
+
max_date = max(t.transaction_date for t in fake_transactions)
|
23 |
+
|
24 |
+
print(f"min_date: {min_date}, max_date: {max_date}")
|
25 |
+
response = client.post(
|
26 |
+
"/api/v1/income_statement",
|
27 |
+
json={
|
28 |
+
"user_id": 1,
|
29 |
+
"date_from": str(min_date),
|
30 |
+
"date_to": str(max_date),
|
31 |
+
},
|
32 |
+
)
|
33 |
+
|
34 |
+
assert response.status_code == 200
|
35 |
+
|
36 |
+
# 4. Verify that the income statement matches the transactions
|
37 |
+
response = client.get(f"/api/v1/income_statement/user/1")
|
38 |
+
print(response.json())
|
39 |
+
assert response.status_code == 200
|
40 |
+
assert response.json()[0].get("income")
|
41 |
+
assert response.json()[0].get("expenses")
|
42 |
+
|
43 |
+
report_id = response.json()[0].get("id")
|
44 |
+
|
45 |
+
# # 5. Verify that the income statement can be retrieved
|
46 |
+
if report_id is not None:
|
47 |
+
response = client.get(f"/api/v1/income_statement/report/{report_id}")
|
48 |
+
assert response.status_code == 200
|
49 |
+
assert response.json().get("income")
|
50 |
+
assert response.json().get("expenses")
|
51 |
+
assert response.json().get("id") == report_id
|
tests/test_transactions.py
CHANGED
@@ -1,109 +1,25 @@
|
|
1 |
-
from datetime import datetime
|
2 |
-
from typing import List
|
3 |
-
from fastapi import Depends
|
4 |
from fastapi.testclient import TestClient
|
5 |
import pytest
|
6 |
|
7 |
from app.model.transaction import Transaction
|
8 |
-
from app.schema.index import TransactionType, TransactionCreate
|
9 |
-
|
10 |
-
from sqlalchemy.ext.asyncio import AsyncSession
|
11 |
-
from app.engine.postgresdb import get_db_session
|
12 |
from app.model.user import User
|
|
|
|
|
13 |
|
14 |
-
def get_fake_transactions(user_id: int) -> List[TransactionCreate]:
|
15 |
-
return [
|
16 |
-
TransactionCreate(
|
17 |
-
user_id=user_id,
|
18 |
-
transaction_date=datetime(2022, 1, 1),
|
19 |
-
category="category",
|
20 |
-
name_description="name_description",
|
21 |
-
amount=1.0,
|
22 |
-
type=TransactionType.EXPENSE,
|
23 |
-
),
|
24 |
-
TransactionCreate(
|
25 |
-
user_id=user_id,
|
26 |
-
transaction_date=datetime(2022, 1, 2),
|
27 |
-
category="category",
|
28 |
-
name_description="name_description",
|
29 |
-
amount=2.0,
|
30 |
-
type=TransactionType.EXPENSE,
|
31 |
-
),
|
32 |
-
TransactionCreate(
|
33 |
-
user_id=user_id,
|
34 |
-
transaction_date=datetime(2022, 1, 3),
|
35 |
-
category="category",
|
36 |
-
name_description="name_description",
|
37 |
-
amount=3.0,
|
38 |
-
type=TransactionType.INCOME,
|
39 |
-
),
|
40 |
-
TransactionCreate(
|
41 |
-
user_id=user_id,
|
42 |
-
transaction_date=datetime(2022, 1, 4),
|
43 |
-
category="category",
|
44 |
-
name_description="name_description",
|
45 |
-
amount=4.0,
|
46 |
-
type=TransactionType.INCOME,
|
47 |
-
),
|
48 |
-
TransactionCreate(
|
49 |
-
user_id=user_id,
|
50 |
-
transaction_date=datetime(2022, 1, 5),
|
51 |
-
category="category",
|
52 |
-
name_description="name_description",
|
53 |
-
amount=5.0,
|
54 |
-
type=TransactionType.EXPENSE,
|
55 |
-
),
|
56 |
-
TransactionCreate(
|
57 |
-
user_id=user_id,
|
58 |
-
transaction_date=datetime(2022, 1, 6),
|
59 |
-
category="category",
|
60 |
-
name_description="name_description",
|
61 |
-
amount=6.0,
|
62 |
-
type=TransactionType.EXPENSE,
|
63 |
-
),
|
64 |
-
TransactionCreate(
|
65 |
-
user_id=user_id,
|
66 |
-
transaction_date=datetime(2022, 1, 7),
|
67 |
-
category="category",
|
68 |
-
name_description="name_description",
|
69 |
-
amount=7.0,
|
70 |
-
type=TransactionType.INCOME,
|
71 |
-
),
|
72 |
-
TransactionCreate(
|
73 |
-
user_id=user_id,
|
74 |
-
transaction_date=datetime(2022, 1, 8),
|
75 |
-
category="category",
|
76 |
-
name_description="name_description",
|
77 |
-
amount=8.0,
|
78 |
-
type=TransactionType.INCOME,
|
79 |
-
),
|
80 |
-
TransactionCreate(
|
81 |
-
user_id=user_id,
|
82 |
-
transaction_date=datetime(2022, 1, 9),
|
83 |
-
category="category",
|
84 |
-
name_description="name_description",
|
85 |
-
amount=9.0,
|
86 |
-
type=TransactionType.EXPENSE,
|
87 |
-
),
|
88 |
-
TransactionCreate(
|
89 |
-
user_id=user_id,
|
90 |
-
transaction_date=datetime(2022, 1, 10),
|
91 |
-
category="category",
|
92 |
-
name_description="name_description",
|
93 |
-
amount=10.0,
|
94 |
-
type=TransactionType.EXPENSE,
|
95 |
-
),
|
96 |
-
]
|
97 |
|
98 |
@pytest.mark.asyncio
|
99 |
async def test_transactions(client: TestClient, get_db_session_fixture: AsyncSession) -> None:
|
100 |
|
101 |
session_override = get_db_session_fixture
|
|
|
|
|
102 |
user = await User.create(session_override, name="user", email="email", hashed_password="password")
|
103 |
|
|
|
104 |
fake_transactions = get_fake_transactions(user.id)
|
105 |
await Transaction.bulk_create(session_override, fake_transactions)
|
106 |
|
|
|
107 |
response = client.get("/api/v1/transactions/1")
|
108 |
assert response.status_code == 200
|
109 |
assert len(response.json()) == 10
|
|
|
|
|
|
|
|
|
1 |
from fastapi.testclient import TestClient
|
2 |
import pytest
|
3 |
|
4 |
from app.model.transaction import Transaction
|
|
|
|
|
|
|
|
|
5 |
from app.model.user import User
|
6 |
+
from sqlalchemy.ext.asyncio import AsyncSession
|
7 |
+
from tests.utils import get_fake_transactions
|
8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
@pytest.mark.asyncio
|
11 |
async def test_transactions(client: TestClient, get_db_session_fixture: AsyncSession) -> None:
|
12 |
|
13 |
session_override = get_db_session_fixture
|
14 |
+
|
15 |
+
# 1. Create a user
|
16 |
user = await User.create(session_override, name="user", email="email", hashed_password="password")
|
17 |
|
18 |
+
# 2. Create a bunch of transactions
|
19 |
fake_transactions = get_fake_transactions(user.id)
|
20 |
await Transaction.bulk_create(session_override, fake_transactions)
|
21 |
|
22 |
+
# 3. Verify that the transactions are returned
|
23 |
response = client.get("/api/v1/transactions/1")
|
24 |
assert response.status_code == 200
|
25 |
assert len(response.json()) == 10
|
tests/utils.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from datetime import datetime
|
2 |
+
from typing import List
|
3 |
+
from app.schema.index import TransactionCreate, TransactionType
|
4 |
+
|
5 |
+
|
6 |
+
def get_fake_transactions(user_id: int) -> List[TransactionCreate]:
|
7 |
+
return [
|
8 |
+
TransactionCreate(
|
9 |
+
user_id=user_id,
|
10 |
+
transaction_date=datetime(2022, 1, 1),
|
11 |
+
category="category1",
|
12 |
+
name_description="name_description",
|
13 |
+
amount=1.0,
|
14 |
+
type=TransactionType.EXPENSE,
|
15 |
+
),
|
16 |
+
TransactionCreate(
|
17 |
+
user_id=user_id,
|
18 |
+
transaction_date=datetime(2022, 1, 2),
|
19 |
+
category="category2",
|
20 |
+
name_description="name_description",
|
21 |
+
amount=2.0,
|
22 |
+
type=TransactionType.EXPENSE,
|
23 |
+
),
|
24 |
+
TransactionCreate(
|
25 |
+
user_id=user_id,
|
26 |
+
transaction_date=datetime(2022, 1, 3),
|
27 |
+
category="category3",
|
28 |
+
name_description="name_description",
|
29 |
+
amount=3.0,
|
30 |
+
type=TransactionType.INCOME,
|
31 |
+
),
|
32 |
+
TransactionCreate(
|
33 |
+
user_id=user_id,
|
34 |
+
transaction_date=datetime(2022, 1, 4),
|
35 |
+
category="category1",
|
36 |
+
name_description="name_description",
|
37 |
+
amount=4.0,
|
38 |
+
type=TransactionType.INCOME,
|
39 |
+
),
|
40 |
+
TransactionCreate(
|
41 |
+
user_id=user_id,
|
42 |
+
transaction_date=datetime(2022, 1, 5),
|
43 |
+
category="category2",
|
44 |
+
name_description="name_description",
|
45 |
+
amount=5.0,
|
46 |
+
type=TransactionType.EXPENSE,
|
47 |
+
),
|
48 |
+
TransactionCreate(
|
49 |
+
user_id=user_id,
|
50 |
+
transaction_date=datetime(2022, 1, 6),
|
51 |
+
category="category3",
|
52 |
+
name_description="name_description",
|
53 |
+
amount=6.0,
|
54 |
+
type=TransactionType.EXPENSE,
|
55 |
+
),
|
56 |
+
TransactionCreate(
|
57 |
+
user_id=user_id,
|
58 |
+
transaction_date=datetime(2022, 1, 7),
|
59 |
+
category="category1",
|
60 |
+
name_description="name_description",
|
61 |
+
amount=7.0,
|
62 |
+
type=TransactionType.INCOME,
|
63 |
+
),
|
64 |
+
TransactionCreate(
|
65 |
+
user_id=user_id,
|
66 |
+
transaction_date=datetime(2022, 1, 8),
|
67 |
+
category="category2",
|
68 |
+
name_description="name_description",
|
69 |
+
amount=8.0,
|
70 |
+
type=TransactionType.INCOME,
|
71 |
+
),
|
72 |
+
TransactionCreate(
|
73 |
+
user_id=user_id,
|
74 |
+
transaction_date=datetime(2022, 1, 9),
|
75 |
+
category="category3",
|
76 |
+
name_description="name_description",
|
77 |
+
amount=9.0,
|
78 |
+
type=TransactionType.EXPENSE,
|
79 |
+
),
|
80 |
+
TransactionCreate(
|
81 |
+
user_id=user_id,
|
82 |
+
transaction_date=datetime(2022, 1, 10),
|
83 |
+
category="category1",
|
84 |
+
name_description="name_description",
|
85 |
+
amount=10.0,
|
86 |
+
type=TransactionType.EXPENSE,
|
87 |
+
),
|
88 |
+
]
|