praneethys commited on
Commit
b905bd6
1 Parent(s): ad33b38

add llm service for generating income statements (#14)

Browse files

- feat: add llm service to generate income statement (6d92adc7e96771ebc55f1859275ee4206a668883)
- test: pytest for transactions (cf9085da983a5159a7de0d43b0b12c16fdc32822)

app/api/routers/file_upload.py CHANGED
@@ -1,11 +1,14 @@
1
  from typing import Annotated
2
- from fastapi import APIRouter, UploadFile
3
  from app.categorization.file_processing import process_file, save_results
4
  from app.schema.index import FileUploadCreate
5
  import asyncio
6
  import os
7
  import csv
8
 
 
 
 
9
  file_upload_router = r = APIRouter(prefix="/api/v1/file_upload", tags=["file_upload"])
10
 
11
  @r.post(
@@ -16,7 +19,7 @@ file_upload_router = r = APIRouter(prefix="/api/v1/file_upload", tags=["file_upl
16
  500: {"description": "Internal server error"},
17
  },
18
  )
19
- async def create_file(input_file: UploadFile):
20
  try:
21
  # Create directory to store all uploaded .csv files
22
  file_upload_directory_path = "data/tx_data/input"
@@ -31,10 +34,10 @@ async def create_file(input_file: UploadFile):
31
  # With the newly created file and it's path, process and save it for embedding
32
  processed_file = process_file(os.path.realpath(input_file.filename))
33
  result = await asyncio.gather(processed_file)
34
- save_results(result)
35
 
36
  except Exception:
37
  return {"message": "There was an error uploading this file. Ensure you have a .csv file with the following columns:"
38
- "\n source, date, type, category, description, amount"}
39
 
40
  return {"message": f"Successfully uploaded {input_file.filename}"}
 
1
  from typing import Annotated
2
+ from fastapi import APIRouter, UploadFile, Depends
3
  from app.categorization.file_processing import process_file, save_results
4
  from app.schema.index import FileUploadCreate
5
  import asyncio
6
  import os
7
  import csv
8
 
9
+ from app.engine.postgresdb import get_db_session
10
+ from sqlalchemy.ext.asyncio import AsyncSession
11
+
12
  file_upload_router = r = APIRouter(prefix="/api/v1/file_upload", tags=["file_upload"])
13
 
14
  @r.post(
 
19
  500: {"description": "Internal server error"},
20
  },
21
  )
22
+ async def create_file(input_file: UploadFile, db: AsyncSession = Depends(get_db_session)):
23
  try:
24
  # Create directory to store all uploaded .csv files
25
  file_upload_directory_path = "data/tx_data/input"
 
34
  # With the newly created file and it's path, process and save it for embedding
35
  processed_file = process_file(os.path.realpath(input_file.filename))
36
  result = await asyncio.gather(processed_file)
37
+ await save_results(db, result)
38
 
39
  except Exception:
40
  return {"message": "There was an error uploading this file. Ensure you have a .csv file with the following columns:"
41
+ "\n transaction_date, type, category, name_description, amount"}
42
 
43
  return {"message": f"Successfully uploaded {input_file.filename}"}
app/categorization/file_processing.py CHANGED
@@ -14,6 +14,8 @@ from app.categorization.config import RESULT_OUTPUT_FILE, CATEGORY_REFERENCE_OUT
14
  from app.model.transaction import Transaction
15
  from app.schema.index import TransactionCreate
16
 
 
 
17
 
18
  # Read file and process it (e.g. categorize transactions)
19
  async def process_file(file_path: str) -> Dict[str, Union[str, pd.DataFrame]]:
@@ -71,7 +73,7 @@ def standardize_csv_file(file_path: str) -> pd.DataFrame:
71
  return tx_list
72
 
73
 
74
- async def save_results(results: List) -> None:
75
  """
76
  Merge all interim results in the input folder and write the merged results to the output file.
77
 
@@ -104,7 +106,7 @@ async def save_results(results: List) -> None:
104
  # Save to database
105
  # FIXME: get user_id from session
106
  txn_list_to_save = [TransactionCreate(**row.to_dict(), user_id=1) for _, row in tx_list.iterrows()]
107
- await Transaction.bulk_create(txn_list_to_save)
108
 
109
  new_ref_data = tx_list[["name/description", "category"]]
110
  if os.path.exists(CATEGORY_REFERENCE_OUTPUT_FILE):
 
14
  from app.model.transaction import Transaction
15
  from app.schema.index import TransactionCreate
16
 
17
+ from sqlalchemy.ext.asyncio import AsyncSession
18
+
19
 
20
  # Read file and process it (e.g. categorize transactions)
21
  async def process_file(file_path: str) -> Dict[str, Union[str, pd.DataFrame]]:
 
73
  return tx_list
74
 
75
 
76
+ async def save_results(db: AsyncSession, results: List) -> None:
77
  """
78
  Merge all interim results in the input folder and write the merged results to the output file.
79
 
 
106
  # Save to database
107
  # FIXME: get user_id from session
108
  txn_list_to_save = [TransactionCreate(**row.to_dict(), user_id=1) for _, row in tx_list.iterrows()]
109
+ await Transaction.bulk_create(db, txn_list_to_save)
110
 
111
  new_ref_data = tx_list[["name/description", "category"]]
112
  if os.path.exists(CATEGORY_REFERENCE_OUTPUT_FILE):
app/model/transaction.py CHANGED
@@ -31,7 +31,8 @@ class Transaction(Base, BaseModel):
31
 
32
  @classmethod
33
  async def bulk_create(cls: "type[Transaction]", db: AsyncSession, transactions: List[TransactionCreate]) -> None:
34
- query = sql.insert(cls).values(transactions)
 
35
  await db.execute(query)
36
  await db.commit()
37
 
 
31
 
32
  @classmethod
33
  async def bulk_create(cls: "type[Transaction]", db: AsyncSession, transactions: List[TransactionCreate]) -> None:
34
+ values = [transaction.model_dump() for transaction in transactions]
35
+ query = sql.insert(cls).values(values)
36
  await db.execute(query)
37
  await db.commit()
38
 
app/schema/index.py CHANGED
@@ -74,3 +74,7 @@ class IncomeStatementResponse(PydanticBaseModel):
74
  date_to: datetime
75
  income: dict
76
  expenses: dict
 
 
 
 
 
74
  date_to: datetime
75
  income: dict
76
  expenses: dict
77
+
78
+ class IncomeStatementLLMResponse(PydanticBaseModel):
79
+ income: dict
80
+ expenses: dict
app/service/income_statement.py CHANGED
@@ -3,13 +3,16 @@ from app.model.transaction import Transaction as TransactionModel
3
  from app.model.income_statement import IncomeStatement as IncomeStatementModel
4
  from sqlalchemy.ext.asyncio import AsyncSession
5
 
 
 
6
 
7
  async def call_llm_to_create_income_statement(payload: IncomeStatementCreate, db: AsyncSession) -> None:
8
  transactions = await TransactionModel.get_by_user_between_dates(
9
  db, payload.user_id, payload.date_from, payload.date_to
10
  )
11
 
12
- # TODO: Call LLM to generate income and expenses
13
- income = {}
14
- expenses = {}
 
15
  await IncomeStatementModel.create(db, **payload, income=income, expenses=expenses)
 
3
  from app.model.income_statement import IncomeStatement as IncomeStatementModel
4
  from sqlalchemy.ext.asyncio import AsyncSession
5
 
6
+ from app.service.llm import call_llm
7
+
8
 
9
  async def call_llm_to_create_income_statement(payload: IncomeStatementCreate, db: AsyncSession) -> None:
10
  transactions = await TransactionModel.get_by_user_between_dates(
11
  db, payload.user_id, payload.date_from, payload.date_to
12
  )
13
 
14
+ response = call_llm(transactions)
15
+ income = response.income
16
+ expenses = response.expenses
17
+
18
  await IncomeStatementModel.create(db, **payload, income=income, expenses=expenses)
app/service/llm.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from llama_index.core.settings import Settings
3
+ from llama_index.core import PromptTemplate
4
+
5
+ from app.model.transaction import Transaction
6
+ from app.schema.index import IncomeStatementLLMResponse
7
+
8
+ def income_statement_prompt (inputData: Transaction) -> PromptTemplate:
9
+ context_str = f"""
10
+ You are an accountant skilled at organizing transactions from multiple different bank
11
+ accounts and credit card statements to prepare an income statement.
12
+
13
+ Input data has the following format: transaction_date, type, category, name_description, amount.
14
+ The <IN> tag is prepended to the input data as follows:
15
+ ```<IN>
16
+ {inputData}
17
+ ```
18
+
19
+ Your task is to prepare an income statement. The income statement's output should be in a json format.
20
+ An example of the expected output is as follows with the <OUT> tag. Note that not all categories of the transactions
21
+ have been listed below. Use below <OUT> tag as a reference in preparing the income statement.
22
+ ```<OUT>
23
+ {
24
+ "REVENUE": {
25
+ "Gross Sales": 73351.11,
26
+ "Other Income": 0,
27
+ "Balance Dec 2022": 3987.39,
28
+ },
29
+ "EXPENSES": {
30
+ "Advertising": 0,
31
+ "Commissions": 0,
32
+ "Insurance": 0,
33
+ "Memberships": 0,
34
+ "Utilities": 0,
35
+ }
36
+ }
37
+ ```
38
+
39
+ """
40
+ prompt = PromptTemplate(context_str, template_var_mappings={"inputData": inputData})
41
+ logging.info(f"Prompt: {prompt}")
42
+
43
+ async def call_llm(prompt: PromptTemplate) -> str:
44
+ llm = Settings.llm.copy()
45
+ prompt = income_statement_prompt()
46
+ llm.check_prompts(prompt)
47
+
48
+ llm.system_prompt = prompt
49
+
50
+ output = await llm.astructured_predict(output_cls=IncomeStatementLLMResponse, prompt=prompt)
51
+
52
+ logging.info(f"Output: {output}")
53
+ return output
pyproject.toml CHANGED
@@ -52,6 +52,9 @@ version = "0.2.2"
52
  [tool.black]
53
  line-length = 119
54
 
 
 
 
55
  [build-system]
56
  requires = [ "poetry-core" ]
57
  build-backend = "poetry.core.masonry.api"
 
52
  [tool.black]
53
  line-length = 119
54
 
55
+ [tool.pytest.ini_options]
56
+ asyncio_mode = "auto"
57
+
58
  [build-system]
59
  requires = [ "poetry-core" ]
60
  build-backend = "poetry.core.masonry.api"
tests/conftest.py CHANGED
@@ -37,7 +37,7 @@ def client(app):
37
  yield c
38
 
39
 
40
- @pytest.mark.asyncio(scope="session")
41
  def event_loop(request):
42
  loop = asyncio.get_event_loop_policy().new_event_loop()
43
  yield loop
@@ -52,9 +52,9 @@ async def connection_test(test_db, event_loop):
52
  pg_db = test_db.dbname
53
  pg_password = test_db.password
54
 
55
- with DatabaseJanitor(pg_user, pg_host, pg_port, pg_db, test_db.version, pg_password):
56
- connection_str = f"postgresql+asyncpg://{pg_user}:@{pg_host}:{pg_port}/{pg_db}"
57
- sessionmanager.init(connection_str)
58
  yield
59
  await sessionmanager.close()
60
 
@@ -73,3 +73,9 @@ async def session_override(app, connection_test):
73
  yield session
74
 
75
  app.dependency_overrides[get_db_session] = get_db_session_override
 
 
 
 
 
 
 
37
  yield c
38
 
39
 
40
+ @pytest.fixture(scope="session")
41
  def event_loop(request):
42
  loop = asyncio.get_event_loop_policy().new_event_loop()
43
  yield loop
 
52
  pg_db = test_db.dbname
53
  pg_password = test_db.password
54
 
55
+ with DatabaseJanitor(user=pg_user, host=pg_host, port=pg_port, dbname=pg_db, version=test_db.version, password=pg_password):
56
+ connection_str = f"postgresql+psycopg://{pg_user}:@{pg_host}:{pg_port}/{pg_db}"
57
+ sessionmanager.init(connection_str, {"echo": True, "future": True})
58
  yield
59
  await sessionmanager.close()
60
 
 
73
  yield session
74
 
75
  app.dependency_overrides[get_db_session] = get_db_session_override
76
+
77
+
78
+ @pytest.fixture(scope="function", autouse=True)
79
+ async def get_db_session_fixture():
80
+ async with sessionmanager.session() as session:
81
+ yield session
tests/test_transactions.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime
2
+ from typing import List
3
+ from fastapi import Depends
4
+ from fastapi.testclient import TestClient
5
+ import pytest
6
+
7
+ from app.model.transaction import Transaction
8
+ from app.schema.index import TransactionType, TransactionCreate
9
+
10
+ from sqlalchemy.ext.asyncio import AsyncSession
11
+ from app.engine.postgresdb import get_db_session
12
+ from app.model.user import User
13
+
14
+ def get_fake_transactions(user_id: int) -> List[TransactionCreate]:
15
+ return [
16
+ TransactionCreate(
17
+ user_id=user_id,
18
+ transaction_date=datetime(2022, 1, 1),
19
+ category="category",
20
+ name_description="name_description",
21
+ amount=1.0,
22
+ type=TransactionType.EXPENSE,
23
+ ),
24
+ TransactionCreate(
25
+ user_id=user_id,
26
+ transaction_date=datetime(2022, 1, 2),
27
+ category="category",
28
+ name_description="name_description",
29
+ amount=2.0,
30
+ type=TransactionType.EXPENSE,
31
+ ),
32
+ TransactionCreate(
33
+ user_id=user_id,
34
+ transaction_date=datetime(2022, 1, 3),
35
+ category="category",
36
+ name_description="name_description",
37
+ amount=3.0,
38
+ type=TransactionType.INCOME,
39
+ ),
40
+ TransactionCreate(
41
+ user_id=user_id,
42
+ transaction_date=datetime(2022, 1, 4),
43
+ category="category",
44
+ name_description="name_description",
45
+ amount=4.0,
46
+ type=TransactionType.INCOME,
47
+ ),
48
+ TransactionCreate(
49
+ user_id=user_id,
50
+ transaction_date=datetime(2022, 1, 5),
51
+ category="category",
52
+ name_description="name_description",
53
+ amount=5.0,
54
+ type=TransactionType.EXPENSE,
55
+ ),
56
+ TransactionCreate(
57
+ user_id=user_id,
58
+ transaction_date=datetime(2022, 1, 6),
59
+ category="category",
60
+ name_description="name_description",
61
+ amount=6.0,
62
+ type=TransactionType.EXPENSE,
63
+ ),
64
+ TransactionCreate(
65
+ user_id=user_id,
66
+ transaction_date=datetime(2022, 1, 7),
67
+ category="category",
68
+ name_description="name_description",
69
+ amount=7.0,
70
+ type=TransactionType.INCOME,
71
+ ),
72
+ TransactionCreate(
73
+ user_id=user_id,
74
+ transaction_date=datetime(2022, 1, 8),
75
+ category="category",
76
+ name_description="name_description",
77
+ amount=8.0,
78
+ type=TransactionType.INCOME,
79
+ ),
80
+ TransactionCreate(
81
+ user_id=user_id,
82
+ transaction_date=datetime(2022, 1, 9),
83
+ category="category",
84
+ name_description="name_description",
85
+ amount=9.0,
86
+ type=TransactionType.EXPENSE,
87
+ ),
88
+ TransactionCreate(
89
+ user_id=user_id,
90
+ transaction_date=datetime(2022, 1, 10),
91
+ category="category",
92
+ name_description="name_description",
93
+ amount=10.0,
94
+ type=TransactionType.EXPENSE,
95
+ ),
96
+ ]
97
+
98
+ @pytest.mark.asyncio
99
+ async def test_transactions(client: TestClient, get_db_session_fixture: AsyncSession) -> None:
100
+
101
+ session_override = get_db_session_fixture
102
+ user = await User.create(session_override, name="user", email="email", hashed_password="password")
103
+
104
+ fake_transactions = get_fake_transactions(user.id)
105
+ await Transaction.bulk_create(session_override, fake_transactions)
106
+
107
+ response = client.get("/api/v1/transactions/1")
108
+ assert response.status_code == 200
109
+ assert len(response.json()) == 10