palexis3 commited on
Commit
107b48d
1 Parent(s): 9914661

Wire up transactions with user api to get supplied data

Browse files
app/api/routers/transaction.py CHANGED
@@ -46,3 +46,24 @@ async def get_transactions(user_id: int, db: AsyncSession = Depends(get_db_sessi
46
  }
47
  results.append(result)
48
  return results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  }
47
  results.append(result)
48
  return results
49
+
50
+
51
+ @r.get(
52
+ "/transaction_answer/{user_id}/{query}",
53
+ responses={
54
+ 200: {"description": "Query answered"},
55
+ 500: {"description": "Internal server error"},
56
+ },
57
+ )
58
+ async def answer_transactions_query(user_id: int, query: str, db: AsyncSession = Depends(get_db_session)):
59
+ """
60
+ Retrieve all transactions.
61
+ """
62
+ result = await TransactionModel.get_by_user(db, user_id)
63
+ all_rows = result.all()
64
+ print(f"answer_transactions_query: \n type: {type(all_rows)}\nall_rows: {all_rows}")
65
+ # TODO: Pass all rows to answer_query method in transaction_query_rag and get result that will be the
66
+ # string param in TransactionRagResponse
67
+ if len(all_rows) == 0:
68
+ raise HTTPException(status_code=500, detail="No transactions found for this user")
69
+ return all_rows
app/api/routers/user.py CHANGED
@@ -6,6 +6,10 @@ from app.engine.postgresdb import get_db_session
6
  from app.schema.index import UserCreate, User as UserSchema, UserResponse, UserUpdate
7
  from app.model.user import User as UserModel
8
 
 
 
 
 
9
 
10
  user_router = r = APIRouter(prefix="/api/v1/users", tags=["users"])
11
  logger = logging.getLogger(__name__)
@@ -29,7 +33,12 @@ async def create_user(user: UserCreate, db: AsyncSession = Depends(get_db_sessio
29
 
30
  await UserModel.create(db, **user.model_dump())
31
  user = await UserModel.get(db, email=user.email)
 
 
 
 
32
  return user
 
33
  except Exception as e:
34
  raise HTTPException(status_code=500, detail=str(e))
35
 
 
6
  from app.schema.index import UserCreate, User as UserSchema, UserResponse, UserUpdate
7
  from app.model.user import User as UserModel
8
 
9
+ from app.model.transaction import Transaction as TransactionModel
10
+
11
+ from tests.utils import get_fake_transactions
12
+
13
 
14
  user_router = r = APIRouter(prefix="/api/v1/users", tags=["users"])
15
  logger = logging.getLogger(__name__)
 
33
 
34
  await UserModel.create(db, **user.model_dump())
35
  user = await UserModel.get(db, email=user.email)
36
+
37
+ fake_transactions = get_fake_transactions(user.id)
38
+ await TransactionModel.bulk_create(db, fake_transactions)
39
+
40
  return user
41
+
42
  except Exception as e:
43
  raise HTTPException(status_code=500, detail=str(e))
44
 
app/schema/index.py CHANGED
@@ -53,6 +53,8 @@ class TransactionCreate(TransactionResponse):
53
  class Transaction(TransactionResponse):
54
  user: User
55
 
 
 
56
 
57
  class FileUploadCreate(PydanticBaseModel):
58
  source: str
 
53
  class Transaction(TransactionResponse):
54
  user: User
55
 
56
+ class TransactionRagResponse(PydanticBaseModel):
57
+ answer: str
58
 
59
  class FileUploadCreate(PydanticBaseModel):
60
  source: str
app/service/transactions_query_rag.py CHANGED
@@ -12,7 +12,7 @@ import os
12
  import pandas as pd
13
  from uuid import uuid4
14
 
15
- async def answer_query(df: pd.DataFrame, query: str) -> None:
16
  """Creates an embedding of the transactions table and then returns the answer for the given query.
17
  Args:
18
  df (pd.DataFrame): DataFrame containing the transactions that a user has entered
@@ -23,7 +23,6 @@ async def answer_query(df: pd.DataFrame, query: str) -> None:
23
  """
24
  try:
25
  batch_limit = 100
26
-
27
  pinecone_api_key = os.environ['PINECONE_API_KEY']
28
  openai_api_key = os.environ['OPENAI_API_KEY']
29
  namespace = "transactionsvector"
 
12
  import pandas as pd
13
  from uuid import uuid4
14
 
15
+ async def answer_query(df: pd.DataFrame, query: str) -> str:
16
  """Creates an embedding of the transactions table and then returns the answer for the given query.
17
  Args:
18
  df (pd.DataFrame): DataFrame containing the transactions that a user has entered
 
23
  """
24
  try:
25
  batch_limit = 100
 
26
  pinecone_api_key = os.environ['PINECONE_API_KEY']
27
  openai_api_key = os.environ['OPENAI_API_KEY']
28
  namespace = "transactionsvector"
tests/test_transactions.py CHANGED
@@ -1,25 +1,109 @@
 
 
 
1
  from fastapi.testclient import TestClient
2
  import pytest
3
 
4
  from app.model.transaction import Transaction
5
- from app.model.user import User
 
6
  from sqlalchemy.ext.asyncio import AsyncSession
7
- from tests.utils import get_fake_transactions
 
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  @pytest.mark.asyncio
11
  async def test_transactions(client: TestClient, get_db_session_fixture: AsyncSession) -> None:
12
 
13
  session_override = get_db_session_fixture
14
-
15
- # 1. Create a user
16
  user = await User.create(session_override, name="user", email="email", hashed_password="password")
17
 
18
- # 2. Create a bunch of transactions
19
  fake_transactions = get_fake_transactions(user.id)
20
  await Transaction.bulk_create(session_override, fake_transactions)
21
 
22
- # 3. Verify that the transactions are returned
23
  response = client.get("/api/v1/transactions/1")
24
  assert response.status_code == 200
25
  assert len(response.json()) == 10
 
1
+ from datetime import datetime
2
+ from typing import List
3
+ from fastapi import Depends
4
  from fastapi.testclient import TestClient
5
  import pytest
6
 
7
  from app.model.transaction import Transaction
8
+ from app.schema.index import TransactionType, TransactionCreate
9
+
10
  from sqlalchemy.ext.asyncio import AsyncSession
11
+ from app.engine.postgresdb import get_db_session
12
+ from app.model.user import User
13
 
14
+ def get_fake_transactions(user_id: int) -> List[TransactionCreate]:
15
+ return [
16
+ TransactionCreate(
17
+ user_id=user_id,
18
+ transaction_date=datetime(2022, 1, 1),
19
+ category="category",
20
+ name_description="name_description",
21
+ amount=1.0,
22
+ type=TransactionType.EXPENSE,
23
+ ),
24
+ TransactionCreate(
25
+ user_id=user_id,
26
+ transaction_date=datetime(2022, 1, 2),
27
+ category="category",
28
+ name_description="name_description",
29
+ amount=2.0,
30
+ type=TransactionType.EXPENSE,
31
+ ),
32
+ TransactionCreate(
33
+ user_id=user_id,
34
+ transaction_date=datetime(2022, 1, 3),
35
+ category="category",
36
+ name_description="name_description",
37
+ amount=3.0,
38
+ type=TransactionType.INCOME,
39
+ ),
40
+ TransactionCreate(
41
+ user_id=user_id,
42
+ transaction_date=datetime(2022, 1, 4),
43
+ category="category",
44
+ name_description="name_description",
45
+ amount=4.0,
46
+ type=TransactionType.INCOME,
47
+ ),
48
+ TransactionCreate(
49
+ user_id=user_id,
50
+ transaction_date=datetime(2022, 1, 5),
51
+ category="category",
52
+ name_description="name_description",
53
+ amount=5.0,
54
+ type=TransactionType.EXPENSE,
55
+ ),
56
+ TransactionCreate(
57
+ user_id=user_id,
58
+ transaction_date=datetime(2022, 1, 6),
59
+ category="category",
60
+ name_description="name_description",
61
+ amount=6.0,
62
+ type=TransactionType.EXPENSE,
63
+ ),
64
+ TransactionCreate(
65
+ user_id=user_id,
66
+ transaction_date=datetime(2022, 1, 7),
67
+ category="category",
68
+ name_description="name_description",
69
+ amount=7.0,
70
+ type=TransactionType.INCOME,
71
+ ),
72
+ TransactionCreate(
73
+ user_id=user_id,
74
+ transaction_date=datetime(2022, 1, 8),
75
+ category="category",
76
+ name_description="name_description",
77
+ amount=8.0,
78
+ type=TransactionType.INCOME,
79
+ ),
80
+ TransactionCreate(
81
+ user_id=user_id,
82
+ transaction_date=datetime(2022, 1, 9),
83
+ category="category",
84
+ name_description="name_description",
85
+ amount=9.0,
86
+ type=TransactionType.EXPENSE,
87
+ ),
88
+ TransactionCreate(
89
+ user_id=user_id,
90
+ transaction_date=datetime(2022, 1, 10),
91
+ category="category",
92
+ name_description="name_description",
93
+ amount=10.0,
94
+ type=TransactionType.EXPENSE,
95
+ ),
96
+ ]
97
 
98
  @pytest.mark.asyncio
99
  async def test_transactions(client: TestClient, get_db_session_fixture: AsyncSession) -> None:
100
 
101
  session_override = get_db_session_fixture
 
 
102
  user = await User.create(session_override, name="user", email="email", hashed_password="password")
103
 
 
104
  fake_transactions = get_fake_transactions(user.id)
105
  await Transaction.bulk_create(session_override, fake_transactions)
106
 
 
107
  response = client.get("/api/v1/transactions/1")
108
  assert response.status_code == 200
109
  assert len(response.json()) == 10