portfolio_management / pipeline /test_db_peration.py
huggingface112's picture
move files to normal tracking except .db
976166f
raw
history blame
4.28 kB
from unittest import TestCase, main
from db_operation import db_operator
from datetime import datetime
class test_db_operation(TestCase):
def setUp(self) -> None:
self.db_operator = db_operator('sqlite:///local_db.db')
self.db_operator.delete_all_stocks()
self.stock1 = {
"ticker": 'AAPL',
"weight": 1.0,
"display_name": 'Apple Inc.',
"date": datetime(2021, 1, 1)
}
self.stock2 = {
"ticker": 'MSFT',
"weight": 1.0,
"display_name": 'Microsoft Corporation',
"date": datetime(2021, 1, 1)
}
def test_insert(self):
stock1 = {
"ticker": 'AAPL',
"weight": 1.0,
"display_name": 'Apple Inc.',
"date": datetime(2021, 1, 1)
}
self.db_operator.add_stock(stock1)
retrieved_stock = self.db_operator.get_stocks_between(datetime(2021, 1, 1), datetime(2021, 1, 1))[0]
self.assertEqual(retrieved_stock.ticker, 'AAPL')
self.assertEqual(retrieved_stock.weight, 1.0)
self.assertEqual(retrieved_stock.display_name, 'Apple Inc.')
self.assertEqual(retrieved_stock.date, datetime(2021, 1, 1))
def test_delete(self):
self.db_operator.add_stock(self.stock1)
self.db_operator.add_stock(self.stock2)
self.db_operator.delete_stocks_between(
datetime(2021, 1, 1),
datetime(2021, 1, 1))
retrieved_stocks = self.db_operator.get_stocks_between(
datetime(2021, 1, 1),
datetime(2021, 1, 1))
self.assertEqual(len(retrieved_stocks), 0)
def test_query_window_1d(self):
# insert 2 stocks between 2021-01-01 and 2021-01-01 every hour
for i in range(24):
self.stock1['date'] = datetime(2021, 1, 1, i)
self.stock2['date'] = datetime(2021, 1, 1, i)
self.db_operator.add_stock(self.stock1)
self.db_operator.add_stock(self.stock2)
# insert two on 2021-01-02
self.stock1['date'] = datetime(2021, 1, 2)
self.stock2['date'] = datetime(2021, 1, 2)
self.db_operator.add_stock(self.stock1)
self.db_operator.add_stock(self.stock2)
# query 1d
retrieved_stocks = self.db_operator.get_stocks_between(
datetime(2021, 1, 1),
datetime(2021, 1, 2))
self.assertEqual(len(retrieved_stocks), 50)
def test_query_window_12h(self):
# insert 2 stocks every hour between 2021-01-01 and 2021-01-01
for i in range(24):
self.stock1['date'] = datetime(2021, 1, 1, i)
self.stock2['date'] = datetime(2021, 1, 1, i)
self.db_operator.add_stock(self.stock1)
self.db_operator.add_stock(self.stock2)
# query 12h
retrieved_stocks = self.db_operator.get_stocks_between(
datetime(2021, 1, 1, 0),
datetime(2021, 1, 1, 12))
self.assertEqual(len(retrieved_stocks), 26)
# self.assertTrue(False)
def test_query_window_1h(self):
# insert 2 stocks every mins between 2021-01-01 and 2021-01-01
for i in range(60):
self.stock1['date'] = datetime(2021, 1, 1, 0, i)
self.stock2['date'] = datetime(2021, 1, 1, 0, i)
self.db_operator.add_stock(self.stock1)
self.db_operator.add_stock(self.stock2)
# query 1h
retrieved_stocks = self.db_operator.get_stocks_between(
datetime(2021, 1, 1, 0),
datetime(2021, 1, 1, 1))
self.assertEqual(len(retrieved_stocks), 120)
def test_query_window_30m(self):
# insert 2 stocks every 1 between 2021-01-01-00:00 and 2021-01-01-00:20
for i in range(20):
self.stock1['date'] = datetime(2021, 1, 1, 0, i)
self.stock2['date'] = datetime(2021, 1, 1, 0, i)
self.db_operator.add_stock(self.stock1)
self.db_operator.add_stock(self.stock2)
# query 30m
retrieved_stocks = self.db_operator.get_stocks_between(
datetime(2021, 1, 1, 0),
datetime(2021, 1, 1, 0, 30))
self.assertEqual(len(retrieved_stocks), 40)
if __name__ == '__main__':
main()