from unittest import TestCase, main from db_operation import db_operator from datetime import datetime class test_db_operation(TestCase): def setUp(self) -> None: self.db_operator = db_operator('sqlite:///local_db.db') self.db_operator.delete_all_stocks() self.stock1 = { "ticker": 'AAPL', "weight": 1.0, "display_name": 'Apple Inc.', "date": datetime(2021, 1, 1) } self.stock2 = { "ticker": 'MSFT', "weight": 1.0, "display_name": 'Microsoft Corporation', "date": datetime(2021, 1, 1) } def test_insert(self): stock1 = { "ticker": 'AAPL', "weight": 1.0, "display_name": 'Apple Inc.', "date": datetime(2021, 1, 1) } self.db_operator.add_stock(stock1) retrieved_stock = self.db_operator.get_stocks_between(datetime(2021, 1, 1), datetime(2021, 1, 1))[0] self.assertEqual(retrieved_stock.ticker, 'AAPL') self.assertEqual(retrieved_stock.weight, 1.0) self.assertEqual(retrieved_stock.display_name, 'Apple Inc.') self.assertEqual(retrieved_stock.date, datetime(2021, 1, 1)) def test_delete(self): self.db_operator.add_stock(self.stock1) self.db_operator.add_stock(self.stock2) self.db_operator.delete_stocks_between( datetime(2021, 1, 1), datetime(2021, 1, 1)) retrieved_stocks = self.db_operator.get_stocks_between( datetime(2021, 1, 1), datetime(2021, 1, 1)) self.assertEqual(len(retrieved_stocks), 0) def test_query_window_1d(self): # insert 2 stocks between 2021-01-01 and 2021-01-01 every hour for i in range(24): self.stock1['date'] = datetime(2021, 1, 1, i) self.stock2['date'] = datetime(2021, 1, 1, i) self.db_operator.add_stock(self.stock1) self.db_operator.add_stock(self.stock2) # insert two on 2021-01-02 self.stock1['date'] = datetime(2021, 1, 2) self.stock2['date'] = datetime(2021, 1, 2) self.db_operator.add_stock(self.stock1) self.db_operator.add_stock(self.stock2) # query 1d retrieved_stocks = self.db_operator.get_stocks_between( datetime(2021, 1, 1), datetime(2021, 1, 2)) self.assertEqual(len(retrieved_stocks), 50) def test_query_window_12h(self): # insert 2 stocks every hour between 2021-01-01 and 2021-01-01 for i in range(24): self.stock1['date'] = datetime(2021, 1, 1, i) self.stock2['date'] = datetime(2021, 1, 1, i) self.db_operator.add_stock(self.stock1) self.db_operator.add_stock(self.stock2) # query 12h retrieved_stocks = self.db_operator.get_stocks_between( datetime(2021, 1, 1, 0), datetime(2021, 1, 1, 12)) self.assertEqual(len(retrieved_stocks), 26) # self.assertTrue(False) def test_query_window_1h(self): # insert 2 stocks every mins between 2021-01-01 and 2021-01-01 for i in range(60): self.stock1['date'] = datetime(2021, 1, 1, 0, i) self.stock2['date'] = datetime(2021, 1, 1, 0, i) self.db_operator.add_stock(self.stock1) self.db_operator.add_stock(self.stock2) # query 1h retrieved_stocks = self.db_operator.get_stocks_between( datetime(2021, 1, 1, 0), datetime(2021, 1, 1, 1)) self.assertEqual(len(retrieved_stocks), 120) def test_query_window_30m(self): # insert 2 stocks every 1 between 2021-01-01-00:00 and 2021-01-01-00:20 for i in range(20): self.stock1['date'] = datetime(2021, 1, 1, 0, i) self.stock2['date'] = datetime(2021, 1, 1, 0, i) self.db_operator.add_stock(self.stock1) self.db_operator.add_stock(self.stock2) # query 30m retrieved_stocks = self.db_operator.get_stocks_between( datetime(2021, 1, 1, 0), datetime(2021, 1, 1, 0, 30)) self.assertEqual(len(retrieved_stocks), 40) if __name__ == '__main__': main()