Spaces:
Runtime error
Runtime error
update app
Browse files- Ashare_data_.py +0 -347
- Inference_datapipe_.py +0 -155
- app.py +351 -1
- requirement → requirements.txt +0 -0
Ashare_data_.py
DELETED
@@ -1,347 +0,0 @@
|
|
1 |
-
import akshare as ak
|
2 |
-
import pandas as pd
|
3 |
-
import os
|
4 |
-
import csv
|
5 |
-
import re
|
6 |
-
import time
|
7 |
-
import math
|
8 |
-
import json
|
9 |
-
import random
|
10 |
-
from datasets import Dataset
|
11 |
-
import datasets
|
12 |
-
|
13 |
-
# os.chdir("/Users/mac/Desktop/FinGPT_Forecasting_Project/")
|
14 |
-
# print(os.getcwd())
|
15 |
-
|
16 |
-
start_date = "20230201"
|
17 |
-
end_date = "20240101"
|
18 |
-
|
19 |
-
# ------------------------------------------------------------------------------
|
20 |
-
# Data Aquisition
|
21 |
-
# ------------------------------------------------------------------------------
|
22 |
-
|
23 |
-
# get return
|
24 |
-
def get_return(symbol, adjust="qfq"):
|
25 |
-
"""
|
26 |
-
Get stock return data.
|
27 |
-
|
28 |
-
Args:
|
29 |
-
symbol: str
|
30 |
-
A-share market stock symbol
|
31 |
-
adjust: str ("qfq", "hfq")
|
32 |
-
price ajustment
|
33 |
-
default = "qfq" 前复权
|
34 |
-
|
35 |
-
Return:
|
36 |
-
weekly forward filled return data
|
37 |
-
"""
|
38 |
-
|
39 |
-
# load data
|
40 |
-
return_data = ak.stock_zh_a_hist(symbol=symbol, period="daily", start_date=start_date, end_date=end_date, adjust=adjust)
|
41 |
-
|
42 |
-
# process timestamp
|
43 |
-
return_data["日期"] = pd.to_datetime(return_data["日期"])
|
44 |
-
return_data.set_index("日期", inplace=True)
|
45 |
-
|
46 |
-
# resample and filled with forward data
|
47 |
-
weekly_data = return_data["收盘"].resample("W").ffill()
|
48 |
-
weekly_returns = weekly_data.pct_change()[1:]
|
49 |
-
weekly_start_prices = weekly_data[:-1]
|
50 |
-
weekly_end_prices = weekly_data[1:]
|
51 |
-
weekly_data = pd.DataFrame({
|
52 |
-
'起始日期': weekly_start_prices.index,
|
53 |
-
'起始价': weekly_start_prices.values,
|
54 |
-
'结算日期': weekly_end_prices.index,
|
55 |
-
'结算价': weekly_end_prices.values,
|
56 |
-
'周收益': weekly_returns.values
|
57 |
-
})
|
58 |
-
weekly_data["简化周收益"] = weekly_data["周收益"].map(return_transform)
|
59 |
-
|
60 |
-
return weekly_data
|
61 |
-
def return_transform(ret):
|
62 |
-
|
63 |
-
up_down = '涨' if ret >= 0 else '跌'
|
64 |
-
integer = math.ceil(abs(100 * ret))
|
65 |
-
if integer == 0:
|
66 |
-
return "平"
|
67 |
-
|
68 |
-
return up_down + (str(integer) if integer <= 5 else '5+')
|
69 |
-
|
70 |
-
# get basics
|
71 |
-
def get_basic(symbol, data):
|
72 |
-
"""
|
73 |
-
Get and match basic data to news dataframe.
|
74 |
-
|
75 |
-
Args:
|
76 |
-
symbol: str
|
77 |
-
A-share market stock symbol
|
78 |
-
data: DataFrame
|
79 |
-
dated news data
|
80 |
-
|
81 |
-
Return:
|
82 |
-
financial news dataframe with matched basic_financial info
|
83 |
-
"""
|
84 |
-
key_financials = ['报告期', '净利润同比增长率', '营业总收入同比增长率', '流动比率', '速动比率', '资产负债率']
|
85 |
-
|
86 |
-
# load quarterly basic data
|
87 |
-
basic_quarter_financials = ak.stock_financial_abstract_ths(symbol = symbol, indicator="按单季度")
|
88 |
-
basic_fin_dict = basic_quarter_financials.to_dict("index")
|
89 |
-
basic_fin_list = [dict([(key, val) for key, val in basic_fin_dict[i].items() if (key in key_financials) and val]) for i in range(len(basic_fin_dict))]
|
90 |
-
|
91 |
-
# match basic financial data to news dataframe
|
92 |
-
matched_basic_fin = []
|
93 |
-
for i, row in data.iterrows():
|
94 |
-
|
95 |
-
newsweek_enddate = row['结算日期'].strftime("%Y-%m-%d")
|
96 |
-
|
97 |
-
matched_basic = {}
|
98 |
-
for basic in basic_fin_list:
|
99 |
-
# match the most current financial report
|
100 |
-
if basic["报告期"] < newsweek_enddate:
|
101 |
-
matched_basic = basic
|
102 |
-
break
|
103 |
-
matched_basic_fin.append(json.dumps(matched_basic, ensure_ascii=False))
|
104 |
-
|
105 |
-
data['基本面'] = matched_basic_fin
|
106 |
-
|
107 |
-
return data
|
108 |
-
|
109 |
-
def raw_financial_data(symbol, with_basics = True):
|
110 |
-
|
111 |
-
# get return data from API
|
112 |
-
data = get_return(symbol=symbol)
|
113 |
-
|
114 |
-
# get news data from local
|
115 |
-
file_name = "news_data" + symbol + ".csv"
|
116 |
-
news_df = pd.read_csv("HS300_news_data20240118/"+file_name, index_col=0)
|
117 |
-
news_df["发布时间"] = pd.to_datetime(news_df["发布时间"], exact=False, format="%Y-%m-%d")
|
118 |
-
news_df.sort_values(by=["发布时间"], inplace=True)
|
119 |
-
|
120 |
-
# match weekly news for return data
|
121 |
-
news_list = []
|
122 |
-
for a, row in data.iterrows():
|
123 |
-
week_start_date = row['起始日期'].strftime('%Y-%m-%d')
|
124 |
-
week_end_date = row['结算日期'].strftime('%Y-%m-%d')
|
125 |
-
print(symbol, ': ', week_start_date, ' - ', week_end_date)
|
126 |
-
|
127 |
-
weekly_news = news_df.loc[(news_df["发布时间"]>week_start_date) & (news_df["发布时间"]<week_end_date)]
|
128 |
-
|
129 |
-
weekly_news = [
|
130 |
-
{
|
131 |
-
"发布时间": n["发布时间"].strftime('%Y%m%d'),
|
132 |
-
"新闻标题": n['新闻标题'],
|
133 |
-
"新闻内容": n['新闻内容'],
|
134 |
-
} for a, n in weekly_news.iterrows()
|
135 |
-
]
|
136 |
-
news_list.append(json.dumps(weekly_news,ensure_ascii=False))
|
137 |
-
|
138 |
-
data["新闻"] = news_list
|
139 |
-
|
140 |
-
if with_basics:
|
141 |
-
data = get_basic(symbol=symbol, data=data)
|
142 |
-
# data.to_csv(symbol+start_date+"_"+end_date+".csv")
|
143 |
-
else:
|
144 |
-
data['新闻'] = [json.dumps({})] * len(data)
|
145 |
-
# data.to_csv(symbol+start_date+"_"+end_date+"_nobasics.csv")
|
146 |
-
|
147 |
-
return data
|
148 |
-
|
149 |
-
# ------------------------------------------------------------------------------
|
150 |
-
# Prompt Generation
|
151 |
-
# ------------------------------------------------------------------------------
|
152 |
-
|
153 |
-
# SYSTEM_PROMPT = "你是一个经验丰富的股票市场分析师。你的任务是根据过去几周的相关新闻和基本财务状况,列出公司的积极发展和潜在担忧,然后对公司未来一周的股价变化提供分析和预测。" \
|
154 |
-
# "你的回答语言应为中文。你的回答格式应该如下:\n\n[积极发展]:\n1. ...\n\n[潜在担忧]:\n1. ...\n\n[预测和分析]:\n...\n"
|
155 |
-
SYSTEM_PROMPT = "你是一名经验丰富的股票市场分析师。你的任务是根据公司在过去几周内的相关新闻和季度财务状况,列出公司的积极发展和潜在担忧,然后结合你对整体金融经济市场的判断,对公司未来一周的股价变化提供预测和分析。" \
|
156 |
-
"你的回答语言应为中文。你的回答格式应该如下:\n\n[积极发展]:\n1. ...\n\n[潜在担忧]:\n1. ...\n\n[预测和分析]:\n...\n"
|
157 |
-
|
158 |
-
def get_company_prompt_new(symbol):
|
159 |
-
try:
|
160 |
-
company_profile = dict(ak.stock_individual_info_em(symbol).values)
|
161 |
-
except:
|
162 |
-
print("Company Info Request Time Out! Please wait and retry.")
|
163 |
-
company_profile["上市时间"] = pd.to_datetime(str(company_profile["上市时间"])).strftime("%Y年%m月%d日")
|
164 |
-
|
165 |
-
template = "[公司介绍]:\n\n{股票简称}是一家在{行业}行业的领先实体,自{上市时间}成立并公开交易。截止今天,{股票简称}的总市值为{总市值}人民币,总股本数为{总股本},流通市值为{流通市值}人民币,流通股数为{流通股}。" \
|
166 |
-
"\n\n{股票简称}主要在中国运营,以股票代码{股票代码}在交易所进行交易。"
|
167 |
-
|
168 |
-
formatted_profile = template.format(**company_profile)
|
169 |
-
stockname = company_profile['股票简称']
|
170 |
-
return formatted_profile, stockname
|
171 |
-
|
172 |
-
def map_return_label(return_lb):
|
173 |
-
"""
|
174 |
-
Map abbrev in the raw data
|
175 |
-
Example:
|
176 |
-
涨1 -- 上涨1%
|
177 |
-
跌2 -- 下跌2%
|
178 |
-
平 -- 股价持平
|
179 |
-
"""
|
180 |
-
|
181 |
-
lb = return_lb.replace('涨', '上涨')
|
182 |
-
lb = lb.replace('跌', '下跌')
|
183 |
-
lb = lb.replace('平', '股价持平')
|
184 |
-
lb = lb.replace('1', '0-1%')
|
185 |
-
lb = lb.replace('2', '1-2%')
|
186 |
-
lb = lb.replace('3', '2-3%')
|
187 |
-
lb = lb.replace('4', '3-4%')
|
188 |
-
if lb.endswith('+'):
|
189 |
-
lb = lb.replace('5+', '超过5%')
|
190 |
-
else:
|
191 |
-
lb = lb.replace('5', '4-5%')
|
192 |
-
|
193 |
-
return lb
|
194 |
-
|
195 |
-
# check news quality
|
196 |
-
def check_news_quality(n, last_n, week_end_date, repeat_rate = 0.6):
|
197 |
-
try:
|
198 |
-
# check content avalability
|
199 |
-
if not (not(str(n['新闻内容'])[0].isdigit()) and not(str(n['新闻内容'])=='nan') and n['发布时间'][:8] <= week_end_date.replace('-', '')):
|
200 |
-
return False
|
201 |
-
# check highly duplicated news
|
202 |
-
# (assume the duplicated contents happened adjacent)
|
203 |
-
|
204 |
-
elif str(last_n['新闻内容'])=='nan':
|
205 |
-
return True
|
206 |
-
elif len(set(n['新闻内容'][:20]) & set(last_n['新闻内容'][:20])) >= 20*repeat_rate or len(set(n['新闻标题']) & set(last_n['新闻标题']))/len(last_n['新闻标题']) > repeat_rate:
|
207 |
-
return False
|
208 |
-
|
209 |
-
else:
|
210 |
-
return True
|
211 |
-
except TypeError:
|
212 |
-
print(n)
|
213 |
-
print(last_n)
|
214 |
-
raise Exception("Check Error")
|
215 |
-
|
216 |
-
def get_prompt_by_row_new(stock, row):
|
217 |
-
"""
|
218 |
-
Generate prompt for each row in the raw data
|
219 |
-
Args:
|
220 |
-
stock: str
|
221 |
-
stock name
|
222 |
-
row: pandas.Series
|
223 |
-
Return:
|
224 |
-
head: heading prompt
|
225 |
-
news: news info
|
226 |
-
basics: basic financial info
|
227 |
-
"""
|
228 |
-
|
229 |
-
week_start_date = row['起始日期'] if isinstance(row['起始日期'], str) else row['起始日期'].strftime('%Y-%m-%d')
|
230 |
-
week_end_date = row['结算日期'] if isinstance(row['结算日期'], str) else row['结算日期'].strftime('%Y-%m-%d')
|
231 |
-
term = '上涨' if row['结算价'] > row['起始价'] else '下跌'
|
232 |
-
chg = map_return_label(row['简化周收益'])
|
233 |
-
head = "自{}至{},{}的股票价格由{:.2f}{}至{:.2f},涨跌幅为:{}。在此期间的公司新闻如下:\n\n".format(
|
234 |
-
week_start_date, week_end_date, stock, row['起始价'], term, row['结算价'], chg)
|
235 |
-
|
236 |
-
news = json.loads(row["新闻"])
|
237 |
-
|
238 |
-
left, right = 0, 0
|
239 |
-
filtered_news = []
|
240 |
-
while left < len(news):
|
241 |
-
n = news[left]
|
242 |
-
|
243 |
-
if left == 0:
|
244 |
-
# check first news quality
|
245 |
-
if (not(str(n['新闻内容'])[0].isdigit()) and not(str(n['新闻内容'])=='nan') and n['发布时间'][:8] <= week_end_date.replace('-', '')):
|
246 |
-
filtered_news.append("[新闻标题]:{}\n[新闻内容]:{}\n".format(n['新闻标题'], n['新闻内容']))
|
247 |
-
left += 1
|
248 |
-
|
249 |
-
else:
|
250 |
-
news_check = check_news_quality(n, last_n = news[right], week_end_date= week_end_date, repeat_rate=0.5)
|
251 |
-
if news_check:
|
252 |
-
filtered_news.append("[新闻标题]:{}\n[新闻内容]:{}\n".format(n['新闻标题'], n['新闻内容']))
|
253 |
-
left += 1
|
254 |
-
right += 1
|
255 |
-
|
256 |
-
|
257 |
-
basics = json.loads(row['基本面'])
|
258 |
-
if basics:
|
259 |
-
basics = "如下所列为{}近期的一些金融基本面信息,记录时间为{}:\n\n[金融基本面]:\n\n".format(
|
260 |
-
stock, basics['报告期']) + "\n".join(f"{k}: {v}" for k, v in basics.items() if k != 'period')
|
261 |
-
else:
|
262 |
-
basics = "[金融基本面]:\n\n 无金融基本面记录"
|
263 |
-
|
264 |
-
return head, filtered_news, basics
|
265 |
-
|
266 |
-
def sample_news(news, k=5):
|
267 |
-
"""
|
268 |
-
Ramdomly select past news.
|
269 |
-
|
270 |
-
Args:
|
271 |
-
news:
|
272 |
-
newslist in the timerange
|
273 |
-
k: int
|
274 |
-
the number of selected news
|
275 |
-
"""
|
276 |
-
return [news[i] for i in sorted(random.sample(range(len(news)), k))]
|
277 |
-
|
278 |
-
def get_all_prompts_new(symbol, min_past_week=1, max_past_weeks=2, with_basics=True):
|
279 |
-
"""
|
280 |
-
Generate prompt. The prompt consists of news from past weeks, basics financial information, and weekly return.
|
281 |
-
History news in the prompt is chosen from past weeks range from min_past_week to max_past_week,
|
282 |
-
and there is a number constraint on ramdomly selected data (default: up to 5).
|
283 |
-
|
284 |
-
Args:
|
285 |
-
symbol: str
|
286 |
-
stock ticker
|
287 |
-
min_past_week: int
|
288 |
-
max_past_week: int
|
289 |
-
with_basics: bool
|
290 |
-
If true, add basic infomation to the prompt
|
291 |
-
|
292 |
-
Return:
|
293 |
-
Prompts for the daterange
|
294 |
-
"""
|
295 |
-
|
296 |
-
# Load Data
|
297 |
-
df = raw_financial_data(symbol, with_basics=with_basics)
|
298 |
-
|
299 |
-
company_prompt, stock = get_company_prompt_new(symbol)
|
300 |
-
|
301 |
-
prev_rows = []
|
302 |
-
all_prompts = []
|
303 |
-
|
304 |
-
for row_idx, row in df.iterrows():
|
305 |
-
|
306 |
-
prompt = ""
|
307 |
-
|
308 |
-
# judge for available history news
|
309 |
-
if len(prev_rows) >= min_past_week:
|
310 |
-
|
311 |
-
# randomly set retrieve data of past weeks
|
312 |
-
# idx = min(random.choice(range(min_past_week, max_past_weeks+1)), len(prev_rows))
|
313 |
-
idx = min(max_past_weeks, len(prev_rows))
|
314 |
-
for i in range(-idx, 0):
|
315 |
-
# Add Head
|
316 |
-
prompt += "\n" + prev_rows[i][0]
|
317 |
-
# Add History News (with numbers constraint)
|
318 |
-
sampled_news = sample_news(
|
319 |
-
prev_rows[i][1],
|
320 |
-
min(3, len(prev_rows[i][1]))
|
321 |
-
)
|
322 |
-
if sampled_news:
|
323 |
-
prompt += "\n".join(sampled_news)
|
324 |
-
else:
|
325 |
-
prompt += "无有关新闻报告"
|
326 |
-
|
327 |
-
head, news, basics = get_prompt_by_row_new(stock, row)
|
328 |
-
|
329 |
-
prev_rows.append((head, news, basics))
|
330 |
-
|
331 |
-
if len(prev_rows) > max_past_weeks:
|
332 |
-
prev_rows.pop(0)
|
333 |
-
|
334 |
-
# set this to make sure there is history news for each considered date
|
335 |
-
if not prompt:
|
336 |
-
continue
|
337 |
-
|
338 |
-
prediction = map_return_label(row['简化周收益'])
|
339 |
-
|
340 |
-
prompt = company_prompt + '\n' + prompt + '\n' + basics
|
341 |
-
|
342 |
-
prompt += f"\n\n基于在{row['起始日期'].strftime('%Y-%m-%d')}之前的所有信息,让我们首先分析{stock}的积极发展和潜在担忧。请简洁地陈述,分别提出2-4个最重要的因素。大部分所提及的因素应该从公司的相关新闻中推断出来。" \
|
343 |
-
f"那么让我们假设你对于下一周({row['起始日期'].strftime('%Y-%m-%d')}至{row['结算日期'].strftime('%Y-%m-%d')})的预测是{prediction}。提供一个总结分析来支持你的预测。预测结果需要从你最后的分析中推断出来,因此不作为你分析的基础因素。"
|
344 |
-
|
345 |
-
all_prompts.append(prompt.strip())
|
346 |
-
|
347 |
-
return all_prompts
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Inference_datapipe_.py
DELETED
@@ -1,155 +0,0 @@
|
|
1 |
-
# Inference Data
|
2 |
-
# get company news online
|
3 |
-
from datetime import date
|
4 |
-
import akshare as ak
|
5 |
-
import pandas as pd
|
6 |
-
from datetime import date, datetime, timedelta
|
7 |
-
from Ashare_data import *
|
8 |
-
|
9 |
-
#default symbol
|
10 |
-
symbol = "600519"
|
11 |
-
B_INST, E_INST = "[INST]", "[/INST]"
|
12 |
-
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
|
13 |
-
|
14 |
-
def get_curday():
|
15 |
-
|
16 |
-
return date.today().strftime("%Y%m%d")
|
17 |
-
|
18 |
-
def n_weeks_before(date_string, n, format = "%Y%m%d"):
|
19 |
-
|
20 |
-
date = datetime.strptime(date_string, "%Y%m%d") - timedelta(days=7*n)
|
21 |
-
|
22 |
-
return date.strftime(format=format)
|
23 |
-
|
24 |
-
|
25 |
-
def get_news(symbol, max_page = 3):
|
26 |
-
|
27 |
-
df_list = []
|
28 |
-
for page in range(1, max_page):
|
29 |
-
|
30 |
-
try:
|
31 |
-
df_list.append(ak.stock_news_em(symbol, page))
|
32 |
-
except KeyError:
|
33 |
-
print(str(symbol) + "pages obtained for symbol: " + page)
|
34 |
-
break
|
35 |
-
|
36 |
-
news_df = pd.concat(df_list, ignore_index=True)
|
37 |
-
return news_df
|
38 |
-
|
39 |
-
# get return
|
40 |
-
def get_cur_return(symbol, start_date, end_date, adjust="qfq"):
|
41 |
-
"""
|
42 |
-
date = "yyyymmdd"
|
43 |
-
"""
|
44 |
-
|
45 |
-
# load data
|
46 |
-
return_data = ak.stock_zh_a_hist(symbol=symbol, period="daily", start_date=start_date, end_date=end_date, adjust=adjust)
|
47 |
-
|
48 |
-
# process timestamp
|
49 |
-
return_data["日期"] = pd.to_datetime(return_data["日期"])
|
50 |
-
return_data.set_index("日期", inplace=True)
|
51 |
-
|
52 |
-
# resample and filled with forward data
|
53 |
-
weekly_data = return_data["收盘"].resample("W").ffill()
|
54 |
-
weekly_returns = weekly_data.pct_change()[1:]
|
55 |
-
weekly_start_prices = weekly_data[:-1]
|
56 |
-
weekly_end_prices = weekly_data[1:]
|
57 |
-
weekly_data = pd.DataFrame({
|
58 |
-
'起始日期': weekly_start_prices.index,
|
59 |
-
'起始价': weekly_start_prices.values,
|
60 |
-
'结算日期': weekly_end_prices.index,
|
61 |
-
'结算价': weekly_end_prices.values,
|
62 |
-
'周收益': weekly_returns.values
|
63 |
-
})
|
64 |
-
weekly_data["简化周收益"] = weekly_data["周收益"].map(return_transform)
|
65 |
-
# check enddate
|
66 |
-
if weekly_data.iloc[-1, 2] > pd.to_datetime(end_date):
|
67 |
-
weekly_data.iloc[-1, 2] = pd.to_datetime(end_date)
|
68 |
-
|
69 |
-
return weekly_data
|
70 |
-
|
71 |
-
# get basics
|
72 |
-
def cur_financial_data(symbol, start_date, end_date, with_basics = True):
|
73 |
-
|
74 |
-
# get data
|
75 |
-
data = get_cur_return(symbol=symbol, start_date=start_date, end_date=end_date)
|
76 |
-
|
77 |
-
news_df = get_news(symbol=symbol)
|
78 |
-
news_df["发布时间"] = pd.to_datetime(news_df["发布时间"], exact=False, format="%Y-%m-%d")
|
79 |
-
news_df.sort_values(by=["发布时间"], inplace=True)
|
80 |
-
|
81 |
-
# match weekly news for return data
|
82 |
-
news_list = []
|
83 |
-
for a, row in data.iterrows():
|
84 |
-
week_start_date = row['起始日期'].strftime('%Y-%m-%d')
|
85 |
-
week_end_date = row['结算日期'].strftime('%Y-%m-%d')
|
86 |
-
print(symbol, ': ', week_start_date, ' - ', week_end_date)
|
87 |
-
|
88 |
-
weekly_news = news_df.loc[(news_df["发布时间"]>week_start_date) & (news_df["发布时间"]<week_end_date)]
|
89 |
-
|
90 |
-
weekly_news = [
|
91 |
-
{
|
92 |
-
"发布时间": n["发布时间"].strftime('%Y%m%d'),
|
93 |
-
"新闻标题": n['新闻标题'],
|
94 |
-
"新闻内容": n['新闻内容'],
|
95 |
-
} for a, n in weekly_news.iterrows()
|
96 |
-
]
|
97 |
-
news_list.append(json.dumps(weekly_news,ensure_ascii=False))
|
98 |
-
|
99 |
-
data["新闻"] = news_list
|
100 |
-
|
101 |
-
if with_basics:
|
102 |
-
data = get_basic(symbol=symbol, data=data)
|
103 |
-
# data.to_csv(symbol+start_date+"_"+end_date+".csv")
|
104 |
-
else:
|
105 |
-
data['新闻'] = [json.dumps({})] * len(data)
|
106 |
-
# data.to_csv(symbol+start_date+"_"+end_date+"_nobasics.csv")
|
107 |
-
|
108 |
-
return data
|
109 |
-
|
110 |
-
def get_all_prompts_online(symbol, with_basics=True, max_news_perweek = 3, weeks_before = 2):
|
111 |
-
|
112 |
-
end_date = get_curday()
|
113 |
-
start_date = n_weeks_before(end_date, weeks_before)
|
114 |
-
|
115 |
-
company_prompt, stock = get_company_prompt_new(symbol)
|
116 |
-
data = cur_financial_data(symbol=symbol, start_date=start_date, end_date=end_date, with_basics=with_basics)
|
117 |
-
|
118 |
-
prev_rows = []
|
119 |
-
|
120 |
-
for row_idx, row in data.iterrows():
|
121 |
-
head, news, basics = get_prompt_by_row_new(symbol, row)
|
122 |
-
prev_rows.append((head, news, basics))
|
123 |
-
|
124 |
-
prompt = ""
|
125 |
-
for i in range(-len(prev_rows), 0):
|
126 |
-
prompt += "\n" + prev_rows[i][0]
|
127 |
-
sampled_news = sample_news(
|
128 |
-
prev_rows[i][1],
|
129 |
-
min(max_news_perweek, len(prev_rows[i][1]))
|
130 |
-
)
|
131 |
-
if sampled_news:
|
132 |
-
prompt += "\n".join(sampled_news)
|
133 |
-
else:
|
134 |
-
prompt += "No relative news reported."
|
135 |
-
|
136 |
-
next_date = n_weeks_before(end_date, -1, format="%Y-%m-%d")
|
137 |
-
end_date = pd.to_datetime(end_date).strftime("%Y-%m-%d")
|
138 |
-
period = "{}至{}".format(end_date, next_date)
|
139 |
-
|
140 |
-
if with_basics:
|
141 |
-
basics = prev_rows[-1][2]
|
142 |
-
else:
|
143 |
-
basics = "[金融基本面]:\n\n 无金融基本面记录"
|
144 |
-
|
145 |
-
info = company_prompt + '\n' + prompt + '\n' + basics
|
146 |
-
|
147 |
-
new_system_prompt = SYSTEM_PROMPT.replace(':\n...', ':\n预测涨跌幅:...\n总结分析:...')
|
148 |
-
prompt = B_INST + B_SYS + new_system_prompt + E_SYS + info + f"\n\n基于在{end_date}之前的所有信息,让我们首先分析{stock}的积极发展和潜在担忧。请简洁地陈述,分别提出2-4个最重要的因素。大部分所提及的因素应该从公司的相关新闻中推断出来。" \
|
149 |
-
f"接下来请预测{symbol}下周({period})的股票涨跌幅,并提供一个总结分析来支持你的预测。" + E_INST
|
150 |
-
|
151 |
-
return info, prompt
|
152 |
-
|
153 |
-
if __name__ == "__main__":
|
154 |
-
info, pt = get_all_prompts_online(symbol=symbol)
|
155 |
-
print(pt)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app.py
CHANGED
@@ -1,10 +1,18 @@
|
|
1 |
import gradio as gr
|
2 |
from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM
|
3 |
from peft import PeftModel
|
4 |
-
import torch
|
5 |
from Ashare_data import *
|
6 |
from Inference_datapipe import *
|
7 |
import re
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
# load model
|
10 |
model = "meta-llama/Llama-2-7b-chat-hf"
|
@@ -19,6 +27,348 @@ model = PeftModel.from_pretrained(model, peft_model, offload_folder="offload/")
|
|
19 |
|
20 |
model = model.eval()
|
21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
def ask(symbol, weeks_before):
|
24 |
|
|
|
1 |
import gradio as gr
|
2 |
from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM
|
3 |
from peft import PeftModel
|
|
|
4 |
from Ashare_data import *
|
5 |
from Inference_datapipe import *
|
6 |
import re
|
7 |
+
import akshare as ak
|
8 |
+
import pandas as pd
|
9 |
+
import random
|
10 |
+
import json
|
11 |
+
import requests
|
12 |
+
import math
|
13 |
+
from datetime import date
|
14 |
+
from datetime import date, datetime, timedelta
|
15 |
+
|
16 |
|
17 |
# load model
|
18 |
model = "meta-llama/Llama-2-7b-chat-hf"
|
|
|
27 |
|
28 |
model = model.eval()
|
29 |
|
30 |
+
# Inference Data
|
31 |
+
# get company news online
|
32 |
+
|
33 |
+
B_INST, E_INST = "[INST]", "[/INST]"
|
34 |
+
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
|
35 |
+
SYSTEM_PROMPT = "你是一名经验丰富的股票市场分析师。你的任务是根据公司在过去几周内的相关新闻和季度财务状况,列出公司的积极发展和潜在担忧,然后结合你对整体金融经济市场的判断,对公司未来一周的股价变化提供预测和分析。" \
|
36 |
+
"你的回答语言应为中文。你的回答格式应该如下:\n\n[积极发展]:\n1. ...\n\n[潜在担忧]:\n1. ...\n\n[预测和分析]:\n...\n"
|
37 |
+
|
38 |
+
# ------------------------------------------------------------------------------
|
39 |
+
# Utils
|
40 |
+
# ------------------------------------------------------------------------------
|
41 |
+
def get_curday():
|
42 |
+
|
43 |
+
return date.today().strftime("%Y%m%d")
|
44 |
+
|
45 |
+
def n_weeks_before(date_string, n, format = "%Y%m%d"):
|
46 |
+
|
47 |
+
date = datetime.strptime(date_string, "%Y%m%d") - timedelta(days=7*n)
|
48 |
+
|
49 |
+
return date.strftime(format=format)
|
50 |
+
|
51 |
+
def check_news_quality(n, last_n, week_end_date, repeat_rate = 0.6):
|
52 |
+
try:
|
53 |
+
# check content avalability
|
54 |
+
if not (not(str(n['新闻内容'])[0].isdigit()) and not(str(n['新闻内容'])=='nan') and n['发布时间'][:8] <= week_end_date.replace('-', '')):
|
55 |
+
return False
|
56 |
+
# check highly duplicated news
|
57 |
+
# (assume the duplicated contents happened adjacent)
|
58 |
+
|
59 |
+
elif str(last_n['新闻内容'])=='nan':
|
60 |
+
return True
|
61 |
+
elif len(set(n['新闻内容'][:20]) & set(last_n['新闻内容'][:20])) >= 20*repeat_rate or len(set(n['新闻标题']) & set(last_n['新闻标题']))/len(last_n['新闻标题']) > repeat_rate:
|
62 |
+
return False
|
63 |
+
|
64 |
+
else:
|
65 |
+
return True
|
66 |
+
except TypeError:
|
67 |
+
print(n)
|
68 |
+
print(last_n)
|
69 |
+
raise Exception("Check Error")
|
70 |
+
|
71 |
+
def sample_news(news, k=5):
|
72 |
+
|
73 |
+
return [news[i] for i in sorted(random.sample(range(len(news)), k))]
|
74 |
+
|
75 |
+
def return_transform(ret):
|
76 |
+
|
77 |
+
up_down = '涨' if ret >= 0 else '跌'
|
78 |
+
integer = math.ceil(abs(100 * ret))
|
79 |
+
if integer == 0:
|
80 |
+
return "平"
|
81 |
+
|
82 |
+
return up_down + (str(integer) if integer <= 5 else '5+')
|
83 |
+
|
84 |
+
def map_return_label(return_lb):
|
85 |
+
|
86 |
+
lb = return_lb.replace('涨', '上涨')
|
87 |
+
lb = lb.replace('跌', '下跌')
|
88 |
+
lb = lb.replace('平', '股价持平')
|
89 |
+
lb = lb.replace('1', '0-1%')
|
90 |
+
lb = lb.replace('2', '1-2%')
|
91 |
+
lb = lb.replace('3', '2-3%')
|
92 |
+
lb = lb.replace('4', '3-4%')
|
93 |
+
if lb.endswith('+'):
|
94 |
+
lb = lb.replace('5+', '超过5%')
|
95 |
+
else:
|
96 |
+
lb = lb.replace('5', '4-5%')
|
97 |
+
|
98 |
+
return lb
|
99 |
+
# ------------------------------------------------------------------------------
|
100 |
+
# Get data from website
|
101 |
+
# ------------------------------------------------------------------------------
|
102 |
+
def stock_news_em(symbol: str = "300059", page = 1) -> pd.DataFrame:
|
103 |
+
|
104 |
+
url = "https://search-api-web.eastmoney.com/search/jsonp"
|
105 |
+
params = {
|
106 |
+
"cb": "jQuery3510875346244069884_1668256937995",
|
107 |
+
"param": '{"uid":"",'
|
108 |
+
+ f'"keyword":"{symbol}"'
|
109 |
+
+ ',"type":["cmsArticleWebOld"],"client":"web","clientType":"web","clientVersion":"curr","param":{"cmsArticleWebOld":{"searchScope":"default","sort":"default",' + f'"pageIndex":{page}'+ ',"pageSize":100,"preTag":"<em>","postTag":"</em>"}}}',
|
110 |
+
"_": "1668256937996",
|
111 |
+
}
|
112 |
+
r = requests.get(url, params=params)
|
113 |
+
data_text = r.text
|
114 |
+
data_json = json.loads(
|
115 |
+
data_text.strip("jQuery3510875346244069884_1668256937995(")[:-1]
|
116 |
+
)
|
117 |
+
temp_df = pd.DataFrame(data_json["result"]["cmsArticleWebOld"])
|
118 |
+
temp_df.rename(
|
119 |
+
columns={
|
120 |
+
"date": "发布时间",
|
121 |
+
"mediaName": "文章来源",
|
122 |
+
"code": "-",
|
123 |
+
"title": "新闻标题",
|
124 |
+
"content": "新闻内容",
|
125 |
+
"url": "新闻链接",
|
126 |
+
"image": "-",
|
127 |
+
},
|
128 |
+
inplace=True,
|
129 |
+
)
|
130 |
+
temp_df["关键词"] = symbol
|
131 |
+
temp_df = temp_df[
|
132 |
+
[
|
133 |
+
"关键词",
|
134 |
+
"新闻标题",
|
135 |
+
"新闻内容",
|
136 |
+
"发布时间",
|
137 |
+
"文章来源",
|
138 |
+
"新闻链接",
|
139 |
+
]
|
140 |
+
]
|
141 |
+
temp_df["新闻标题"] = (
|
142 |
+
temp_df["新闻标题"]
|
143 |
+
.str.replace(r"\(<em>", "", regex=True)
|
144 |
+
.str.replace(r"</em>\)", "", regex=True)
|
145 |
+
)
|
146 |
+
temp_df["新闻标题"] = (
|
147 |
+
temp_df["新闻标题"]
|
148 |
+
.str.replace(r"<em>", "", regex=True)
|
149 |
+
.str.replace(r"</em>", "", regex=True)
|
150 |
+
)
|
151 |
+
temp_df["新闻内容"] = (
|
152 |
+
temp_df["新闻内容"]
|
153 |
+
.str.replace(r"\(<em>", "", regex=True)
|
154 |
+
.str.replace(r"</em>\)", "", regex=True)
|
155 |
+
)
|
156 |
+
temp_df["新闻内容"] = (
|
157 |
+
temp_df["新闻内容"]
|
158 |
+
.str.replace(r"<em>", "", regex=True)
|
159 |
+
.str.replace(r"</em>", "", regex=True)
|
160 |
+
)
|
161 |
+
temp_df["新闻内容"] = temp_df["新闻内容"].str.replace(r"\u3000", "", regex=True)
|
162 |
+
temp_df["新闻内容"] = temp_df["新闻内容"].str.replace(r"\r\n", " ", regex=True)
|
163 |
+
return temp_df
|
164 |
+
|
165 |
+
def get_news(symbol, max_page = 3):
|
166 |
+
|
167 |
+
df_list = []
|
168 |
+
for page in range(1, max_page):
|
169 |
+
|
170 |
+
try:
|
171 |
+
df_list.append(stock_news_em(symbol, page))
|
172 |
+
except KeyError:
|
173 |
+
print(str(symbol) + "pages obtained for symbol: " + page)
|
174 |
+
break
|
175 |
+
|
176 |
+
news_df = pd.concat(df_list, ignore_index=True)
|
177 |
+
return news_df
|
178 |
+
|
179 |
+
def get_cur_return(symbol, start_date, end_date, adjust="qfq"):
|
180 |
+
|
181 |
+
# load data
|
182 |
+
return_data = ak.stock_zh_a_hist(symbol=symbol, period="daily", start_date=start_date, end_date=end_date, adjust=adjust)
|
183 |
+
|
184 |
+
# process timestamp
|
185 |
+
return_data["日期"] = pd.to_datetime(return_data["日期"])
|
186 |
+
return_data.set_index("日期", inplace=True)
|
187 |
+
|
188 |
+
# resample and filled with forward data
|
189 |
+
weekly_data = return_data["收盘"].resample("W").ffill()
|
190 |
+
weekly_returns = weekly_data.pct_change()[1:]
|
191 |
+
weekly_start_prices = weekly_data[:-1]
|
192 |
+
weekly_end_prices = weekly_data[1:]
|
193 |
+
weekly_data = pd.DataFrame({
|
194 |
+
'起始日期': weekly_start_prices.index,
|
195 |
+
'起始价': weekly_start_prices.values,
|
196 |
+
'结算日期': weekly_end_prices.index,
|
197 |
+
'结算价': weekly_end_prices.values,
|
198 |
+
'周收益': weekly_returns.values
|
199 |
+
})
|
200 |
+
weekly_data["简化周收益"] = weekly_data["周收益"].map(return_transform)
|
201 |
+
# check enddate
|
202 |
+
if weekly_data.iloc[-1, 2] > pd.to_datetime(end_date):
|
203 |
+
weekly_data.iloc[-1, 2] = pd.to_datetime(end_date)
|
204 |
+
|
205 |
+
return weekly_data
|
206 |
+
|
207 |
+
def get_basic(symbol, data):
|
208 |
+
|
209 |
+
key_financials = ['报告期', '净利润同比增长率', '营业总收入同比增长率', '流动比率', '速动比率', '资产负债率']
|
210 |
+
|
211 |
+
# load quarterly basic data
|
212 |
+
basic_quarter_financials = ak.stock_financial_abstract_ths(symbol = symbol, indicator="按单季度")
|
213 |
+
basic_fin_dict = basic_quarter_financials.to_dict("index")
|
214 |
+
basic_fin_list = [dict([(key, val) for key, val in basic_fin_dict[i].items() if (key in key_financials) and val]) for i in range(len(basic_fin_dict))]
|
215 |
+
|
216 |
+
# match basic financial data to news dataframe
|
217 |
+
matched_basic_fin = []
|
218 |
+
for i, row in data.iterrows():
|
219 |
+
|
220 |
+
newsweek_enddate = row['结算日期'].strftime("%Y-%m-%d")
|
221 |
+
|
222 |
+
matched_basic = {}
|
223 |
+
for basic in basic_fin_list:
|
224 |
+
# match the most current financial report
|
225 |
+
if basic["报告期"] < newsweek_enddate:
|
226 |
+
matched_basic = basic
|
227 |
+
break
|
228 |
+
matched_basic_fin.append(json.dumps(matched_basic, ensure_ascii=False))
|
229 |
+
|
230 |
+
data['基本面'] = matched_basic_fin
|
231 |
+
|
232 |
+
return data
|
233 |
+
# ------------------------------------------------------------------------------
|
234 |
+
# Structure Data
|
235 |
+
# ------------------------------------------------------------------------------
|
236 |
+
def cur_financial_data(symbol, start_date, end_date, with_basics = True):
|
237 |
+
|
238 |
+
# get data
|
239 |
+
data = get_cur_return(symbol=symbol, start_date=start_date, end_date=end_date)
|
240 |
+
|
241 |
+
news_df = get_news(symbol=symbol)
|
242 |
+
news_df["发布时间"] = pd.to_datetime(news_df["发布时间"], exact=False, format="%Y-%m-%d")
|
243 |
+
news_df.sort_values(by=["发布时间"], inplace=True)
|
244 |
+
|
245 |
+
# match weekly news for return data
|
246 |
+
news_list = []
|
247 |
+
for a, row in data.iterrows():
|
248 |
+
week_start_date = row['起始日期'].strftime('%Y-%m-%d')
|
249 |
+
week_end_date = row['结算日期'].strftime('%Y-%m-%d')
|
250 |
+
print(symbol, ': ', week_start_date, ' - ', week_end_date)
|
251 |
+
|
252 |
+
weekly_news = news_df.loc[(news_df["发布时间"]>week_start_date) & (news_df["发布时间"]<week_end_date)]
|
253 |
+
|
254 |
+
weekly_news = [
|
255 |
+
{
|
256 |
+
"发布时间": n["发布时间"].strftime('%Y%m%d'),
|
257 |
+
"新闻标题": n['新闻标题'],
|
258 |
+
"新闻内容": n['新闻内容'],
|
259 |
+
} for a, n in weekly_news.iterrows()
|
260 |
+
]
|
261 |
+
news_list.append(json.dumps(weekly_news,ensure_ascii=False))
|
262 |
+
|
263 |
+
data["新闻"] = news_list
|
264 |
+
|
265 |
+
if with_basics:
|
266 |
+
data = get_basic(symbol=symbol, data=data)
|
267 |
+
# data.to_csv(symbol+start_date+"_"+end_date+".csv")
|
268 |
+
else:
|
269 |
+
data['新闻'] = [json.dumps({})] * len(data)
|
270 |
+
# data.to_csv(symbol+start_date+"_"+end_date+"_nobasics.csv")
|
271 |
+
|
272 |
+
return data
|
273 |
+
# ------------------------------------------------------------------------------
|
274 |
+
# Formate Instruction
|
275 |
+
# ------------------------------------------------------------------------------
|
276 |
+
def get_company_prompt_new(symbol):
|
277 |
+
try:
|
278 |
+
company_profile = dict(ak.stock_individual_info_em(symbol).values)
|
279 |
+
except:
|
280 |
+
print("Company Info Request Time Out! Please wait and retry.")
|
281 |
+
company_profile["上市时间"] = pd.to_datetime(str(company_profile["上市时间"])).strftime("%Y年%m月%d日")
|
282 |
+
|
283 |
+
template = "[公司介绍]:\n\n{股票简称}是一家在{行业}行业的领先实体,自{上市时间}成立并公开交易。截止今天,{股票简称}的总市值为{总市值}人民币,总股本数为{总股本},流通市值为{流通市值}人民币,流通股数为{流通股}。" \
|
284 |
+
"\n\n{股票简称}主要在中国运营,以股票代码{股票代码}在交易所进行交易。"
|
285 |
+
|
286 |
+
formatted_profile = template.format(**company_profile)
|
287 |
+
stockname = company_profile['股票简称']
|
288 |
+
return formatted_profile, stockname
|
289 |
+
|
290 |
+
def get_prompt_by_row_new(stock, row):
|
291 |
+
|
292 |
+
week_start_date = row['起始日期'] if isinstance(row['起始日期'], str) else row['起始日期'].strftime('%Y-%m-%d')
|
293 |
+
week_end_date = row['结算日期'] if isinstance(row['结算日期'], str) else row['结算日期'].strftime('%Y-%m-%d')
|
294 |
+
term = '上涨' if row['结算价'] > row['起始价'] else '下跌'
|
295 |
+
chg = map_return_label(row['简化周收益'])
|
296 |
+
head = "自{}至{},{}的股票价格由{:.2f}{}至{:.2f},涨跌幅为:{}。在此期间的公司新闻如下:\n\n".format(
|
297 |
+
week_start_date, week_end_date, stock, row['起始价'], term, row['结算价'], chg)
|
298 |
+
|
299 |
+
news = json.loads(row["新闻"])
|
300 |
+
|
301 |
+
left, right = 0, 0
|
302 |
+
filtered_news = []
|
303 |
+
while left < len(news):
|
304 |
+
n = news[left]
|
305 |
+
|
306 |
+
if left == 0:
|
307 |
+
# check first news quality
|
308 |
+
if (not(str(n['新闻内容'])[0].isdigit()) and not(str(n['新闻内容'])=='nan') and n['发布时间'][:8] <= week_end_date.replace('-', '')):
|
309 |
+
filtered_news.append("[新闻标题]:{}\n[新闻内容]:{}\n".format(n['新闻标题'], n['新闻内容']))
|
310 |
+
left += 1
|
311 |
+
|
312 |
+
else:
|
313 |
+
news_check = check_news_quality(n, last_n = news[right], week_end_date= week_end_date, repeat_rate=0.5)
|
314 |
+
if news_check:
|
315 |
+
filtered_news.append("[新闻标题]:{}\n[新闻内容]:{}\n".format(n['新闻标题'], n['新闻内容']))
|
316 |
+
left += 1
|
317 |
+
right += 1
|
318 |
+
|
319 |
+
|
320 |
+
basics = json.loads(row['基本面'])
|
321 |
+
if basics:
|
322 |
+
basics = "如下所列为{}近期的一些金融基本面信息,记录时间为{}:\n\n[金融基本面]:\n\n".format(
|
323 |
+
stock, basics['报告期']) + "\n".join(f"{k}: {v}" for k, v in basics.items() if k != 'period')
|
324 |
+
else:
|
325 |
+
basics = "[金融基本面]:\n\n 无金融基本面记录"
|
326 |
+
|
327 |
+
return head, filtered_news, basics
|
328 |
+
|
329 |
+
def get_all_prompts_online(symbol, with_basics=True, max_news_perweek = 3, weeks_before = 2):
|
330 |
+
|
331 |
+
end_date = get_curday()
|
332 |
+
start_date = n_weeks_before(end_date, weeks_before)
|
333 |
+
|
334 |
+
company_prompt, stock = get_company_prompt_new(symbol)
|
335 |
+
data = cur_financial_data(symbol=symbol, start_date=start_date, end_date=end_date, with_basics=with_basics)
|
336 |
+
|
337 |
+
prev_rows = []
|
338 |
+
|
339 |
+
for row_idx, row in data.iterrows():
|
340 |
+
head, news, basics = get_prompt_by_row_new(symbol, row)
|
341 |
+
prev_rows.append((head, news, basics))
|
342 |
+
|
343 |
+
prompt = ""
|
344 |
+
for i in range(-len(prev_rows), 0):
|
345 |
+
prompt += "\n" + prev_rows[i][0]
|
346 |
+
sampled_news = sample_news(
|
347 |
+
prev_rows[i][1],
|
348 |
+
min(max_news_perweek, len(prev_rows[i][1]))
|
349 |
+
)
|
350 |
+
if sampled_news:
|
351 |
+
prompt += "\n".join(sampled_news)
|
352 |
+
else:
|
353 |
+
prompt += "No relative news reported."
|
354 |
+
|
355 |
+
next_date = n_weeks_before(end_date, -1, format="%Y-%m-%d")
|
356 |
+
end_date = pd.to_datetime(end_date).strftime("%Y-%m-%d")
|
357 |
+
period = "{}至{}".format(end_date, next_date)
|
358 |
+
|
359 |
+
if with_basics:
|
360 |
+
basics = prev_rows[-1][2]
|
361 |
+
else:
|
362 |
+
basics = "[金融基本面]:\n\n 无金融基本面记录"
|
363 |
+
|
364 |
+
info = company_prompt + '\n' + prompt + '\n' + basics
|
365 |
+
|
366 |
+
new_system_prompt = SYSTEM_PROMPT.replace(':\n...', ':\n预测涨跌幅:...\n总结分析:...')
|
367 |
+
prompt = B_INST + B_SYS + new_system_prompt + E_SYS + info + f"\n\n基于在{end_date}之前的所有信息,让我们首先分析{stock}的积极发展和潜在担忧。请简洁地陈述,分别提出2-4个最重要的因素。大部分所提及的因素应该从公司的相关新闻中推断出来。" \
|
368 |
+
f"接下来请预测{symbol}下周({period})的股票涨跌幅,并提供一个总结分析来支持你的预测。" + E_INST
|
369 |
+
|
370 |
+
return info, prompt
|
371 |
+
|
372 |
|
373 |
def ask(symbol, weeks_before):
|
374 |
|
requirement → requirements.txt
RENAMED
File without changes
|