File size: 6,537 Bytes
3e323f0
cb0a98c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2b059b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cb0a98c
 
 
 
 
 
 
 
 
65eba49
 
 
 
 
 
 
3e323f0
 
 
 
65eba49
 
 
 
 
 
 
 
 
 
 
 
cb0a98c
 
 
 
 
 
 
 
b826c6c
 
 
 
 
cb0a98c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a77544c
 
 
 
 
 
 
 
cb0a98c
 
b826c6c
a77544c
 
 
 
 
 
 
 
 
 
 
b826c6c
a77544c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2b059b0
b826c6c
 
2b059b0
b826c6c
 
 
 
 
 
 
 
 
 
 
 
a77544c
 
 
 
 
 
 
b826c6c
 
 
a77544c
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
'''
Abstraction for saving data to db
'''
import pandas as pd
import table_schema as ts
from sqlalchemy import create_engine

db_url = 'sqlite:///instance/local.db'

def _validate_schema(df, schema):
    '''
    validate df has the same columns and data types as schema

    Parameters
    ----------
    df: pd.DataFrame
    schema: dict
        {column_name: data_type}

    Returns
    -------
    bool
        True if df has the same columns and data types as schema
        False otherwise
    '''

    # check if the DataFrame has the same columns as the schema
    if set(df.columns) != set(schema.keys()):
        return False
    # check if the data types of the columns match the schema
    # TODO: ignoring type check for now
    # for col, dtype in schema.items():
    #     if df[col].dtype != dtype:
    #         return False
    return True


def get_all_benchmark_profile():
    '''return all entries in the benchmark profile table'''
    return _get_all_row(ts.BENCHMARK_TABLE)

def append_to_benchmark_profile(df):
    '''append new entry to benchmark profile table'''
    # handle possible duplication caused by right fill
    most_recent_dates = get_most_recent_benchmark_profile().date
    if len(most_recent_dates) > 0:
        date = most_recent_dates[0]
        # drop df.date == date
        df = df[df.date != date]
        if len(df) != 0:
            _append_df_to_db(df, ts.BENCHMARK_TABLE, ts.BENCHMARK_TABLE_SCHEMA)
    else:
        _append_df_to_db(df, ts.BENCHMARK_TABLE, ts.BENCHMARK_TABLE_SCHEMA)

def get_most_recent_benchmark_profile():
    '''return the most recent entry in the benchmark profile table'''
    return _get_most_recent(ts.BENCHMARK_TABLE)
def get_most_recent_profile(type):
    table_name = 'benchmark_profile' if type == 'benchmark' else 'portfolio_profile'
    query = f"SELECT * FROM {table_name} WHERE date = (SELECT MAX(date) FROM {table_name})"
    with create_engine(db_url).connect() as conn:
        df = pd.read_sql(query, con=conn)
        # convert date to datetime object
        df['date'] = pd.to_datetime(df['date'])
        return df

def _get_oldest(table_name, ts_column='date'):
    query = f"SELECT * FROM {table_name} WHERE {ts_column} = (SELECT MIN({ts_column}) FROM {table_name})"
    with create_engine(db_url).connect() as conn:
        df = pd.read_sql(query, con=conn)
        df[ts_column] = pd.to_datetime(df[ts_column])
        return df

def get_oldest_stocks_price():
    df = _get_oldest(ts.STOCKS_PRICE_TABLE, ts_column='time')
    return df

def get_oldest_portfolio_profile():
    df = _get_oldest(ts.PORTFOLIO_TABLE)
    return df

def get_oldest_stocks_proce():
    df = _get_oldest(ts.STOCKS_PRICE_TABLE, ts_column='time')
    return df

def get_oldest_benchmark_profile():
    df = _get_oldest(ts.BENCHMARK_TABLE)
    return df

def _get_most_recent(table_name, ts_column='date'):
    '''return the most recent entry in the table'''
    query = f"SELECT * FROM {table_name} WHERE {ts_column} = (SELECT MAX({ts_column}) FROM {table_name})"
    with create_engine(db_url).connect() as conn:
        df = pd.read_sql(query, con=conn)
        # convert date to datetime object
        df[ts_column] = pd.to_datetime(df[ts_column])
        return df
    
def get_all_portfolio_profile():
    df = _get_all_row(ts.PORTFOLIO_TABLE)
    # df['date'] = pd.to_datetime(df['date'])
    return df

def get_most_recent_portfolio_profile():
    df = _get_most_recent(ts.PORTFOLIO_TABLE)
    return df

def get_most_recent_stocks_price():
    df = _get_most_recent(ts.STOCKS_PRICE_TABLE, ts_column='time')
    return df

def _append_df_to_db(df, table_name, schema):
    # validation 
    if not _validate_schema(df, schema):
        raise Exception(f'VALIDATION_ERROR: df does not have the same schema as the table {table_name}')
    with create_engine(db_url).connect() as conn:
        df.to_sql(table_name, con=conn, if_exists='append', index=False)

def append_to_stocks_price_table(df):
    '''append new entry to stocks price table'''
    _append_df_to_db(df, ts.STOCKS_PRICE_TABLE, ts.STOCKS_PRICE_TABLE_SCHEMA)

def get_all_stocks():
    '''
    get all stocks information

    Returns
    -------
    pd.DataFrame
        all stocks information
    '''
    with create_engine(db_url).connect() as conn:
        all_stocks = pd.read_sql(ts.STOCKS_DETAILS_TABLE, con=conn)
        return all_stocks
def save_portfolio_analytic_df(df):
    table_name = 'analytic_p'
    with create_engine(db_url).connect() as conn:
        df.to_sql(table_name, con=conn, if_exists='replace', index=False)

def get_portfolio_analytic_df():
    table_name = 'analytic_p'
    with create_engine(db_url).connect() as conn:
        df = pd.read_sql(table_name, con=conn)
        return df


def save_benchmark_analytic_df(df):
    table_name = 'analytic_b'
    with create_engine(db_url).connect() as conn:
        df.to_sql(table_name, con=conn, if_exists='replace', index=False)

def get_benchmark_analytic_df():
    table_name = 'analytic_b'
    with create_engine(db_url).connect() as conn:
        df = pd.read_sql(table_name, con=conn)
        return df

def save_analytic_df(df):
    table_name = 'analytic'
    with create_engine(db_url).connect() as conn:
        df.to_sql(table_name, con=conn, if_exists='replace', index=False)
def get_analytic_df():
    table_name = 'analytic'
    with create_engine(db_url).connect() as conn:
        df = pd.read_sql(table_name, con=conn)
        return df
def _get_all_row(table_name, ts_column='date'):
    with create_engine(db_url).connect() as conn:
        df = pd.read_sql(table_name, con=conn)
        df[ts_column] = pd.to_datetime(df[ts_column])
        return df

def get_all_stocks_price():
    '''
    return all entries in stocks price table
    '''
    return _get_all_row(ts.STOCKS_PRICE_TABLE)

def get_stocks_price(tickers: list[str]):
    '''
    return df of stock price within ticker in stocks price table
    '''
    if len(tickers) == 0:
        # so returned df has the same schema as the table
        query = f"SELECT * FROM {ts.STOCKS_PRICE_TABLE} WHERE 1=0"
    elif len(tickers) == 1:
        query = f"SELECT * FROM {ts.STOCKS_PRICE_TABLE} WHERE ticker = '{tickers[0]}'"
    else:
        query = f"SELECT * FROM {ts.STOCKS_PRICE_TABLE} WHERE ticker IN {tuple(tickers)}"
    with create_engine(db_url).connect() as conn:
        df = pd.read_sql(query, con=conn)
        df.time = pd.to_datetime(df.time)
        # drop duplicates
        return df.drop_duplicates(subset=['ticker', 'time'])