Spaces:
Runtime error
Runtime error
huggingface112
commited on
Commit
•
cb0a98c
1
Parent(s):
c8b5a38
abstract db operation to db_operation.py
Browse files- db_operation.py +85 -0
db_operation.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Abstraction for saving data to db
|
3 |
+
'''
|
4 |
+
import pandas as pd
|
5 |
+
import table_schema as ts
|
6 |
+
from sqlalchemy import create_engine
|
7 |
+
|
8 |
+
db_url = 'sqlite:///instance/local.db'
|
9 |
+
|
10 |
+
def _validate_schema(df, schema):
|
11 |
+
'''
|
12 |
+
validate df has the same columns and data types as schema
|
13 |
+
|
14 |
+
Parameters
|
15 |
+
----------
|
16 |
+
df: pd.DataFrame
|
17 |
+
schema: dict
|
18 |
+
{column_name: data_type}
|
19 |
+
|
20 |
+
Returns
|
21 |
+
-------
|
22 |
+
bool
|
23 |
+
True if df has the same columns and data types as schema
|
24 |
+
False otherwise
|
25 |
+
'''
|
26 |
+
|
27 |
+
# check if the DataFrame has the same columns as the schema
|
28 |
+
if set(df.columns) != set(schema.keys()):
|
29 |
+
return False
|
30 |
+
# check if the data types of the columns match the schema
|
31 |
+
# TODO: ignoring type check for now
|
32 |
+
# for col, dtype in schema.items():
|
33 |
+
# if df[col].dtype != dtype:
|
34 |
+
# return False
|
35 |
+
return True
|
36 |
+
|
37 |
+
# def append_benchmark_profile(df):
|
38 |
+
# '''append new entry to benchmark profile table'''
|
39 |
+
# with create_engine(db_url).connect() as conn:
|
40 |
+
# df.to_sql(ts.BENCHMARK_PROFILE_TABLE, con=conn, if_exists='append', index=False)
|
41 |
+
|
42 |
+
|
43 |
+
def get_most_recent_profile(type):
|
44 |
+
table_name = 'benchmark_profile' if type == 'benchmark' else 'portfolio_profile'
|
45 |
+
query = f"SELECT * FROM {table_name} WHERE date = (SELECT MAX(date) FROM {table_name})"
|
46 |
+
with create_engine(db_url).connect() as conn:
|
47 |
+
df = pd.read_sql(query, con=conn)
|
48 |
+
# convert date to datetime object
|
49 |
+
df['date'] = pd.to_datetime(df['date'])
|
50 |
+
return df
|
51 |
+
|
52 |
+
def _get_most_recent(table_name, ts_column='date'):
|
53 |
+
'''return the most recent entry in the table'''
|
54 |
+
query = f"SELECT * FROM {table_name} WHERE {ts_column} = (SELECT MAX({ts_column}) FROM {table_name})"
|
55 |
+
with create_engine(db_url).connect() as conn:
|
56 |
+
df = pd.read_sql(query, con=conn)
|
57 |
+
# convert date to datetime object
|
58 |
+
df[ts_column] = pd.to_datetime(df[ts_column])
|
59 |
+
return df
|
60 |
+
|
61 |
+
def get_most_recent_portfolio_profile():
|
62 |
+
df = _get_most_recent(ts.PORTFOLIO_TABLE)
|
63 |
+
df['date'] = pd.to_datetime(df['date'])
|
64 |
+
return df
|
65 |
+
|
66 |
+
def get_most_recent_stocks_price():
|
67 |
+
df = _get_most_recent(ts.STOCKS_PRICE_TABLE, ts_column='time')
|
68 |
+
df['time'] = pd.to_datetime(df['time'])
|
69 |
+
return df
|
70 |
+
|
71 |
+
def _append_df_to_db(df, table_name, schema):
|
72 |
+
# validation
|
73 |
+
if not _validate_schema(df, schema):
|
74 |
+
raise Exception(f'VALIDATION_ERROR: df does not have the same schema as the table {table_name}')
|
75 |
+
with create_engine(db_url).connect() as conn:
|
76 |
+
df.to_sql(table_name, con=conn, if_exists='append', index=False)
|
77 |
+
|
78 |
+
def append_to_stocks_price_table(df):
|
79 |
+
'''append new entry to stocks price table'''
|
80 |
+
_append_df_to_db(df, ts.STOCKS_PRICE_TABLE, ts.STOCKS_PRICE_TABLE_SCHEMA)
|
81 |
+
|
82 |
+
def get_all_stocks():
|
83 |
+
with create_engine(db_url).connect() as conn:
|
84 |
+
all_stocks = pd.read_sql(ts.STOCKS_DETAILS_TABLE, con=conn)
|
85 |
+
return all_stocks
|