camphong24032002 commited on
Commit
5c36aa4
·
1 Parent(s): 41875de

fix: return exception error

Browse files
Files changed (3) hide show
  1. models/price.py +9 -1
  2. routes/data.py +3 -1
  3. services/indicator.py +55 -31
models/price.py CHANGED
@@ -1,5 +1,13 @@
1
- from pydantic import BaseModel
2
 
3
 
4
  class PricePayload(BaseModel):
5
  symbol: str
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel, validator
2
 
3
 
4
  class PricePayload(BaseModel):
5
  symbol: str
6
+ count_back: int = 150
7
+ source: str = "database"
8
+
9
+ @validator("source")
10
+ def valid_source(cls, v):
11
+ if v not in ["database", "vnstock"]:
12
+ raise ValueError('source must be database or vnstock')
13
+ return v
routes/data.py CHANGED
@@ -61,7 +61,9 @@ async def update_entire_macd():
61
  status_code=status.HTTP_200_OK
62
  )
63
  async def get_price_data(payload: PricePayload) -> Sequence[dict]:
64
- price_df = Indicator.get_price(payload.symbol)
 
 
65
  if price_df is None:
66
  return [{"message": "Error"}]
67
  return json.loads(price_df.to_json(orient="records"))
 
61
  status_code=status.HTTP_200_OK
62
  )
63
  async def get_price_data(payload: PricePayload) -> Sequence[dict]:
64
+ price_df = Indicator.get_price(payload.symbol,
65
+ payload.count_back,
66
+ payload.source)
67
  if price_df is None:
68
  return [{"message": "Error"}]
69
  return json.loads(price_df.to_json(orient="records"))
services/indicator.py CHANGED
@@ -155,27 +155,53 @@ class Indicator:
155
  updating_price = False
156
 
157
  @staticmethod
158
- def get_price(symbol) -> pd.DataFrame:
 
 
159
  try:
160
  symbol = symbol.upper()
161
- uri = os.environ.get("MONGODB_URI")
162
- client = MongoClient(uri)
163
- database = client.get_database("data")
164
- collection = database.get_collection("price")
165
- result = list(collection.find())
166
- tmp_df = pd.DataFrame(result[0]["value"])
167
- lst_symbols = tmp_df["ticker"].values
168
- if symbol not in lst_symbols:
169
- return None
170
- symbol_index = np.argwhere(lst_symbols == symbol)[0][0]
171
- lst_values = []
172
- for record in result:
173
- value = record["value"][symbol_index]
174
- value["time"] = record["time"]
175
- lst_values.append(value)
176
- return pd.DataFrame(lst_values)
177
- except Exception:
178
- return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
 
180
  @staticmethod
181
  def update_daily_rsi() -> None:
@@ -288,10 +314,7 @@ class Indicator:
288
  client = MongoClient(uri)
289
  database = client.get_database("data")
290
  collection = database.get_collection("rsi")
291
- try:
292
- records = list(collection.find())
293
- except Exception as e:
294
- print(f"Error: {e}")
295
  record_df = pd.DataFrame(records).drop(columns=["_id"])
296
  record_df = \
297
  record_df[["time", symbol]].rename(columns={symbol: "rsi"})
@@ -303,8 +326,10 @@ class Indicator:
303
  record_df["stoch_rsi_smooth_k"], smooth_d
304
  )
305
  return record_df
306
- except Exception:
307
- return None
 
 
308
 
309
  @staticmethod
310
  def stoch_rsi(rsi: pd.Series, periods: int = 14) -> pd.Series:
@@ -384,16 +409,15 @@ class Indicator:
384
  client = MongoClient(uri)
385
  database = client.get_database("data")
386
  collection = database.get_collection("macd")
387
- try:
388
- records = list(collection.find())
389
- except Exception as e:
390
- print(f"Error: {e}")
391
  record_df = pd.DataFrame(records).drop(columns=["_id"])
392
  record_df = \
393
  record_df[["time", symbol]].rename(columns={symbol: "macd"})
394
  return record_df
395
- except Exception:
396
- return None
 
 
397
 
398
  @staticmethod
399
  def get_ichimoku_cloud(
 
155
  updating_price = False
156
 
157
  @staticmethod
158
+ def get_price(symbol: str,
159
+ count_back: int = 150,
160
+ source: str = "database") -> pd.DataFrame:
161
  try:
162
  symbol = symbol.upper()
163
+ if source == "database":
164
+ uri = os.environ.get("MONGODB_URI")
165
+ client = MongoClient(uri)
166
+ database = client.get_database("data")
167
+ collection = database.get_collection("price")
168
+ result = list(collection.find())
169
+ tmp_df = pd.DataFrame(result[0]["value"])
170
+ lst_symbols = tmp_df["ticker"].values
171
+ if symbol not in lst_symbols:
172
+ return pd.DataFrame(
173
+ [{"message": "The symbol is not existed"}]
174
+ )
175
+ symbol_index = np.argwhere(lst_symbols == symbol)[0][0]
176
+ lst_values = []
177
+ for record in result:
178
+ value = record["value"][symbol_index]
179
+ value["time"] = record["time"]
180
+ lst_values.append(value)
181
+ return pd.DataFrame(lst_values)
182
+ elif source == "vnstock":
183
+ datetime_now = datetime.utcnow()
184
+ end_date = datetime_now.date()
185
+ delta = timedelta(days=count_back)
186
+ start_date = end_date - delta
187
+ start_date = start_date.strftime(DATE_FORMAT)
188
+ end_date = end_date.strftime(DATE_FORMAT)
189
+ df = longterm_ohlc_data(symbol,
190
+ start_date,
191
+ end_date,
192
+ resolution="D",
193
+ type="stock"
194
+ ).reset_index(drop=True)
195
+ df[["open", "high", "low", "close"]] = \
196
+ df[["open", "high", "low", "close"]] * 1000
197
+ # convert open, high, low, close to int
198
+ df[["open", "high", "low", "close"]] = \
199
+ df[["open", "high", "low", "close"]].astype(int)
200
+ return df[["time", "ticker", "open", "high", "low", "close"]]
201
+ except Exception as e:
202
+ return pd.DataFrame(
203
+ [{"message": f"Caught error {e}"}]
204
+ )
205
 
206
  @staticmethod
207
  def update_daily_rsi() -> None:
 
314
  client = MongoClient(uri)
315
  database = client.get_database("data")
316
  collection = database.get_collection("rsi")
317
+ records = list(collection.find())
 
 
 
318
  record_df = pd.DataFrame(records).drop(columns=["_id"])
319
  record_df = \
320
  record_df[["time", symbol]].rename(columns={symbol: "rsi"})
 
326
  record_df["stoch_rsi_smooth_k"], smooth_d
327
  )
328
  return record_df
329
+ except Exception as e:
330
+ return pd.DataFrame(
331
+ [{"message": f"Caught error {e}"}]
332
+ )
333
 
334
  @staticmethod
335
  def stoch_rsi(rsi: pd.Series, periods: int = 14) -> pd.Series:
 
409
  client = MongoClient(uri)
410
  database = client.get_database("data")
411
  collection = database.get_collection("macd")
412
+ records = list(collection.find())
 
 
 
413
  record_df = pd.DataFrame(records).drop(columns=["_id"])
414
  record_df = \
415
  record_df[["time", symbol]].rename(columns={symbol: "macd"})
416
  return record_df
417
+ except Exception as e:
418
+ return pd.DataFrame(
419
+ [{"message": f"Caught error {e}"}]
420
+ )
421
 
422
  @staticmethod
423
  def get_ichimoku_cloud(