Spaces:
Sleeping
Sleeping
praneethys
commited on
Commit
•
b905bd6
1
Parent(s):
ad33b38
add llm service for generating income statements (#14)
Browse files- feat: add llm service to generate income statement (6d92adc7e96771ebc55f1859275ee4206a668883)
- test: pytest for transactions (cf9085da983a5159a7de0d43b0b12c16fdc32822)
- app/api/routers/file_upload.py +7 -4
- app/categorization/file_processing.py +4 -2
- app/model/transaction.py +2 -1
- app/schema/index.py +4 -0
- app/service/income_statement.py +6 -3
- app/service/llm.py +53 -0
- pyproject.toml +3 -0
- tests/conftest.py +10 -4
- tests/test_transactions.py +109 -0
app/api/routers/file_upload.py
CHANGED
@@ -1,11 +1,14 @@
|
|
1 |
from typing import Annotated
|
2 |
-
from fastapi import APIRouter, UploadFile
|
3 |
from app.categorization.file_processing import process_file, save_results
|
4 |
from app.schema.index import FileUploadCreate
|
5 |
import asyncio
|
6 |
import os
|
7 |
import csv
|
8 |
|
|
|
|
|
|
|
9 |
file_upload_router = r = APIRouter(prefix="/api/v1/file_upload", tags=["file_upload"])
|
10 |
|
11 |
@r.post(
|
@@ -16,7 +19,7 @@ file_upload_router = r = APIRouter(prefix="/api/v1/file_upload", tags=["file_upl
|
|
16 |
500: {"description": "Internal server error"},
|
17 |
},
|
18 |
)
|
19 |
-
async def create_file(input_file: UploadFile):
|
20 |
try:
|
21 |
# Create directory to store all uploaded .csv files
|
22 |
file_upload_directory_path = "data/tx_data/input"
|
@@ -31,10 +34,10 @@ async def create_file(input_file: UploadFile):
|
|
31 |
# With the newly created file and it's path, process and save it for embedding
|
32 |
processed_file = process_file(os.path.realpath(input_file.filename))
|
33 |
result = await asyncio.gather(processed_file)
|
34 |
-
save_results(result)
|
35 |
|
36 |
except Exception:
|
37 |
return {"message": "There was an error uploading this file. Ensure you have a .csv file with the following columns:"
|
38 |
-
"\n
|
39 |
|
40 |
return {"message": f"Successfully uploaded {input_file.filename}"}
|
|
|
1 |
from typing import Annotated
|
2 |
+
from fastapi import APIRouter, UploadFile, Depends
|
3 |
from app.categorization.file_processing import process_file, save_results
|
4 |
from app.schema.index import FileUploadCreate
|
5 |
import asyncio
|
6 |
import os
|
7 |
import csv
|
8 |
|
9 |
+
from app.engine.postgresdb import get_db_session
|
10 |
+
from sqlalchemy.ext.asyncio import AsyncSession
|
11 |
+
|
12 |
file_upload_router = r = APIRouter(prefix="/api/v1/file_upload", tags=["file_upload"])
|
13 |
|
14 |
@r.post(
|
|
|
19 |
500: {"description": "Internal server error"},
|
20 |
},
|
21 |
)
|
22 |
+
async def create_file(input_file: UploadFile, db: AsyncSession = Depends(get_db_session)):
|
23 |
try:
|
24 |
# Create directory to store all uploaded .csv files
|
25 |
file_upload_directory_path = "data/tx_data/input"
|
|
|
34 |
# With the newly created file and it's path, process and save it for embedding
|
35 |
processed_file = process_file(os.path.realpath(input_file.filename))
|
36 |
result = await asyncio.gather(processed_file)
|
37 |
+
await save_results(db, result)
|
38 |
|
39 |
except Exception:
|
40 |
return {"message": "There was an error uploading this file. Ensure you have a .csv file with the following columns:"
|
41 |
+
"\n transaction_date, type, category, name_description, amount"}
|
42 |
|
43 |
return {"message": f"Successfully uploaded {input_file.filename}"}
|
app/categorization/file_processing.py
CHANGED
@@ -14,6 +14,8 @@ from app.categorization.config import RESULT_OUTPUT_FILE, CATEGORY_REFERENCE_OUT
|
|
14 |
from app.model.transaction import Transaction
|
15 |
from app.schema.index import TransactionCreate
|
16 |
|
|
|
|
|
17 |
|
18 |
# Read file and process it (e.g. categorize transactions)
|
19 |
async def process_file(file_path: str) -> Dict[str, Union[str, pd.DataFrame]]:
|
@@ -71,7 +73,7 @@ def standardize_csv_file(file_path: str) -> pd.DataFrame:
|
|
71 |
return tx_list
|
72 |
|
73 |
|
74 |
-
async def save_results(results: List) -> None:
|
75 |
"""
|
76 |
Merge all interim results in the input folder and write the merged results to the output file.
|
77 |
|
@@ -104,7 +106,7 @@ async def save_results(results: List) -> None:
|
|
104 |
# Save to database
|
105 |
# FIXME: get user_id from session
|
106 |
txn_list_to_save = [TransactionCreate(**row.to_dict(), user_id=1) for _, row in tx_list.iterrows()]
|
107 |
-
await Transaction.bulk_create(txn_list_to_save)
|
108 |
|
109 |
new_ref_data = tx_list[["name/description", "category"]]
|
110 |
if os.path.exists(CATEGORY_REFERENCE_OUTPUT_FILE):
|
|
|
14 |
from app.model.transaction import Transaction
|
15 |
from app.schema.index import TransactionCreate
|
16 |
|
17 |
+
from sqlalchemy.ext.asyncio import AsyncSession
|
18 |
+
|
19 |
|
20 |
# Read file and process it (e.g. categorize transactions)
|
21 |
async def process_file(file_path: str) -> Dict[str, Union[str, pd.DataFrame]]:
|
|
|
73 |
return tx_list
|
74 |
|
75 |
|
76 |
+
async def save_results(db: AsyncSession, results: List) -> None:
|
77 |
"""
|
78 |
Merge all interim results in the input folder and write the merged results to the output file.
|
79 |
|
|
|
106 |
# Save to database
|
107 |
# FIXME: get user_id from session
|
108 |
txn_list_to_save = [TransactionCreate(**row.to_dict(), user_id=1) for _, row in tx_list.iterrows()]
|
109 |
+
await Transaction.bulk_create(db, txn_list_to_save)
|
110 |
|
111 |
new_ref_data = tx_list[["name/description", "category"]]
|
112 |
if os.path.exists(CATEGORY_REFERENCE_OUTPUT_FILE):
|
app/model/transaction.py
CHANGED
@@ -31,7 +31,8 @@ class Transaction(Base, BaseModel):
|
|
31 |
|
32 |
@classmethod
|
33 |
async def bulk_create(cls: "type[Transaction]", db: AsyncSession, transactions: List[TransactionCreate]) -> None:
|
34 |
-
|
|
|
35 |
await db.execute(query)
|
36 |
await db.commit()
|
37 |
|
|
|
31 |
|
32 |
@classmethod
|
33 |
async def bulk_create(cls: "type[Transaction]", db: AsyncSession, transactions: List[TransactionCreate]) -> None:
|
34 |
+
values = [transaction.model_dump() for transaction in transactions]
|
35 |
+
query = sql.insert(cls).values(values)
|
36 |
await db.execute(query)
|
37 |
await db.commit()
|
38 |
|
app/schema/index.py
CHANGED
@@ -74,3 +74,7 @@ class IncomeStatementResponse(PydanticBaseModel):
|
|
74 |
date_to: datetime
|
75 |
income: dict
|
76 |
expenses: dict
|
|
|
|
|
|
|
|
|
|
74 |
date_to: datetime
|
75 |
income: dict
|
76 |
expenses: dict
|
77 |
+
|
78 |
+
class IncomeStatementLLMResponse(PydanticBaseModel):
|
79 |
+
income: dict
|
80 |
+
expenses: dict
|
app/service/income_statement.py
CHANGED
@@ -3,13 +3,16 @@ from app.model.transaction import Transaction as TransactionModel
|
|
3 |
from app.model.income_statement import IncomeStatement as IncomeStatementModel
|
4 |
from sqlalchemy.ext.asyncio import AsyncSession
|
5 |
|
|
|
|
|
6 |
|
7 |
async def call_llm_to_create_income_statement(payload: IncomeStatementCreate, db: AsyncSession) -> None:
|
8 |
transactions = await TransactionModel.get_by_user_between_dates(
|
9 |
db, payload.user_id, payload.date_from, payload.date_to
|
10 |
)
|
11 |
|
12 |
-
|
13 |
-
income =
|
14 |
-
expenses =
|
|
|
15 |
await IncomeStatementModel.create(db, **payload, income=income, expenses=expenses)
|
|
|
3 |
from app.model.income_statement import IncomeStatement as IncomeStatementModel
|
4 |
from sqlalchemy.ext.asyncio import AsyncSession
|
5 |
|
6 |
+
from app.service.llm import call_llm
|
7 |
+
|
8 |
|
9 |
async def call_llm_to_create_income_statement(payload: IncomeStatementCreate, db: AsyncSession) -> None:
|
10 |
transactions = await TransactionModel.get_by_user_between_dates(
|
11 |
db, payload.user_id, payload.date_from, payload.date_to
|
12 |
)
|
13 |
|
14 |
+
response = call_llm(transactions)
|
15 |
+
income = response.income
|
16 |
+
expenses = response.expenses
|
17 |
+
|
18 |
await IncomeStatementModel.create(db, **payload, income=income, expenses=expenses)
|
app/service/llm.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from llama_index.core.settings import Settings
|
3 |
+
from llama_index.core import PromptTemplate
|
4 |
+
|
5 |
+
from app.model.transaction import Transaction
|
6 |
+
from app.schema.index import IncomeStatementLLMResponse
|
7 |
+
|
8 |
+
def income_statement_prompt (inputData: Transaction) -> PromptTemplate:
|
9 |
+
context_str = f"""
|
10 |
+
You are an accountant skilled at organizing transactions from multiple different bank
|
11 |
+
accounts and credit card statements to prepare an income statement.
|
12 |
+
|
13 |
+
Input data has the following format: transaction_date, type, category, name_description, amount.
|
14 |
+
The <IN> tag is prepended to the input data as follows:
|
15 |
+
```<IN>
|
16 |
+
{inputData}
|
17 |
+
```
|
18 |
+
|
19 |
+
Your task is to prepare an income statement. The income statement's output should be in a json format.
|
20 |
+
An example of the expected output is as follows with the <OUT> tag. Note that not all categories of the transactions
|
21 |
+
have been listed below. Use below <OUT> tag as a reference in preparing the income statement.
|
22 |
+
```<OUT>
|
23 |
+
{
|
24 |
+
"REVENUE": {
|
25 |
+
"Gross Sales": 73351.11,
|
26 |
+
"Other Income": 0,
|
27 |
+
"Balance Dec 2022": 3987.39,
|
28 |
+
},
|
29 |
+
"EXPENSES": {
|
30 |
+
"Advertising": 0,
|
31 |
+
"Commissions": 0,
|
32 |
+
"Insurance": 0,
|
33 |
+
"Memberships": 0,
|
34 |
+
"Utilities": 0,
|
35 |
+
}
|
36 |
+
}
|
37 |
+
```
|
38 |
+
|
39 |
+
"""
|
40 |
+
prompt = PromptTemplate(context_str, template_var_mappings={"inputData": inputData})
|
41 |
+
logging.info(f"Prompt: {prompt}")
|
42 |
+
|
43 |
+
async def call_llm(prompt: PromptTemplate) -> str:
|
44 |
+
llm = Settings.llm.copy()
|
45 |
+
prompt = income_statement_prompt()
|
46 |
+
llm.check_prompts(prompt)
|
47 |
+
|
48 |
+
llm.system_prompt = prompt
|
49 |
+
|
50 |
+
output = await llm.astructured_predict(output_cls=IncomeStatementLLMResponse, prompt=prompt)
|
51 |
+
|
52 |
+
logging.info(f"Output: {output}")
|
53 |
+
return output
|
pyproject.toml
CHANGED
@@ -52,6 +52,9 @@ version = "0.2.2"
|
|
52 |
[tool.black]
|
53 |
line-length = 119
|
54 |
|
|
|
|
|
|
|
55 |
[build-system]
|
56 |
requires = [ "poetry-core" ]
|
57 |
build-backend = "poetry.core.masonry.api"
|
|
|
52 |
[tool.black]
|
53 |
line-length = 119
|
54 |
|
55 |
+
[tool.pytest.ini_options]
|
56 |
+
asyncio_mode = "auto"
|
57 |
+
|
58 |
[build-system]
|
59 |
requires = [ "poetry-core" ]
|
60 |
build-backend = "poetry.core.masonry.api"
|
tests/conftest.py
CHANGED
@@ -37,7 +37,7 @@ def client(app):
|
|
37 |
yield c
|
38 |
|
39 |
|
40 |
-
@pytest.
|
41 |
def event_loop(request):
|
42 |
loop = asyncio.get_event_loop_policy().new_event_loop()
|
43 |
yield loop
|
@@ -52,9 +52,9 @@ async def connection_test(test_db, event_loop):
|
|
52 |
pg_db = test_db.dbname
|
53 |
pg_password = test_db.password
|
54 |
|
55 |
-
with DatabaseJanitor(pg_user, pg_host, pg_port, pg_db, test_db.version, pg_password):
|
56 |
-
connection_str = f"postgresql+
|
57 |
-
sessionmanager.init(connection_str)
|
58 |
yield
|
59 |
await sessionmanager.close()
|
60 |
|
@@ -73,3 +73,9 @@ async def session_override(app, connection_test):
|
|
73 |
yield session
|
74 |
|
75 |
app.dependency_overrides[get_db_session] = get_db_session_override
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
yield c
|
38 |
|
39 |
|
40 |
+
@pytest.fixture(scope="session")
|
41 |
def event_loop(request):
|
42 |
loop = asyncio.get_event_loop_policy().new_event_loop()
|
43 |
yield loop
|
|
|
52 |
pg_db = test_db.dbname
|
53 |
pg_password = test_db.password
|
54 |
|
55 |
+
with DatabaseJanitor(user=pg_user, host=pg_host, port=pg_port, dbname=pg_db, version=test_db.version, password=pg_password):
|
56 |
+
connection_str = f"postgresql+psycopg://{pg_user}:@{pg_host}:{pg_port}/{pg_db}"
|
57 |
+
sessionmanager.init(connection_str, {"echo": True, "future": True})
|
58 |
yield
|
59 |
await sessionmanager.close()
|
60 |
|
|
|
73 |
yield session
|
74 |
|
75 |
app.dependency_overrides[get_db_session] = get_db_session_override
|
76 |
+
|
77 |
+
|
78 |
+
@pytest.fixture(scope="function", autouse=True)
|
79 |
+
async def get_db_session_fixture():
|
80 |
+
async with sessionmanager.session() as session:
|
81 |
+
yield session
|
tests/test_transactions.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from datetime import datetime
|
2 |
+
from typing import List
|
3 |
+
from fastapi import Depends
|
4 |
+
from fastapi.testclient import TestClient
|
5 |
+
import pytest
|
6 |
+
|
7 |
+
from app.model.transaction import Transaction
|
8 |
+
from app.schema.index import TransactionType, TransactionCreate
|
9 |
+
|
10 |
+
from sqlalchemy.ext.asyncio import AsyncSession
|
11 |
+
from app.engine.postgresdb import get_db_session
|
12 |
+
from app.model.user import User
|
13 |
+
|
14 |
+
def get_fake_transactions(user_id: int) -> List[TransactionCreate]:
|
15 |
+
return [
|
16 |
+
TransactionCreate(
|
17 |
+
user_id=user_id,
|
18 |
+
transaction_date=datetime(2022, 1, 1),
|
19 |
+
category="category",
|
20 |
+
name_description="name_description",
|
21 |
+
amount=1.0,
|
22 |
+
type=TransactionType.EXPENSE,
|
23 |
+
),
|
24 |
+
TransactionCreate(
|
25 |
+
user_id=user_id,
|
26 |
+
transaction_date=datetime(2022, 1, 2),
|
27 |
+
category="category",
|
28 |
+
name_description="name_description",
|
29 |
+
amount=2.0,
|
30 |
+
type=TransactionType.EXPENSE,
|
31 |
+
),
|
32 |
+
TransactionCreate(
|
33 |
+
user_id=user_id,
|
34 |
+
transaction_date=datetime(2022, 1, 3),
|
35 |
+
category="category",
|
36 |
+
name_description="name_description",
|
37 |
+
amount=3.0,
|
38 |
+
type=TransactionType.INCOME,
|
39 |
+
),
|
40 |
+
TransactionCreate(
|
41 |
+
user_id=user_id,
|
42 |
+
transaction_date=datetime(2022, 1, 4),
|
43 |
+
category="category",
|
44 |
+
name_description="name_description",
|
45 |
+
amount=4.0,
|
46 |
+
type=TransactionType.INCOME,
|
47 |
+
),
|
48 |
+
TransactionCreate(
|
49 |
+
user_id=user_id,
|
50 |
+
transaction_date=datetime(2022, 1, 5),
|
51 |
+
category="category",
|
52 |
+
name_description="name_description",
|
53 |
+
amount=5.0,
|
54 |
+
type=TransactionType.EXPENSE,
|
55 |
+
),
|
56 |
+
TransactionCreate(
|
57 |
+
user_id=user_id,
|
58 |
+
transaction_date=datetime(2022, 1, 6),
|
59 |
+
category="category",
|
60 |
+
name_description="name_description",
|
61 |
+
amount=6.0,
|
62 |
+
type=TransactionType.EXPENSE,
|
63 |
+
),
|
64 |
+
TransactionCreate(
|
65 |
+
user_id=user_id,
|
66 |
+
transaction_date=datetime(2022, 1, 7),
|
67 |
+
category="category",
|
68 |
+
name_description="name_description",
|
69 |
+
amount=7.0,
|
70 |
+
type=TransactionType.INCOME,
|
71 |
+
),
|
72 |
+
TransactionCreate(
|
73 |
+
user_id=user_id,
|
74 |
+
transaction_date=datetime(2022, 1, 8),
|
75 |
+
category="category",
|
76 |
+
name_description="name_description",
|
77 |
+
amount=8.0,
|
78 |
+
type=TransactionType.INCOME,
|
79 |
+
),
|
80 |
+
TransactionCreate(
|
81 |
+
user_id=user_id,
|
82 |
+
transaction_date=datetime(2022, 1, 9),
|
83 |
+
category="category",
|
84 |
+
name_description="name_description",
|
85 |
+
amount=9.0,
|
86 |
+
type=TransactionType.EXPENSE,
|
87 |
+
),
|
88 |
+
TransactionCreate(
|
89 |
+
user_id=user_id,
|
90 |
+
transaction_date=datetime(2022, 1, 10),
|
91 |
+
category="category",
|
92 |
+
name_description="name_description",
|
93 |
+
amount=10.0,
|
94 |
+
type=TransactionType.EXPENSE,
|
95 |
+
),
|
96 |
+
]
|
97 |
+
|
98 |
+
@pytest.mark.asyncio
|
99 |
+
async def test_transactions(client: TestClient, get_db_session_fixture: AsyncSession) -> None:
|
100 |
+
|
101 |
+
session_override = get_db_session_fixture
|
102 |
+
user = await User.create(session_override, name="user", email="email", hashed_password="password")
|
103 |
+
|
104 |
+
fake_transactions = get_fake_transactions(user.id)
|
105 |
+
await Transaction.bulk_create(session_override, fake_transactions)
|
106 |
+
|
107 |
+
response = client.get("/api/v1/transactions/1")
|
108 |
+
assert response.status_code == 200
|
109 |
+
assert len(response.json()) == 10
|