huggingface112 commited on
Commit
cb0a98c
1 Parent(s): c8b5a38

abstract db operation to db_operation.py

Browse files
Files changed (1) hide show
  1. db_operation.py +85 -0
db_operation.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Abstraction for saving data to db
3
+ '''
4
+ import pandas as pd
5
+ import table_schema as ts
6
+ from sqlalchemy import create_engine
7
+
8
+ db_url = 'sqlite:///instance/local.db'
9
+
10
+ def _validate_schema(df, schema):
11
+ '''
12
+ validate df has the same columns and data types as schema
13
+
14
+ Parameters
15
+ ----------
16
+ df: pd.DataFrame
17
+ schema: dict
18
+ {column_name: data_type}
19
+
20
+ Returns
21
+ -------
22
+ bool
23
+ True if df has the same columns and data types as schema
24
+ False otherwise
25
+ '''
26
+
27
+ # check if the DataFrame has the same columns as the schema
28
+ if set(df.columns) != set(schema.keys()):
29
+ return False
30
+ # check if the data types of the columns match the schema
31
+ # TODO: ignoring type check for now
32
+ # for col, dtype in schema.items():
33
+ # if df[col].dtype != dtype:
34
+ # return False
35
+ return True
36
+
37
+ # def append_benchmark_profile(df):
38
+ # '''append new entry to benchmark profile table'''
39
+ # with create_engine(db_url).connect() as conn:
40
+ # df.to_sql(ts.BENCHMARK_PROFILE_TABLE, con=conn, if_exists='append', index=False)
41
+
42
+
43
+ def get_most_recent_profile(type):
44
+ table_name = 'benchmark_profile' if type == 'benchmark' else 'portfolio_profile'
45
+ query = f"SELECT * FROM {table_name} WHERE date = (SELECT MAX(date) FROM {table_name})"
46
+ with create_engine(db_url).connect() as conn:
47
+ df = pd.read_sql(query, con=conn)
48
+ # convert date to datetime object
49
+ df['date'] = pd.to_datetime(df['date'])
50
+ return df
51
+
52
+ def _get_most_recent(table_name, ts_column='date'):
53
+ '''return the most recent entry in the table'''
54
+ query = f"SELECT * FROM {table_name} WHERE {ts_column} = (SELECT MAX({ts_column}) FROM {table_name})"
55
+ with create_engine(db_url).connect() as conn:
56
+ df = pd.read_sql(query, con=conn)
57
+ # convert date to datetime object
58
+ df[ts_column] = pd.to_datetime(df[ts_column])
59
+ return df
60
+
61
+ def get_most_recent_portfolio_profile():
62
+ df = _get_most_recent(ts.PORTFOLIO_TABLE)
63
+ df['date'] = pd.to_datetime(df['date'])
64
+ return df
65
+
66
+ def get_most_recent_stocks_price():
67
+ df = _get_most_recent(ts.STOCKS_PRICE_TABLE, ts_column='time')
68
+ df['time'] = pd.to_datetime(df['time'])
69
+ return df
70
+
71
+ def _append_df_to_db(df, table_name, schema):
72
+ # validation
73
+ if not _validate_schema(df, schema):
74
+ raise Exception(f'VALIDATION_ERROR: df does not have the same schema as the table {table_name}')
75
+ with create_engine(db_url).connect() as conn:
76
+ df.to_sql(table_name, con=conn, if_exists='append', index=False)
77
+
78
+ def append_to_stocks_price_table(df):
79
+ '''append new entry to stocks price table'''
80
+ _append_df_to_db(df, ts.STOCKS_PRICE_TABLE, ts.STOCKS_PRICE_TABLE_SCHEMA)
81
+
82
+ def get_all_stocks():
83
+ with create_engine(db_url).connect() as conn:
84
+ all_stocks = pd.read_sql(ts.STOCKS_DETAILS_TABLE, con=conn)
85
+ return all_stocks