File size: 2,898 Bytes
cb0a98c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
'''
Abstraction for saving data to db
'''
import pandas as pd
import table_schema as ts
from sqlalchemy import create_engine

db_url = 'sqlite:///instance/local.db'

def _validate_schema(df, schema):
    '''
    validate df has the same columns and data types as schema

    Parameters
    ----------
    df: pd.DataFrame
    schema: dict
        {column_name: data_type}

    Returns
    -------
    bool
        True if df has the same columns and data types as schema
        False otherwise
    '''

    # check if the DataFrame has the same columns as the schema
    if set(df.columns) != set(schema.keys()):
        return False
    # check if the data types of the columns match the schema
    # TODO: ignoring type check for now
    # for col, dtype in schema.items():
    #     if df[col].dtype != dtype:
    #         return False
    return True

# def append_benchmark_profile(df):
#     '''append new entry to benchmark profile table'''
#     with create_engine(db_url).connect() as conn:
#         df.to_sql(ts.BENCHMARK_PROFILE_TABLE, con=conn, if_exists='append', index=False)


def get_most_recent_profile(type):
    table_name = 'benchmark_profile' if type == 'benchmark' else 'portfolio_profile'
    query = f"SELECT * FROM {table_name} WHERE date = (SELECT MAX(date) FROM {table_name})"
    with create_engine(db_url).connect() as conn:
        df = pd.read_sql(query, con=conn)
        # convert date to datetime object
        df['date'] = pd.to_datetime(df['date'])
        return df

def _get_most_recent(table_name, ts_column='date'):
    '''return the most recent entry in the table'''
    query = f"SELECT * FROM {table_name} WHERE {ts_column} = (SELECT MAX({ts_column}) FROM {table_name})"
    with create_engine(db_url).connect() as conn:
        df = pd.read_sql(query, con=conn)
        # convert date to datetime object
        df[ts_column] = pd.to_datetime(df[ts_column])
        return df

def get_most_recent_portfolio_profile():
    df = _get_most_recent(ts.PORTFOLIO_TABLE)
    df['date'] = pd.to_datetime(df['date'])
    return df

def get_most_recent_stocks_price():
    df = _get_most_recent(ts.STOCKS_PRICE_TABLE, ts_column='time')
    df['time'] = pd.to_datetime(df['time'])
    return df

def _append_df_to_db(df, table_name, schema):
    # validation 
    if not _validate_schema(df, schema):
        raise Exception(f'VALIDATION_ERROR: df does not have the same schema as the table {table_name}')
    with create_engine(db_url).connect() as conn:
        df.to_sql(table_name, con=conn, if_exists='append', index=False)

def append_to_stocks_price_table(df):
    '''append new entry to stocks price table'''
    _append_df_to_db(df, ts.STOCKS_PRICE_TABLE, ts.STOCKS_PRICE_TABLE_SCHEMA)

def get_all_stocks():
    with create_engine(db_url).connect() as conn:
        all_stocks = pd.read_sql(ts.STOCKS_DETAILS_TABLE, con=conn)
        return all_stocks