Spaces:
Sleeping
Sleeping
File size: 6,537 Bytes
3e323f0 cb0a98c 2b059b0 cb0a98c 65eba49 3e323f0 65eba49 cb0a98c b826c6c cb0a98c a77544c cb0a98c b826c6c a77544c b826c6c a77544c 2b059b0 b826c6c 2b059b0 b826c6c a77544c b826c6c a77544c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 |
'''
Abstraction for saving data to db
'''
import pandas as pd
import table_schema as ts
from sqlalchemy import create_engine
db_url = 'sqlite:///instance/local.db'
def _validate_schema(df, schema):
'''
validate df has the same columns and data types as schema
Parameters
----------
df: pd.DataFrame
schema: dict
{column_name: data_type}
Returns
-------
bool
True if df has the same columns and data types as schema
False otherwise
'''
# check if the DataFrame has the same columns as the schema
if set(df.columns) != set(schema.keys()):
return False
# check if the data types of the columns match the schema
# TODO: ignoring type check for now
# for col, dtype in schema.items():
# if df[col].dtype != dtype:
# return False
return True
def get_all_benchmark_profile():
'''return all entries in the benchmark profile table'''
return _get_all_row(ts.BENCHMARK_TABLE)
def append_to_benchmark_profile(df):
'''append new entry to benchmark profile table'''
# handle possible duplication caused by right fill
most_recent_dates = get_most_recent_benchmark_profile().date
if len(most_recent_dates) > 0:
date = most_recent_dates[0]
# drop df.date == date
df = df[df.date != date]
if len(df) != 0:
_append_df_to_db(df, ts.BENCHMARK_TABLE, ts.BENCHMARK_TABLE_SCHEMA)
else:
_append_df_to_db(df, ts.BENCHMARK_TABLE, ts.BENCHMARK_TABLE_SCHEMA)
def get_most_recent_benchmark_profile():
'''return the most recent entry in the benchmark profile table'''
return _get_most_recent(ts.BENCHMARK_TABLE)
def get_most_recent_profile(type):
table_name = 'benchmark_profile' if type == 'benchmark' else 'portfolio_profile'
query = f"SELECT * FROM {table_name} WHERE date = (SELECT MAX(date) FROM {table_name})"
with create_engine(db_url).connect() as conn:
df = pd.read_sql(query, con=conn)
# convert date to datetime object
df['date'] = pd.to_datetime(df['date'])
return df
def _get_oldest(table_name, ts_column='date'):
query = f"SELECT * FROM {table_name} WHERE {ts_column} = (SELECT MIN({ts_column}) FROM {table_name})"
with create_engine(db_url).connect() as conn:
df = pd.read_sql(query, con=conn)
df[ts_column] = pd.to_datetime(df[ts_column])
return df
def get_oldest_stocks_price():
df = _get_oldest(ts.STOCKS_PRICE_TABLE, ts_column='time')
return df
def get_oldest_portfolio_profile():
df = _get_oldest(ts.PORTFOLIO_TABLE)
return df
def get_oldest_stocks_proce():
df = _get_oldest(ts.STOCKS_PRICE_TABLE, ts_column='time')
return df
def get_oldest_benchmark_profile():
df = _get_oldest(ts.BENCHMARK_TABLE)
return df
def _get_most_recent(table_name, ts_column='date'):
'''return the most recent entry in the table'''
query = f"SELECT * FROM {table_name} WHERE {ts_column} = (SELECT MAX({ts_column}) FROM {table_name})"
with create_engine(db_url).connect() as conn:
df = pd.read_sql(query, con=conn)
# convert date to datetime object
df[ts_column] = pd.to_datetime(df[ts_column])
return df
def get_all_portfolio_profile():
df = _get_all_row(ts.PORTFOLIO_TABLE)
# df['date'] = pd.to_datetime(df['date'])
return df
def get_most_recent_portfolio_profile():
df = _get_most_recent(ts.PORTFOLIO_TABLE)
return df
def get_most_recent_stocks_price():
df = _get_most_recent(ts.STOCKS_PRICE_TABLE, ts_column='time')
return df
def _append_df_to_db(df, table_name, schema):
# validation
if not _validate_schema(df, schema):
raise Exception(f'VALIDATION_ERROR: df does not have the same schema as the table {table_name}')
with create_engine(db_url).connect() as conn:
df.to_sql(table_name, con=conn, if_exists='append', index=False)
def append_to_stocks_price_table(df):
'''append new entry to stocks price table'''
_append_df_to_db(df, ts.STOCKS_PRICE_TABLE, ts.STOCKS_PRICE_TABLE_SCHEMA)
def get_all_stocks():
'''
get all stocks information
Returns
-------
pd.DataFrame
all stocks information
'''
with create_engine(db_url).connect() as conn:
all_stocks = pd.read_sql(ts.STOCKS_DETAILS_TABLE, con=conn)
return all_stocks
def save_portfolio_analytic_df(df):
table_name = 'analytic_p'
with create_engine(db_url).connect() as conn:
df.to_sql(table_name, con=conn, if_exists='replace', index=False)
def get_portfolio_analytic_df():
table_name = 'analytic_p'
with create_engine(db_url).connect() as conn:
df = pd.read_sql(table_name, con=conn)
return df
def save_benchmark_analytic_df(df):
table_name = 'analytic_b'
with create_engine(db_url).connect() as conn:
df.to_sql(table_name, con=conn, if_exists='replace', index=False)
def get_benchmark_analytic_df():
table_name = 'analytic_b'
with create_engine(db_url).connect() as conn:
df = pd.read_sql(table_name, con=conn)
return df
def save_analytic_df(df):
table_name = 'analytic'
with create_engine(db_url).connect() as conn:
df.to_sql(table_name, con=conn, if_exists='replace', index=False)
def get_analytic_df():
table_name = 'analytic'
with create_engine(db_url).connect() as conn:
df = pd.read_sql(table_name, con=conn)
return df
def _get_all_row(table_name, ts_column='date'):
with create_engine(db_url).connect() as conn:
df = pd.read_sql(table_name, con=conn)
df[ts_column] = pd.to_datetime(df[ts_column])
return df
def get_all_stocks_price():
'''
return all entries in stocks price table
'''
return _get_all_row(ts.STOCKS_PRICE_TABLE)
def get_stocks_price(tickers: list[str]):
'''
return df of stock price within ticker in stocks price table
'''
if len(tickers) == 0:
# so returned df has the same schema as the table
query = f"SELECT * FROM {ts.STOCKS_PRICE_TABLE} WHERE 1=0"
elif len(tickers) == 1:
query = f"SELECT * FROM {ts.STOCKS_PRICE_TABLE} WHERE ticker = '{tickers[0]}'"
else:
query = f"SELECT * FROM {ts.STOCKS_PRICE_TABLE} WHERE ticker IN {tuple(tickers)}"
with create_engine(db_url).connect() as conn:
df = pd.read_sql(query, con=conn)
df.time = pd.to_datetime(df.time)
# drop duplicates
return df.drop_duplicates(subset=['ticker', 'time']) |