llk010502 commited on
Commit
5165373
1 Parent(s): fe9f282

update app

Browse files
Ashare_data_.py DELETED
@@ -1,347 +0,0 @@
1
- import akshare as ak
2
- import pandas as pd
3
- import os
4
- import csv
5
- import re
6
- import time
7
- import math
8
- import json
9
- import random
10
- from datasets import Dataset
11
- import datasets
12
-
13
- # os.chdir("/Users/mac/Desktop/FinGPT_Forecasting_Project/")
14
- # print(os.getcwd())
15
-
16
- start_date = "20230201"
17
- end_date = "20240101"
18
-
19
- # ------------------------------------------------------------------------------
20
- # Data Aquisition
21
- # ------------------------------------------------------------------------------
22
-
23
- # get return
24
- def get_return(symbol, adjust="qfq"):
25
- """
26
- Get stock return data.
27
-
28
- Args:
29
- symbol: str
30
- A-share market stock symbol
31
- adjust: str ("qfq", "hfq")
32
- price ajustment
33
- default = "qfq" 前复权
34
-
35
- Return:
36
- weekly forward filled return data
37
- """
38
-
39
- # load data
40
- return_data = ak.stock_zh_a_hist(symbol=symbol, period="daily", start_date=start_date, end_date=end_date, adjust=adjust)
41
-
42
- # process timestamp
43
- return_data["日期"] = pd.to_datetime(return_data["日期"])
44
- return_data.set_index("日期", inplace=True)
45
-
46
- # resample and filled with forward data
47
- weekly_data = return_data["收盘"].resample("W").ffill()
48
- weekly_returns = weekly_data.pct_change()[1:]
49
- weekly_start_prices = weekly_data[:-1]
50
- weekly_end_prices = weekly_data[1:]
51
- weekly_data = pd.DataFrame({
52
- '起始日期': weekly_start_prices.index,
53
- '起始价': weekly_start_prices.values,
54
- '结算日期': weekly_end_prices.index,
55
- '结算价': weekly_end_prices.values,
56
- '周收益': weekly_returns.values
57
- })
58
- weekly_data["简化周收益"] = weekly_data["周收益"].map(return_transform)
59
-
60
- return weekly_data
61
- def return_transform(ret):
62
-
63
- up_down = '涨' if ret >= 0 else '跌'
64
- integer = math.ceil(abs(100 * ret))
65
- if integer == 0:
66
- return "平"
67
-
68
- return up_down + (str(integer) if integer <= 5 else '5+')
69
-
70
- # get basics
71
- def get_basic(symbol, data):
72
- """
73
- Get and match basic data to news dataframe.
74
-
75
- Args:
76
- symbol: str
77
- A-share market stock symbol
78
- data: DataFrame
79
- dated news data
80
-
81
- Return:
82
- financial news dataframe with matched basic_financial info
83
- """
84
- key_financials = ['报告期', '净利润同比增长率', '营业总收入同比增长率', '流动比率', '速动比率', '资产负债率']
85
-
86
- # load quarterly basic data
87
- basic_quarter_financials = ak.stock_financial_abstract_ths(symbol = symbol, indicator="按单季度")
88
- basic_fin_dict = basic_quarter_financials.to_dict("index")
89
- basic_fin_list = [dict([(key, val) for key, val in basic_fin_dict[i].items() if (key in key_financials) and val]) for i in range(len(basic_fin_dict))]
90
-
91
- # match basic financial data to news dataframe
92
- matched_basic_fin = []
93
- for i, row in data.iterrows():
94
-
95
- newsweek_enddate = row['结算日期'].strftime("%Y-%m-%d")
96
-
97
- matched_basic = {}
98
- for basic in basic_fin_list:
99
- # match the most current financial report
100
- if basic["报告期"] < newsweek_enddate:
101
- matched_basic = basic
102
- break
103
- matched_basic_fin.append(json.dumps(matched_basic, ensure_ascii=False))
104
-
105
- data['基本面'] = matched_basic_fin
106
-
107
- return data
108
-
109
- def raw_financial_data(symbol, with_basics = True):
110
-
111
- # get return data from API
112
- data = get_return(symbol=symbol)
113
-
114
- # get news data from local
115
- file_name = "news_data" + symbol + ".csv"
116
- news_df = pd.read_csv("HS300_news_data20240118/"+file_name, index_col=0)
117
- news_df["发布时间"] = pd.to_datetime(news_df["发布时间"], exact=False, format="%Y-%m-%d")
118
- news_df.sort_values(by=["发布时间"], inplace=True)
119
-
120
- # match weekly news for return data
121
- news_list = []
122
- for a, row in data.iterrows():
123
- week_start_date = row['起始日期'].strftime('%Y-%m-%d')
124
- week_end_date = row['结算日期'].strftime('%Y-%m-%d')
125
- print(symbol, ': ', week_start_date, ' - ', week_end_date)
126
-
127
- weekly_news = news_df.loc[(news_df["发布时间"]>week_start_date) & (news_df["发布时间"]<week_end_date)]
128
-
129
- weekly_news = [
130
- {
131
- "发布时间": n["发布时间"].strftime('%Y%m%d'),
132
- "新闻标题": n['新闻标题'],
133
- "新闻内容": n['新闻内容'],
134
- } for a, n in weekly_news.iterrows()
135
- ]
136
- news_list.append(json.dumps(weekly_news,ensure_ascii=False))
137
-
138
- data["新闻"] = news_list
139
-
140
- if with_basics:
141
- data = get_basic(symbol=symbol, data=data)
142
- # data.to_csv(symbol+start_date+"_"+end_date+".csv")
143
- else:
144
- data['新闻'] = [json.dumps({})] * len(data)
145
- # data.to_csv(symbol+start_date+"_"+end_date+"_nobasics.csv")
146
-
147
- return data
148
-
149
- # ------------------------------------------------------------------------------
150
- # Prompt Generation
151
- # ------------------------------------------------------------------------------
152
-
153
- # SYSTEM_PROMPT = "你是一个经验丰富的股票市场分析师。你的任务是根据过去几周的相关新闻和基本财务状况,列出公司的积极发展和潜在担忧,然后对公司未来一周的股价变化提供分析和预测。" \
154
- # "你的回答语言应为中文。你的回答格式应该如下:\n\n[积极发展]:\n1. ...\n\n[潜在担忧]:\n1. ...\n\n[预测和分析]:\n...\n"
155
- SYSTEM_PROMPT = "你是一名经验丰富的股票市场分析师。你的任务是根据公司在过去几周内的相关新闻和季度财务状况,列出公司的积极发展和潜在担忧,然后结合你对整体金融经济市场的判断,对公司未来一周的股价变化提供预测和分析。" \
156
- "你的回答语言应为中文。你的回答格式应该如下:\n\n[积极发展]:\n1. ...\n\n[潜在担忧]:\n1. ...\n\n[预测和分析]:\n...\n"
157
-
158
- def get_company_prompt_new(symbol):
159
- try:
160
- company_profile = dict(ak.stock_individual_info_em(symbol).values)
161
- except:
162
- print("Company Info Request Time Out! Please wait and retry.")
163
- company_profile["上市时间"] = pd.to_datetime(str(company_profile["上市时间"])).strftime("%Y年%m月%d日")
164
-
165
- template = "[公司介绍]:\n\n{股票简称}是一家在{行业}行业的领先实体,自{上市时间}成立并公开交易。截止今天,{股票简称}的总市值为{总市值}人民币,总股本数为{总股本},流通市值为{流通市值}人民币,流通股数为{流通股}。" \
166
- "\n\n{股票简称}主要在中国运营,以股票代码{股票代码}在交易所进行交易。"
167
-
168
- formatted_profile = template.format(**company_profile)
169
- stockname = company_profile['股票简称']
170
- return formatted_profile, stockname
171
-
172
- def map_return_label(return_lb):
173
- """
174
- Map abbrev in the raw data
175
- Example:
176
- 涨1 -- 上涨1%
177
- 跌2 -- 下跌2%
178
- 平 -- 股价持平
179
- """
180
-
181
- lb = return_lb.replace('涨', '上涨')
182
- lb = lb.replace('跌', '下跌')
183
- lb = lb.replace('平', '股价持平')
184
- lb = lb.replace('1', '0-1%')
185
- lb = lb.replace('2', '1-2%')
186
- lb = lb.replace('3', '2-3%')
187
- lb = lb.replace('4', '3-4%')
188
- if lb.endswith('+'):
189
- lb = lb.replace('5+', '超过5%')
190
- else:
191
- lb = lb.replace('5', '4-5%')
192
-
193
- return lb
194
-
195
- # check news quality
196
- def check_news_quality(n, last_n, week_end_date, repeat_rate = 0.6):
197
- try:
198
- # check content avalability
199
- if not (not(str(n['新闻内容'])[0].isdigit()) and not(str(n['新闻内容'])=='nan') and n['发布时间'][:8] <= week_end_date.replace('-', '')):
200
- return False
201
- # check highly duplicated news
202
- # (assume the duplicated contents happened adjacent)
203
-
204
- elif str(last_n['新闻内容'])=='nan':
205
- return True
206
- elif len(set(n['新闻内容'][:20]) & set(last_n['新闻内容'][:20])) >= 20*repeat_rate or len(set(n['新闻标题']) & set(last_n['新闻标题']))/len(last_n['新闻标题']) > repeat_rate:
207
- return False
208
-
209
- else:
210
- return True
211
- except TypeError:
212
- print(n)
213
- print(last_n)
214
- raise Exception("Check Error")
215
-
216
- def get_prompt_by_row_new(stock, row):
217
- """
218
- Generate prompt for each row in the raw data
219
- Args:
220
- stock: str
221
- stock name
222
- row: pandas.Series
223
- Return:
224
- head: heading prompt
225
- news: news info
226
- basics: basic financial info
227
- """
228
-
229
- week_start_date = row['起始日期'] if isinstance(row['起始日期'], str) else row['起始日期'].strftime('%Y-%m-%d')
230
- week_end_date = row['结算日期'] if isinstance(row['结算日期'], str) else row['结算日期'].strftime('%Y-%m-%d')
231
- term = '上涨' if row['结算价'] > row['起始价'] else '下跌'
232
- chg = map_return_label(row['简化周收益'])
233
- head = "自{}至{},{}的股票价格由{:.2f}{}至{:.2f},涨跌幅为:{}。在此期间的公司新闻如下:\n\n".format(
234
- week_start_date, week_end_date, stock, row['起始价'], term, row['结算价'], chg)
235
-
236
- news = json.loads(row["新闻"])
237
-
238
- left, right = 0, 0
239
- filtered_news = []
240
- while left < len(news):
241
- n = news[left]
242
-
243
- if left == 0:
244
- # check first news quality
245
- if (not(str(n['新闻内容'])[0].isdigit()) and not(str(n['新闻内容'])=='nan') and n['发布时间'][:8] <= week_end_date.replace('-', '')):
246
- filtered_news.append("[新闻标题]:{}\n[新闻内容]:{}\n".format(n['新闻标题'], n['新闻内容']))
247
- left += 1
248
-
249
- else:
250
- news_check = check_news_quality(n, last_n = news[right], week_end_date= week_end_date, repeat_rate=0.5)
251
- if news_check:
252
- filtered_news.append("[新闻标题]:{}\n[新闻内容]:{}\n".format(n['新闻标题'], n['新闻内容']))
253
- left += 1
254
- right += 1
255
-
256
-
257
- basics = json.loads(row['基本面'])
258
- if basics:
259
- basics = "如下所列为{}近期的一些金融基本面信息,记录时间为{}:\n\n[金融基本面]:\n\n".format(
260
- stock, basics['报告期']) + "\n".join(f"{k}: {v}" for k, v in basics.items() if k != 'period')
261
- else:
262
- basics = "[金融基本面]:\n\n 无金融基本面记录"
263
-
264
- return head, filtered_news, basics
265
-
266
- def sample_news(news, k=5):
267
- """
268
- Ramdomly select past news.
269
-
270
- Args:
271
- news:
272
- newslist in the timerange
273
- k: int
274
- the number of selected news
275
- """
276
- return [news[i] for i in sorted(random.sample(range(len(news)), k))]
277
-
278
- def get_all_prompts_new(symbol, min_past_week=1, max_past_weeks=2, with_basics=True):
279
- """
280
- Generate prompt. The prompt consists of news from past weeks, basics financial information, and weekly return.
281
- History news in the prompt is chosen from past weeks range from min_past_week to max_past_week,
282
- and there is a number constraint on ramdomly selected data (default: up to 5).
283
-
284
- Args:
285
- symbol: str
286
- stock ticker
287
- min_past_week: int
288
- max_past_week: int
289
- with_basics: bool
290
- If true, add basic infomation to the prompt
291
-
292
- Return:
293
- Prompts for the daterange
294
- """
295
-
296
- # Load Data
297
- df = raw_financial_data(symbol, with_basics=with_basics)
298
-
299
- company_prompt, stock = get_company_prompt_new(symbol)
300
-
301
- prev_rows = []
302
- all_prompts = []
303
-
304
- for row_idx, row in df.iterrows():
305
-
306
- prompt = ""
307
-
308
- # judge for available history news
309
- if len(prev_rows) >= min_past_week:
310
-
311
- # randomly set retrieve data of past weeks
312
- # idx = min(random.choice(range(min_past_week, max_past_weeks+1)), len(prev_rows))
313
- idx = min(max_past_weeks, len(prev_rows))
314
- for i in range(-idx, 0):
315
- # Add Head
316
- prompt += "\n" + prev_rows[i][0]
317
- # Add History News (with numbers constraint)
318
- sampled_news = sample_news(
319
- prev_rows[i][1],
320
- min(3, len(prev_rows[i][1]))
321
- )
322
- if sampled_news:
323
- prompt += "\n".join(sampled_news)
324
- else:
325
- prompt += "无有关新闻报告"
326
-
327
- head, news, basics = get_prompt_by_row_new(stock, row)
328
-
329
- prev_rows.append((head, news, basics))
330
-
331
- if len(prev_rows) > max_past_weeks:
332
- prev_rows.pop(0)
333
-
334
- # set this to make sure there is history news for each considered date
335
- if not prompt:
336
- continue
337
-
338
- prediction = map_return_label(row['简化周收益'])
339
-
340
- prompt = company_prompt + '\n' + prompt + '\n' + basics
341
-
342
- prompt += f"\n\n基于在{row['起始日期'].strftime('%Y-%m-%d')}之前的所有信息,让我们首先分析{stock}的积极发展和潜在担忧。请简洁地陈述,分别提出2-4个最重要的因素。大部分所提及的因素应该从公司的相关新闻中推断出来。" \
343
- f"那么让我们假设你对于下一周({row['起始日期'].strftime('%Y-%m-%d')}至{row['结算日期'].strftime('%Y-%m-%d')})的预测是{prediction}。提供一个总结分析来支持你的预测。预测结果需要从你最后的分析中推断出来,因此不作为你分析的基础因素。"
344
-
345
- all_prompts.append(prompt.strip())
346
-
347
- return all_prompts
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
Inference_datapipe_.py DELETED
@@ -1,155 +0,0 @@
1
- # Inference Data
2
- # get company news online
3
- from datetime import date
4
- import akshare as ak
5
- import pandas as pd
6
- from datetime import date, datetime, timedelta
7
- from Ashare_data import *
8
-
9
- #default symbol
10
- symbol = "600519"
11
- B_INST, E_INST = "[INST]", "[/INST]"
12
- B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
13
-
14
- def get_curday():
15
-
16
- return date.today().strftime("%Y%m%d")
17
-
18
- def n_weeks_before(date_string, n, format = "%Y%m%d"):
19
-
20
- date = datetime.strptime(date_string, "%Y%m%d") - timedelta(days=7*n)
21
-
22
- return date.strftime(format=format)
23
-
24
-
25
- def get_news(symbol, max_page = 3):
26
-
27
- df_list = []
28
- for page in range(1, max_page):
29
-
30
- try:
31
- df_list.append(ak.stock_news_em(symbol, page))
32
- except KeyError:
33
- print(str(symbol) + "pages obtained for symbol: " + page)
34
- break
35
-
36
- news_df = pd.concat(df_list, ignore_index=True)
37
- return news_df
38
-
39
- # get return
40
- def get_cur_return(symbol, start_date, end_date, adjust="qfq"):
41
- """
42
- date = "yyyymmdd"
43
- """
44
-
45
- # load data
46
- return_data = ak.stock_zh_a_hist(symbol=symbol, period="daily", start_date=start_date, end_date=end_date, adjust=adjust)
47
-
48
- # process timestamp
49
- return_data["日期"] = pd.to_datetime(return_data["日期"])
50
- return_data.set_index("日期", inplace=True)
51
-
52
- # resample and filled with forward data
53
- weekly_data = return_data["收盘"].resample("W").ffill()
54
- weekly_returns = weekly_data.pct_change()[1:]
55
- weekly_start_prices = weekly_data[:-1]
56
- weekly_end_prices = weekly_data[1:]
57
- weekly_data = pd.DataFrame({
58
- '起始日期': weekly_start_prices.index,
59
- '起始价': weekly_start_prices.values,
60
- '结算日期': weekly_end_prices.index,
61
- '结算价': weekly_end_prices.values,
62
- '周收益': weekly_returns.values
63
- })
64
- weekly_data["简化周收益"] = weekly_data["周收益"].map(return_transform)
65
- # check enddate
66
- if weekly_data.iloc[-1, 2] > pd.to_datetime(end_date):
67
- weekly_data.iloc[-1, 2] = pd.to_datetime(end_date)
68
-
69
- return weekly_data
70
-
71
- # get basics
72
- def cur_financial_data(symbol, start_date, end_date, with_basics = True):
73
-
74
- # get data
75
- data = get_cur_return(symbol=symbol, start_date=start_date, end_date=end_date)
76
-
77
- news_df = get_news(symbol=symbol)
78
- news_df["发布时间"] = pd.to_datetime(news_df["发布时间"], exact=False, format="%Y-%m-%d")
79
- news_df.sort_values(by=["发布时间"], inplace=True)
80
-
81
- # match weekly news for return data
82
- news_list = []
83
- for a, row in data.iterrows():
84
- week_start_date = row['起始日期'].strftime('%Y-%m-%d')
85
- week_end_date = row['结算日期'].strftime('%Y-%m-%d')
86
- print(symbol, ': ', week_start_date, ' - ', week_end_date)
87
-
88
- weekly_news = news_df.loc[(news_df["发布时间"]>week_start_date) & (news_df["发布时间"]<week_end_date)]
89
-
90
- weekly_news = [
91
- {
92
- "发布时间": n["发布时间"].strftime('%Y%m%d'),
93
- "新闻标题": n['新闻标题'],
94
- "新闻内容": n['新闻内容'],
95
- } for a, n in weekly_news.iterrows()
96
- ]
97
- news_list.append(json.dumps(weekly_news,ensure_ascii=False))
98
-
99
- data["新闻"] = news_list
100
-
101
- if with_basics:
102
- data = get_basic(symbol=symbol, data=data)
103
- # data.to_csv(symbol+start_date+"_"+end_date+".csv")
104
- else:
105
- data['新闻'] = [json.dumps({})] * len(data)
106
- # data.to_csv(symbol+start_date+"_"+end_date+"_nobasics.csv")
107
-
108
- return data
109
-
110
- def get_all_prompts_online(symbol, with_basics=True, max_news_perweek = 3, weeks_before = 2):
111
-
112
- end_date = get_curday()
113
- start_date = n_weeks_before(end_date, weeks_before)
114
-
115
- company_prompt, stock = get_company_prompt_new(symbol)
116
- data = cur_financial_data(symbol=symbol, start_date=start_date, end_date=end_date, with_basics=with_basics)
117
-
118
- prev_rows = []
119
-
120
- for row_idx, row in data.iterrows():
121
- head, news, basics = get_prompt_by_row_new(symbol, row)
122
- prev_rows.append((head, news, basics))
123
-
124
- prompt = ""
125
- for i in range(-len(prev_rows), 0):
126
- prompt += "\n" + prev_rows[i][0]
127
- sampled_news = sample_news(
128
- prev_rows[i][1],
129
- min(max_news_perweek, len(prev_rows[i][1]))
130
- )
131
- if sampled_news:
132
- prompt += "\n".join(sampled_news)
133
- else:
134
- prompt += "No relative news reported."
135
-
136
- next_date = n_weeks_before(end_date, -1, format="%Y-%m-%d")
137
- end_date = pd.to_datetime(end_date).strftime("%Y-%m-%d")
138
- period = "{}至{}".format(end_date, next_date)
139
-
140
- if with_basics:
141
- basics = prev_rows[-1][2]
142
- else:
143
- basics = "[金融基本面]:\n\n 无金融基本面记录"
144
-
145
- info = company_prompt + '\n' + prompt + '\n' + basics
146
-
147
- new_system_prompt = SYSTEM_PROMPT.replace(':\n...', ':\n预测涨跌幅:...\n总结分析:...')
148
- prompt = B_INST + B_SYS + new_system_prompt + E_SYS + info + f"\n\n基于在{end_date}之前的所有信息,让我们首先分析{stock}的积极发展和潜在担忧。请简洁地陈述,分别提出2-4个最重要的因素。大部分所提及的因素应该从公司的相关新闻中推断出来。" \
149
- f"接下来请预测{symbol}下周({period})的股票涨跌幅,并提供一个总结分析来支持你的预测。" + E_INST
150
-
151
- return info, prompt
152
-
153
- if __name__ == "__main__":
154
- info, pt = get_all_prompts_online(symbol=symbol)
155
- print(pt)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app.py CHANGED
@@ -1,10 +1,18 @@
1
  import gradio as gr
2
  from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM
3
  from peft import PeftModel
4
- import torch
5
  from Ashare_data import *
6
  from Inference_datapipe import *
7
  import re
 
 
 
 
 
 
 
 
 
8
 
9
  # load model
10
  model = "meta-llama/Llama-2-7b-chat-hf"
@@ -19,6 +27,348 @@ model = PeftModel.from_pretrained(model, peft_model, offload_folder="offload/")
19
 
20
  model = model.eval()
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  def ask(symbol, weeks_before):
24
 
 
1
  import gradio as gr
2
  from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM
3
  from peft import PeftModel
 
4
  from Ashare_data import *
5
  from Inference_datapipe import *
6
  import re
7
+ import akshare as ak
8
+ import pandas as pd
9
+ import random
10
+ import json
11
+ import requests
12
+ import math
13
+ from datetime import date
14
+ from datetime import date, datetime, timedelta
15
+
16
 
17
  # load model
18
  model = "meta-llama/Llama-2-7b-chat-hf"
 
27
 
28
  model = model.eval()
29
 
30
+ # Inference Data
31
+ # get company news online
32
+
33
+ B_INST, E_INST = "[INST]", "[/INST]"
34
+ B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
35
+ SYSTEM_PROMPT = "你是一名经验丰富的股票市场分析师。你的任务是根据公司在过去几周内的相关新闻和季度财务状况,列出公司的积极发展和潜在担忧,然后结合你对整体金融经济市场的判断,对公司未来一周的股价变化提供预测和分析。" \
36
+ "你的回答语言应为中文。你的回答格式应该如下:\n\n[积极发展]:\n1. ...\n\n[潜在担忧]:\n1. ...\n\n[预测和分析]:\n...\n"
37
+
38
+ # ------------------------------------------------------------------------------
39
+ # Utils
40
+ # ------------------------------------------------------------------------------
41
+ def get_curday():
42
+
43
+ return date.today().strftime("%Y%m%d")
44
+
45
+ def n_weeks_before(date_string, n, format = "%Y%m%d"):
46
+
47
+ date = datetime.strptime(date_string, "%Y%m%d") - timedelta(days=7*n)
48
+
49
+ return date.strftime(format=format)
50
+
51
+ def check_news_quality(n, last_n, week_end_date, repeat_rate = 0.6):
52
+ try:
53
+ # check content avalability
54
+ if not (not(str(n['新闻内容'])[0].isdigit()) and not(str(n['新闻内容'])=='nan') and n['发布时间'][:8] <= week_end_date.replace('-', '')):
55
+ return False
56
+ # check highly duplicated news
57
+ # (assume the duplicated contents happened adjacent)
58
+
59
+ elif str(last_n['新闻内容'])=='nan':
60
+ return True
61
+ elif len(set(n['新闻内容'][:20]) & set(last_n['新闻内容'][:20])) >= 20*repeat_rate or len(set(n['新闻标题']) & set(last_n['新闻标题']))/len(last_n['新闻标题']) > repeat_rate:
62
+ return False
63
+
64
+ else:
65
+ return True
66
+ except TypeError:
67
+ print(n)
68
+ print(last_n)
69
+ raise Exception("Check Error")
70
+
71
+ def sample_news(news, k=5):
72
+
73
+ return [news[i] for i in sorted(random.sample(range(len(news)), k))]
74
+
75
+ def return_transform(ret):
76
+
77
+ up_down = '涨' if ret >= 0 else '跌'
78
+ integer = math.ceil(abs(100 * ret))
79
+ if integer == 0:
80
+ return "平"
81
+
82
+ return up_down + (str(integer) if integer <= 5 else '5+')
83
+
84
+ def map_return_label(return_lb):
85
+
86
+ lb = return_lb.replace('涨', '上涨')
87
+ lb = lb.replace('跌', '下跌')
88
+ lb = lb.replace('平', '股价持平')
89
+ lb = lb.replace('1', '0-1%')
90
+ lb = lb.replace('2', '1-2%')
91
+ lb = lb.replace('3', '2-3%')
92
+ lb = lb.replace('4', '3-4%')
93
+ if lb.endswith('+'):
94
+ lb = lb.replace('5+', '超过5%')
95
+ else:
96
+ lb = lb.replace('5', '4-5%')
97
+
98
+ return lb
99
+ # ------------------------------------------------------------------------------
100
+ # Get data from website
101
+ # ------------------------------------------------------------------------------
102
+ def stock_news_em(symbol: str = "300059", page = 1) -> pd.DataFrame:
103
+
104
+ url = "https://search-api-web.eastmoney.com/search/jsonp"
105
+ params = {
106
+ "cb": "jQuery3510875346244069884_1668256937995",
107
+ "param": '{"uid":"",'
108
+ + f'"keyword":"{symbol}"'
109
+ + ',"type":["cmsArticleWebOld"],"client":"web","clientType":"web","clientVersion":"curr","param":{"cmsArticleWebOld":{"searchScope":"default","sort":"default",' + f'"pageIndex":{page}'+ ',"pageSize":100,"preTag":"<em>","postTag":"</em>"}}}',
110
+ "_": "1668256937996",
111
+ }
112
+ r = requests.get(url, params=params)
113
+ data_text = r.text
114
+ data_json = json.loads(
115
+ data_text.strip("jQuery3510875346244069884_1668256937995(")[:-1]
116
+ )
117
+ temp_df = pd.DataFrame(data_json["result"]["cmsArticleWebOld"])
118
+ temp_df.rename(
119
+ columns={
120
+ "date": "发布时间",
121
+ "mediaName": "文章来源",
122
+ "code": "-",
123
+ "title": "新闻标题",
124
+ "content": "新闻内容",
125
+ "url": "新闻链接",
126
+ "image": "-",
127
+ },
128
+ inplace=True,
129
+ )
130
+ temp_df["关键词"] = symbol
131
+ temp_df = temp_df[
132
+ [
133
+ "关键词",
134
+ "新闻标题",
135
+ "新闻内容",
136
+ "发布时间",
137
+ "文章来源",
138
+ "新闻链接",
139
+ ]
140
+ ]
141
+ temp_df["新闻标题"] = (
142
+ temp_df["新闻标题"]
143
+ .str.replace(r"\(<em>", "", regex=True)
144
+ .str.replace(r"</em>\)", "", regex=True)
145
+ )
146
+ temp_df["新闻标题"] = (
147
+ temp_df["新闻标题"]
148
+ .str.replace(r"<em>", "", regex=True)
149
+ .str.replace(r"</em>", "", regex=True)
150
+ )
151
+ temp_df["新闻内容"] = (
152
+ temp_df["新闻内容"]
153
+ .str.replace(r"\(<em>", "", regex=True)
154
+ .str.replace(r"</em>\)", "", regex=True)
155
+ )
156
+ temp_df["新闻内容"] = (
157
+ temp_df["新闻内容"]
158
+ .str.replace(r"<em>", "", regex=True)
159
+ .str.replace(r"</em>", "", regex=True)
160
+ )
161
+ temp_df["新闻内容"] = temp_df["新闻内容"].str.replace(r"\u3000", "", regex=True)
162
+ temp_df["新闻内容"] = temp_df["新闻内容"].str.replace(r"\r\n", " ", regex=True)
163
+ return temp_df
164
+
165
+ def get_news(symbol, max_page = 3):
166
+
167
+ df_list = []
168
+ for page in range(1, max_page):
169
+
170
+ try:
171
+ df_list.append(stock_news_em(symbol, page))
172
+ except KeyError:
173
+ print(str(symbol) + "pages obtained for symbol: " + page)
174
+ break
175
+
176
+ news_df = pd.concat(df_list, ignore_index=True)
177
+ return news_df
178
+
179
+ def get_cur_return(symbol, start_date, end_date, adjust="qfq"):
180
+
181
+ # load data
182
+ return_data = ak.stock_zh_a_hist(symbol=symbol, period="daily", start_date=start_date, end_date=end_date, adjust=adjust)
183
+
184
+ # process timestamp
185
+ return_data["日期"] = pd.to_datetime(return_data["日期"])
186
+ return_data.set_index("日期", inplace=True)
187
+
188
+ # resample and filled with forward data
189
+ weekly_data = return_data["收盘"].resample("W").ffill()
190
+ weekly_returns = weekly_data.pct_change()[1:]
191
+ weekly_start_prices = weekly_data[:-1]
192
+ weekly_end_prices = weekly_data[1:]
193
+ weekly_data = pd.DataFrame({
194
+ '起始日期': weekly_start_prices.index,
195
+ '起始价': weekly_start_prices.values,
196
+ '结算日期': weekly_end_prices.index,
197
+ '结算价': weekly_end_prices.values,
198
+ '周收益': weekly_returns.values
199
+ })
200
+ weekly_data["简化周收益"] = weekly_data["周收益"].map(return_transform)
201
+ # check enddate
202
+ if weekly_data.iloc[-1, 2] > pd.to_datetime(end_date):
203
+ weekly_data.iloc[-1, 2] = pd.to_datetime(end_date)
204
+
205
+ return weekly_data
206
+
207
+ def get_basic(symbol, data):
208
+
209
+ key_financials = ['报告期', '净利润同比增长率', '营业总收入同比增长率', '流动比率', '速动比率', '资产负债率']
210
+
211
+ # load quarterly basic data
212
+ basic_quarter_financials = ak.stock_financial_abstract_ths(symbol = symbol, indicator="按单季度")
213
+ basic_fin_dict = basic_quarter_financials.to_dict("index")
214
+ basic_fin_list = [dict([(key, val) for key, val in basic_fin_dict[i].items() if (key in key_financials) and val]) for i in range(len(basic_fin_dict))]
215
+
216
+ # match basic financial data to news dataframe
217
+ matched_basic_fin = []
218
+ for i, row in data.iterrows():
219
+
220
+ newsweek_enddate = row['结算日期'].strftime("%Y-%m-%d")
221
+
222
+ matched_basic = {}
223
+ for basic in basic_fin_list:
224
+ # match the most current financial report
225
+ if basic["报告期"] < newsweek_enddate:
226
+ matched_basic = basic
227
+ break
228
+ matched_basic_fin.append(json.dumps(matched_basic, ensure_ascii=False))
229
+
230
+ data['基本面'] = matched_basic_fin
231
+
232
+ return data
233
+ # ------------------------------------------------------------------------------
234
+ # Structure Data
235
+ # ------------------------------------------------------------------------------
236
+ def cur_financial_data(symbol, start_date, end_date, with_basics = True):
237
+
238
+ # get data
239
+ data = get_cur_return(symbol=symbol, start_date=start_date, end_date=end_date)
240
+
241
+ news_df = get_news(symbol=symbol)
242
+ news_df["发布时间"] = pd.to_datetime(news_df["发布时间"], exact=False, format="%Y-%m-%d")
243
+ news_df.sort_values(by=["发布时间"], inplace=True)
244
+
245
+ # match weekly news for return data
246
+ news_list = []
247
+ for a, row in data.iterrows():
248
+ week_start_date = row['起始日期'].strftime('%Y-%m-%d')
249
+ week_end_date = row['结算日期'].strftime('%Y-%m-%d')
250
+ print(symbol, ': ', week_start_date, ' - ', week_end_date)
251
+
252
+ weekly_news = news_df.loc[(news_df["发布时间"]>week_start_date) & (news_df["发布时间"]<week_end_date)]
253
+
254
+ weekly_news = [
255
+ {
256
+ "发布时间": n["发布时间"].strftime('%Y%m%d'),
257
+ "新闻标题": n['新闻标题'],
258
+ "新闻内容": n['新闻内容'],
259
+ } for a, n in weekly_news.iterrows()
260
+ ]
261
+ news_list.append(json.dumps(weekly_news,ensure_ascii=False))
262
+
263
+ data["新闻"] = news_list
264
+
265
+ if with_basics:
266
+ data = get_basic(symbol=symbol, data=data)
267
+ # data.to_csv(symbol+start_date+"_"+end_date+".csv")
268
+ else:
269
+ data['新闻'] = [json.dumps({})] * len(data)
270
+ # data.to_csv(symbol+start_date+"_"+end_date+"_nobasics.csv")
271
+
272
+ return data
273
+ # ------------------------------------------------------------------------------
274
+ # Formate Instruction
275
+ # ------------------------------------------------------------------------------
276
+ def get_company_prompt_new(symbol):
277
+ try:
278
+ company_profile = dict(ak.stock_individual_info_em(symbol).values)
279
+ except:
280
+ print("Company Info Request Time Out! Please wait and retry.")
281
+ company_profile["上市时间"] = pd.to_datetime(str(company_profile["上市时间"])).strftime("%Y年%m月%d日")
282
+
283
+ template = "[公司介绍]:\n\n{股票简称}是一家在{行业}行业的领先实体,自{上市时间}成立并公开交易。截止今天,{股票简称}的总市值为{总市值}人民币,总股本数为{总股本},流通市值为{流通市值}人民币,流通股数为{流通股}。" \
284
+ "\n\n{股票简称}主要在中国运营,以股票代码{股票代码}在交易所进行交易。"
285
+
286
+ formatted_profile = template.format(**company_profile)
287
+ stockname = company_profile['股票简称']
288
+ return formatted_profile, stockname
289
+
290
+ def get_prompt_by_row_new(stock, row):
291
+
292
+ week_start_date = row['起始日期'] if isinstance(row['起始日期'], str) else row['起始日期'].strftime('%Y-%m-%d')
293
+ week_end_date = row['结算日期'] if isinstance(row['结算日期'], str) else row['结算日期'].strftime('%Y-%m-%d')
294
+ term = '上涨' if row['结算价'] > row['起始价'] else '下跌'
295
+ chg = map_return_label(row['简化周收益'])
296
+ head = "自{}至{},{}的股票价格由{:.2f}{}至{:.2f},涨跌幅为:{}。在此期间的公司新闻如下:\n\n".format(
297
+ week_start_date, week_end_date, stock, row['起始价'], term, row['结算价'], chg)
298
+
299
+ news = json.loads(row["新闻"])
300
+
301
+ left, right = 0, 0
302
+ filtered_news = []
303
+ while left < len(news):
304
+ n = news[left]
305
+
306
+ if left == 0:
307
+ # check first news quality
308
+ if (not(str(n['新闻内容'])[0].isdigit()) and not(str(n['新闻内容'])=='nan') and n['发布时间'][:8] <= week_end_date.replace('-', '')):
309
+ filtered_news.append("[新闻标题]:{}\n[新闻内容]:{}\n".format(n['新闻标题'], n['新闻内容']))
310
+ left += 1
311
+
312
+ else:
313
+ news_check = check_news_quality(n, last_n = news[right], week_end_date= week_end_date, repeat_rate=0.5)
314
+ if news_check:
315
+ filtered_news.append("[新闻标题]:{}\n[新闻内容]:{}\n".format(n['新闻标题'], n['新闻内容']))
316
+ left += 1
317
+ right += 1
318
+
319
+
320
+ basics = json.loads(row['基本面'])
321
+ if basics:
322
+ basics = "如下所列为{}近期的一些金融基本面信息,记录时间为{}:\n\n[金融基本面]:\n\n".format(
323
+ stock, basics['报告期']) + "\n".join(f"{k}: {v}" for k, v in basics.items() if k != 'period')
324
+ else:
325
+ basics = "[金融基本面]:\n\n 无金融基本面记录"
326
+
327
+ return head, filtered_news, basics
328
+
329
+ def get_all_prompts_online(symbol, with_basics=True, max_news_perweek = 3, weeks_before = 2):
330
+
331
+ end_date = get_curday()
332
+ start_date = n_weeks_before(end_date, weeks_before)
333
+
334
+ company_prompt, stock = get_company_prompt_new(symbol)
335
+ data = cur_financial_data(symbol=symbol, start_date=start_date, end_date=end_date, with_basics=with_basics)
336
+
337
+ prev_rows = []
338
+
339
+ for row_idx, row in data.iterrows():
340
+ head, news, basics = get_prompt_by_row_new(symbol, row)
341
+ prev_rows.append((head, news, basics))
342
+
343
+ prompt = ""
344
+ for i in range(-len(prev_rows), 0):
345
+ prompt += "\n" + prev_rows[i][0]
346
+ sampled_news = sample_news(
347
+ prev_rows[i][1],
348
+ min(max_news_perweek, len(prev_rows[i][1]))
349
+ )
350
+ if sampled_news:
351
+ prompt += "\n".join(sampled_news)
352
+ else:
353
+ prompt += "No relative news reported."
354
+
355
+ next_date = n_weeks_before(end_date, -1, format="%Y-%m-%d")
356
+ end_date = pd.to_datetime(end_date).strftime("%Y-%m-%d")
357
+ period = "{}至{}".format(end_date, next_date)
358
+
359
+ if with_basics:
360
+ basics = prev_rows[-1][2]
361
+ else:
362
+ basics = "[金融基本面]:\n\n 无金融基本面记录"
363
+
364
+ info = company_prompt + '\n' + prompt + '\n' + basics
365
+
366
+ new_system_prompt = SYSTEM_PROMPT.replace(':\n...', ':\n预测涨跌幅:...\n总结分析:...')
367
+ prompt = B_INST + B_SYS + new_system_prompt + E_SYS + info + f"\n\n基于在{end_date}之前的所有信息,让我们首先分析{stock}的积极发展和潜在担忧。请简洁地陈述,分别提出2-4个最重要的因素。大部分所提及的因素应该从公司的相关新闻中推断出来。" \
368
+ f"接下来请预测{symbol}下周({period})的股票涨跌幅,并提供一个总结分析来支持你的预测。" + E_INST
369
+
370
+ return info, prompt
371
+
372
 
373
  def ask(symbol, weeks_before):
374
 
requirement → requirements.txt RENAMED
File without changes