Spaces:
Runtime error
Runtime error
File size: 10,497 Bytes
976166f c8b5a38 588011f 976166f 48b892d 976166f 48b892d 976166f 48b892d 976166f 48b892d 976166f 48b892d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 |
'''
contain method for api call to jqdatasdk
'''
from dotenv import load_dotenv
from datetime import datetime, timedelta
import jqdatasdk as jq
import pandas as pd
from typing import List, Optional
from sqlalchemy import create_engine
import table_schema as ts
import os
import utils
db_url = 'sqlite:///instance/local.db'
load_dotenv()
user_name = os.environ.get('JQDATA_USER')
password = os.environ.get('JQDATA_PASSWORD')
def auth_api(func):
"""
decorator for function require jqdatasdk api
"""
def wrapper(*args, **kwargs):
if (not jq.is_auth()):
jq.auth(user_name, password)
result = func(*args, **kwargs)
return result
return wrapper
def aggregate_sector(input: str) -> Optional[str]:
'''
mapping from sector to aggregated sector retur None if not found
this handling is for spotting undefined sector in current mapping
later
Return: str -- aggregated sector
None if no mapping
'''
mapping = {
'电气设备I': '工业',
'建筑装饰I': '工业',
'交通运输I': '工业',
'机械设备I': '工业',
'国防军工I': '工业',
'综合I': '工业',
'电子I': '信息与通信',
'计算机I': '信息与通信',
'通信I': '信息与通信',
'传媒I': '信息与通信',
'纺织服装I': '消费',
'家用电器I': '消费',
'汽车I': '消费',
'休闲服务I': '消费',
'商业贸易I': '消费',
'食品饮料I': '消费',
'美容护理I': '消费',
'农林牧渔I': '消费',
'钢铁I': '原料与能源',
'建筑材料I': '原料与能源',
'有色金属I': '原料与能源',
'化工I': '原料与能源',
'轻工制造I': '原料与能源',
'煤炭I': '原料与能源',
'石油石化I': '原料与能源',
'采掘I': '原料与能源',
'医药生物I': '医药卫生',
'公用事业I': '公用事业',
'环保I': '公用事业',
'房地产I': '金融与地产',
'银行I': '金融与地产',
'非银金融I': '金融与地产'
}
# return the first mapping found
sectors = input.split(" ")
maped_name = "其他"
for sector in sectors:
maped_name = mapping.get(sector, None)
if maped_name is not None:
return maped_name
return maped_name
@auth_api
def get_all_stock_info() -> tuple[pd.DataFrame, List[str]]:
'''
return all stock information
Return
------
tuple: tuple(pd.DataFrame, List[str])
DataFrame -- display_name | name | start_date | end_date | type
'''
error = []
try:
df = jq.get_all_securities()
df['ticker'] = df.index
df.reset_index(drop=True, inplace=True)
# df.reset_index(inplace=True)
return df, error
except Exception as e:
error.append(f'get_all_stock_info\n{e}')
return None, error
@auth_api
def add_detail_to_stocks(df: pd.DataFrame) -> List[str]:
"""
add display_name, name, sector, and aggregate sector to each stock if not exist already
return a list of error message
Args: pd.DataFrame
ticker | date | weight | sector | aggregate_sector | display_name | name
Returns: List[str], error messages
"""
error = []
df[['sector', 'aggregate_sector']] = df.groupby(
'ticker')[['sector', 'aggregate_sector']].ffill()
df[['display_name', 'name']] = df.groupby(
'ticker')[['display_name', 'name']].ffill()
not_have_sector = list(
df[df['aggregate_sector'].isnull()]['ticker'].unique())
not_have_name = list(df[df['name'].isnull()]['ticker'].unique())
# sector and aggregate sector
if len(not_have_sector) != 0:
try:
sectors = jq.get_industry(security=not_have_sector)
df['sector'] = df.apply(lambda x: x.sector if not pd.isna(x.sector)
else " ".join(value['industry_name']
for value in sectors[x.ticker].values()), axis=1)
df['aggregate_sector'] = df.apply(
lambda x: x.aggregate_sector if not pd.isna(x.aggregate_sector)
else aggregate_sector(x.sector), axis=1
)
except Exception as e:
error.append(f'Error on creaet_sector_information\n{ticker}\n{e}')
# display_name and name
if len(not_have_name) != 0:
try:
for ticker in not_have_name:
detail = jq.get_security_info(ticker)
df.loc[df.ticker.isin(not_have_name)
]['display_name'] = detail.display_name
df.loc[df.ticker.isin(not_have_name)]['name'] = detail.name
except Exception as e:
error.append(f'Error on get display_name and name\n{ticker}\n{e}')
return error
@auth_api
def update_portfolio_profile(stocks: List[dict], current_p: pd.DataFrame = None) -> tuple[pd.DataFrame, List[str]]:
"""create or update a portfolio profile,
return a time series of profile
Parameters
----------
stocks : List[{ticker: Str, shares: float, date:datetime}]
update profile with a list of stock information
current_p : pd.DataFrame, optional
current portfolio profile, default is None
Returns
-------
updated_profile : pd.DataFrame
ticker | date | weight | sector | aggregate_sector | display_name | name
error : List[str]
a list of error message
"""
error = []
profile_df = pd.DataFrame(stocks)
profile_df['sector'] = None
profile_df['aggregate_sector'] = None
# add display_name
try:
with create_engine(db_url).connect() as conn:
info_df = pd.read_sql_table(ts.STOCKS_DETAILS_TABLE, conn)
profile_df = pd.merge(
profile_df, info_df[['display_name', 'ticker', 'name', 'aggregate_sector', ]], on='ticker', how='left')
except Exception as e:
error.append(f'create_portfolio \n{e}')
# get sector information
incoming_error = add_detail_to_stocks(profile_df)
error.extend(incoming_error)
# concate to existing profile if exist
if current_p is not None:
profile_df = pd.concat([profile_df, current_p], ignore_index=True)
profile_df.drop_duplicates(
subset=['ticker', 'date'], keep='last', inplace=True)
profile_df.reset_index(drop=True, inplace=True)
return profile_df, error
@auth_api
def get_all_stocks_detail():
'''get df contain all stock display_name, name, sector, aggregate_sector'''
detail_df = jq.get_all_securities()
detail_df['ticker'] = detail_df.index
detail_df.reset_index(drop=True, inplace=True)
industry_info = jq.get_industry(detail_df.ticker.to_list())
detail_df['sector'] = detail_df.apply(lambda x: " ".join(
value['industry_name']for value in industry_info[x.ticker].values()), axis=1)
detail_df['aggregate_sector'] = detail_df.apply(
lambda x: aggregate_sector(x.sector), axis=1)
return detail_df
@auth_api
def get_api_usage():
return jq.get_query_count()
@auth_api
def get_stocks_price(profile: pd.DataFrame, start_date: datetime, end_date: datetime, frequency='daily') -> tuple[pd.DataFrame, List[str]]:
"""
Return a dataframe contain stock price between period of time for price in a portfolio profile
Arguments:
profile {pd.DataFrame} -- ticker | date | weight | sector | aggregate_sector | display_name | name
start_date {datetime} -- start date of the period include start date
end_date {datetime} -- end date of the period include end date
frequency {str} -- resolution of the price, default is daily
Returns: Tuple(pd.DataFrame, List[str])
pd.DataFrame -- ticker date open close high low volumn money
error_message {list} -- a list of error message
"""
error_message = []
start_str = start_date.strftime('%Y-%m-%d')
end_str = end_date.strftime('%Y-%m-%d')
if profile.date.min() < start_date:
# hanlde benchmark doesn't have weight on the exact date
start_str = profile.date.min().strftime('%Y-%m-%d')
ticker = profile['ticker'].to_list()
try:
data = jq.get_price(ticker, start_date=start_str,
end_date=end_str, frequency=frequency)
data.rename(columns={'time': 'date', 'code': "ticker"}, inplace=True)
return data, error_message
except Exception as e:
error_message.append(f'Error when fetching {ticker} \n {e}')
return None, error_message
@auth_api
def fetch_stocks_price(**params):
'''request list of stock price from start_date to end_date with frequency or count'''
stocks_df = jq.get_price(**params)
stocks_df.rename(columns={'code': 'ticker'}, inplace=True)
return stocks_df
@auth_api
def fetch_benchmark_profile(start_date: datetime, end_date: datetime, delta_time=timedelta(days=30), benchmark="000905.XSHG"):
'''
fetch benchmark profile from start_date to end_date with delta_time
Parameters
----------
start_date : datetime
start date of the period include start date
end_date : datetime
end date of the period include end date
delta_time : timedelta, optional
the default is 30 days since the jq api only update index weight once every month
'''
if end_date < start_date:
raise Exception('end_date must be greater than start_date')
results = []
# handle end_date == start_date
while start_date < end_date:
try:
date_str = start_date.strftime('%Y-%m-%d')
result = jq.get_index_weights(benchmark, date=date_str)
results.append(result)
except Exception as e:
print(f'Error when fetching {benchmark}\n\
update on {date_str} is missing\n\
{e}')
start_date += delta_time
update_df = pd.concat(results)
update_df['ticker'] = update_df.index
update_df['date'] = pd.to_datetime(update_df['date'])
update_df.rename({'date': 'time'}, inplace=True, axis=1)
# remove duplicate row
update_df = update_df.drop_duplicates(
subset=['ticker', 'time'], keep='last')
update_df.reset_index(drop=True, inplace=True)
return update_df
|