Spaces:
Runtime error
Runtime error
application files
Browse files- Ashare_data_.py +347 -0
- Inference_datapipe_.py +155 -0
- app.py +70 -0
Ashare_data_.py
ADDED
@@ -0,0 +1,347 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import akshare as ak
|
2 |
+
import pandas as pd
|
3 |
+
import os
|
4 |
+
import csv
|
5 |
+
import re
|
6 |
+
import time
|
7 |
+
import math
|
8 |
+
import json
|
9 |
+
import random
|
10 |
+
from datasets import Dataset
|
11 |
+
import datasets
|
12 |
+
|
13 |
+
# os.chdir("/Users/mac/Desktop/FinGPT_Forecasting_Project/")
|
14 |
+
# print(os.getcwd())
|
15 |
+
|
16 |
+
start_date = "20230201"
|
17 |
+
end_date = "20240101"
|
18 |
+
|
19 |
+
# ------------------------------------------------------------------------------
|
20 |
+
# Data Aquisition
|
21 |
+
# ------------------------------------------------------------------------------
|
22 |
+
|
23 |
+
# get return
|
24 |
+
def get_return(symbol, adjust="qfq"):
|
25 |
+
"""
|
26 |
+
Get stock return data.
|
27 |
+
|
28 |
+
Args:
|
29 |
+
symbol: str
|
30 |
+
A-share market stock symbol
|
31 |
+
adjust: str ("qfq", "hfq")
|
32 |
+
price ajustment
|
33 |
+
default = "qfq" 前复权
|
34 |
+
|
35 |
+
Return:
|
36 |
+
weekly forward filled return data
|
37 |
+
"""
|
38 |
+
|
39 |
+
# load data
|
40 |
+
return_data = ak.stock_zh_a_hist(symbol=symbol, period="daily", start_date=start_date, end_date=end_date, adjust=adjust)
|
41 |
+
|
42 |
+
# process timestamp
|
43 |
+
return_data["日期"] = pd.to_datetime(return_data["日期"])
|
44 |
+
return_data.set_index("日期", inplace=True)
|
45 |
+
|
46 |
+
# resample and filled with forward data
|
47 |
+
weekly_data = return_data["收盘"].resample("W").ffill()
|
48 |
+
weekly_returns = weekly_data.pct_change()[1:]
|
49 |
+
weekly_start_prices = weekly_data[:-1]
|
50 |
+
weekly_end_prices = weekly_data[1:]
|
51 |
+
weekly_data = pd.DataFrame({
|
52 |
+
'起始日期': weekly_start_prices.index,
|
53 |
+
'起始价': weekly_start_prices.values,
|
54 |
+
'结算日期': weekly_end_prices.index,
|
55 |
+
'结算价': weekly_end_prices.values,
|
56 |
+
'周收益': weekly_returns.values
|
57 |
+
})
|
58 |
+
weekly_data["简化周收益"] = weekly_data["周收益"].map(return_transform)
|
59 |
+
|
60 |
+
return weekly_data
|
61 |
+
def return_transform(ret):
|
62 |
+
|
63 |
+
up_down = '涨' if ret >= 0 else '跌'
|
64 |
+
integer = math.ceil(abs(100 * ret))
|
65 |
+
if integer == 0:
|
66 |
+
return "平"
|
67 |
+
|
68 |
+
return up_down + (str(integer) if integer <= 5 else '5+')
|
69 |
+
|
70 |
+
# get basics
|
71 |
+
def get_basic(symbol, data):
|
72 |
+
"""
|
73 |
+
Get and match basic data to news dataframe.
|
74 |
+
|
75 |
+
Args:
|
76 |
+
symbol: str
|
77 |
+
A-share market stock symbol
|
78 |
+
data: DataFrame
|
79 |
+
dated news data
|
80 |
+
|
81 |
+
Return:
|
82 |
+
financial news dataframe with matched basic_financial info
|
83 |
+
"""
|
84 |
+
key_financials = ['报告期', '净利润同比增长率', '营业总收入同比增长率', '流动比率', '速动比率', '资产负债率']
|
85 |
+
|
86 |
+
# load quarterly basic data
|
87 |
+
basic_quarter_financials = ak.stock_financial_abstract_ths(symbol = symbol, indicator="按单季度")
|
88 |
+
basic_fin_dict = basic_quarter_financials.to_dict("index")
|
89 |
+
basic_fin_list = [dict([(key, val) for key, val in basic_fin_dict[i].items() if (key in key_financials) and val]) for i in range(len(basic_fin_dict))]
|
90 |
+
|
91 |
+
# match basic financial data to news dataframe
|
92 |
+
matched_basic_fin = []
|
93 |
+
for i, row in data.iterrows():
|
94 |
+
|
95 |
+
newsweek_enddate = row['结算日期'].strftime("%Y-%m-%d")
|
96 |
+
|
97 |
+
matched_basic = {}
|
98 |
+
for basic in basic_fin_list:
|
99 |
+
# match the most current financial report
|
100 |
+
if basic["报告期"] < newsweek_enddate:
|
101 |
+
matched_basic = basic
|
102 |
+
break
|
103 |
+
matched_basic_fin.append(json.dumps(matched_basic, ensure_ascii=False))
|
104 |
+
|
105 |
+
data['基本面'] = matched_basic_fin
|
106 |
+
|
107 |
+
return data
|
108 |
+
|
109 |
+
def raw_financial_data(symbol, with_basics = True):
|
110 |
+
|
111 |
+
# get return data from API
|
112 |
+
data = get_return(symbol=symbol)
|
113 |
+
|
114 |
+
# get news data from local
|
115 |
+
file_name = "news_data" + symbol + ".csv"
|
116 |
+
news_df = pd.read_csv("HS300_news_data20240118/"+file_name, index_col=0)
|
117 |
+
news_df["发布时间"] = pd.to_datetime(news_df["发布时间"], exact=False, format="%Y-%m-%d")
|
118 |
+
news_df.sort_values(by=["发布时间"], inplace=True)
|
119 |
+
|
120 |
+
# match weekly news for return data
|
121 |
+
news_list = []
|
122 |
+
for a, row in data.iterrows():
|
123 |
+
week_start_date = row['起始日期'].strftime('%Y-%m-%d')
|
124 |
+
week_end_date = row['结算日期'].strftime('%Y-%m-%d')
|
125 |
+
print(symbol, ': ', week_start_date, ' - ', week_end_date)
|
126 |
+
|
127 |
+
weekly_news = news_df.loc[(news_df["发布时间"]>week_start_date) & (news_df["发布时间"]<week_end_date)]
|
128 |
+
|
129 |
+
weekly_news = [
|
130 |
+
{
|
131 |
+
"发布时间": n["发布时间"].strftime('%Y%m%d'),
|
132 |
+
"新闻标题": n['新闻标题'],
|
133 |
+
"新闻内容": n['新闻内容'],
|
134 |
+
} for a, n in weekly_news.iterrows()
|
135 |
+
]
|
136 |
+
news_list.append(json.dumps(weekly_news,ensure_ascii=False))
|
137 |
+
|
138 |
+
data["新闻"] = news_list
|
139 |
+
|
140 |
+
if with_basics:
|
141 |
+
data = get_basic(symbol=symbol, data=data)
|
142 |
+
# data.to_csv(symbol+start_date+"_"+end_date+".csv")
|
143 |
+
else:
|
144 |
+
data['新闻'] = [json.dumps({})] * len(data)
|
145 |
+
# data.to_csv(symbol+start_date+"_"+end_date+"_nobasics.csv")
|
146 |
+
|
147 |
+
return data
|
148 |
+
|
149 |
+
# ------------------------------------------------------------------------------
|
150 |
+
# Prompt Generation
|
151 |
+
# ------------------------------------------------------------------------------
|
152 |
+
|
153 |
+
# SYSTEM_PROMPT = "你是一个经验丰富的股票市场分析师。你的任务是根据过去几周的相关新闻和基本财务状况,列出公司的积极发展和潜在担忧,然后对公司未来一周的股价变化提供分析和预测。" \
|
154 |
+
# "你的回答语言应为中文。你的回答格式应该如下:\n\n[积极发展]:\n1. ...\n\n[潜在担忧]:\n1. ...\n\n[预测和分析]:\n...\n"
|
155 |
+
SYSTEM_PROMPT = "你是一名经验丰富的股票市场分析师。你的任务是根据公司在过去几周内的相关新闻和季度财务状况,列出公司的积极发展和潜在担忧,然后结合你对整体金融经济市场的判断,对公司未来一周的股价变化提供预测和分析。" \
|
156 |
+
"你的回答语言应为中文。你的回答格式应该如下:\n\n[积极发展]:\n1. ...\n\n[潜在担忧]:\n1. ...\n\n[预测和分析]:\n...\n"
|
157 |
+
|
158 |
+
def get_company_prompt_new(symbol):
|
159 |
+
try:
|
160 |
+
company_profile = dict(ak.stock_individual_info_em(symbol).values)
|
161 |
+
except:
|
162 |
+
print("Company Info Request Time Out! Please wait and retry.")
|
163 |
+
company_profile["上市时间"] = pd.to_datetime(str(company_profile["上市时间"])).strftime("%Y年%m月%d日")
|
164 |
+
|
165 |
+
template = "[公司介绍]:\n\n{股票简称}是一家在{行业}行业的领先实体,自{上市时间}成立并公开交易。截止今天,{股票简称}的总市值为{总市值}人民币,总股本数为{总股本},流通市值为{流通市值}人民币,流通股数为{流通股}。" \
|
166 |
+
"\n\n{股票简称}主要在中国运营,以股票代码{股票代码}在交易所进行交易。"
|
167 |
+
|
168 |
+
formatted_profile = template.format(**company_profile)
|
169 |
+
stockname = company_profile['股票简称']
|
170 |
+
return formatted_profile, stockname
|
171 |
+
|
172 |
+
def map_return_label(return_lb):
|
173 |
+
"""
|
174 |
+
Map abbrev in the raw data
|
175 |
+
Example:
|
176 |
+
涨1 -- 上涨1%
|
177 |
+
跌2 -- 下跌2%
|
178 |
+
平 -- 股价持平
|
179 |
+
"""
|
180 |
+
|
181 |
+
lb = return_lb.replace('涨', '上涨')
|
182 |
+
lb = lb.replace('跌', '下跌')
|
183 |
+
lb = lb.replace('平', '股价持平')
|
184 |
+
lb = lb.replace('1', '0-1%')
|
185 |
+
lb = lb.replace('2', '1-2%')
|
186 |
+
lb = lb.replace('3', '2-3%')
|
187 |
+
lb = lb.replace('4', '3-4%')
|
188 |
+
if lb.endswith('+'):
|
189 |
+
lb = lb.replace('5+', '超过5%')
|
190 |
+
else:
|
191 |
+
lb = lb.replace('5', '4-5%')
|
192 |
+
|
193 |
+
return lb
|
194 |
+
|
195 |
+
# check news quality
|
196 |
+
def check_news_quality(n, last_n, week_end_date, repeat_rate = 0.6):
|
197 |
+
try:
|
198 |
+
# check content avalability
|
199 |
+
if not (not(str(n['新闻内容'])[0].isdigit()) and not(str(n['新闻内容'])=='nan') and n['发布时间'][:8] <= week_end_date.replace('-', '')):
|
200 |
+
return False
|
201 |
+
# check highly duplicated news
|
202 |
+
# (assume the duplicated contents happened adjacent)
|
203 |
+
|
204 |
+
elif str(last_n['新闻内容'])=='nan':
|
205 |
+
return True
|
206 |
+
elif len(set(n['新闻内容'][:20]) & set(last_n['新闻内容'][:20])) >= 20*repeat_rate or len(set(n['新闻标题']) & set(last_n['新闻标题']))/len(last_n['新闻标题']) > repeat_rate:
|
207 |
+
return False
|
208 |
+
|
209 |
+
else:
|
210 |
+
return True
|
211 |
+
except TypeError:
|
212 |
+
print(n)
|
213 |
+
print(last_n)
|
214 |
+
raise Exception("Check Error")
|
215 |
+
|
216 |
+
def get_prompt_by_row_new(stock, row):
|
217 |
+
"""
|
218 |
+
Generate prompt for each row in the raw data
|
219 |
+
Args:
|
220 |
+
stock: str
|
221 |
+
stock name
|
222 |
+
row: pandas.Series
|
223 |
+
Return:
|
224 |
+
head: heading prompt
|
225 |
+
news: news info
|
226 |
+
basics: basic financial info
|
227 |
+
"""
|
228 |
+
|
229 |
+
week_start_date = row['起始日期'] if isinstance(row['起始日期'], str) else row['起始日期'].strftime('%Y-%m-%d')
|
230 |
+
week_end_date = row['结算日期'] if isinstance(row['结算日期'], str) else row['结算日期'].strftime('%Y-%m-%d')
|
231 |
+
term = '上涨' if row['结算价'] > row['起始价'] else '下跌'
|
232 |
+
chg = map_return_label(row['简化周收益'])
|
233 |
+
head = "自{}至{},{}的股票价格由{:.2f}{}至{:.2f},涨跌幅为:{}。在此期间的公司新闻如下:\n\n".format(
|
234 |
+
week_start_date, week_end_date, stock, row['起始价'], term, row['结算价'], chg)
|
235 |
+
|
236 |
+
news = json.loads(row["新闻"])
|
237 |
+
|
238 |
+
left, right = 0, 0
|
239 |
+
filtered_news = []
|
240 |
+
while left < len(news):
|
241 |
+
n = news[left]
|
242 |
+
|
243 |
+
if left == 0:
|
244 |
+
# check first news quality
|
245 |
+
if (not(str(n['新闻内容'])[0].isdigit()) and not(str(n['新闻内容'])=='nan') and n['发布时间'][:8] <= week_end_date.replace('-', '')):
|
246 |
+
filtered_news.append("[新闻标题]:{}\n[新闻内容]:{}\n".format(n['新闻标题'], n['新闻内容']))
|
247 |
+
left += 1
|
248 |
+
|
249 |
+
else:
|
250 |
+
news_check = check_news_quality(n, last_n = news[right], week_end_date= week_end_date, repeat_rate=0.5)
|
251 |
+
if news_check:
|
252 |
+
filtered_news.append("[新闻标题]:{}\n[新闻内容]:{}\n".format(n['新闻标题'], n['新闻内容']))
|
253 |
+
left += 1
|
254 |
+
right += 1
|
255 |
+
|
256 |
+
|
257 |
+
basics = json.loads(row['基本面'])
|
258 |
+
if basics:
|
259 |
+
basics = "如下所列为{}近期的一些金融基本面信息,记录时间为{}:\n\n[金融基本面]:\n\n".format(
|
260 |
+
stock, basics['报告期']) + "\n".join(f"{k}: {v}" for k, v in basics.items() if k != 'period')
|
261 |
+
else:
|
262 |
+
basics = "[金融基本面]:\n\n 无金融基本面记录"
|
263 |
+
|
264 |
+
return head, filtered_news, basics
|
265 |
+
|
266 |
+
def sample_news(news, k=5):
|
267 |
+
"""
|
268 |
+
Ramdomly select past news.
|
269 |
+
|
270 |
+
Args:
|
271 |
+
news:
|
272 |
+
newslist in the timerange
|
273 |
+
k: int
|
274 |
+
the number of selected news
|
275 |
+
"""
|
276 |
+
return [news[i] for i in sorted(random.sample(range(len(news)), k))]
|
277 |
+
|
278 |
+
def get_all_prompts_new(symbol, min_past_week=1, max_past_weeks=2, with_basics=True):
|
279 |
+
"""
|
280 |
+
Generate prompt. The prompt consists of news from past weeks, basics financial information, and weekly return.
|
281 |
+
History news in the prompt is chosen from past weeks range from min_past_week to max_past_week,
|
282 |
+
and there is a number constraint on ramdomly selected data (default: up to 5).
|
283 |
+
|
284 |
+
Args:
|
285 |
+
symbol: str
|
286 |
+
stock ticker
|
287 |
+
min_past_week: int
|
288 |
+
max_past_week: int
|
289 |
+
with_basics: bool
|
290 |
+
If true, add basic infomation to the prompt
|
291 |
+
|
292 |
+
Return:
|
293 |
+
Prompts for the daterange
|
294 |
+
"""
|
295 |
+
|
296 |
+
# Load Data
|
297 |
+
df = raw_financial_data(symbol, with_basics=with_basics)
|
298 |
+
|
299 |
+
company_prompt, stock = get_company_prompt_new(symbol)
|
300 |
+
|
301 |
+
prev_rows = []
|
302 |
+
all_prompts = []
|
303 |
+
|
304 |
+
for row_idx, row in df.iterrows():
|
305 |
+
|
306 |
+
prompt = ""
|
307 |
+
|
308 |
+
# judge for available history news
|
309 |
+
if len(prev_rows) >= min_past_week:
|
310 |
+
|
311 |
+
# randomly set retrieve data of past weeks
|
312 |
+
# idx = min(random.choice(range(min_past_week, max_past_weeks+1)), len(prev_rows))
|
313 |
+
idx = min(max_past_weeks, len(prev_rows))
|
314 |
+
for i in range(-idx, 0):
|
315 |
+
# Add Head
|
316 |
+
prompt += "\n" + prev_rows[i][0]
|
317 |
+
# Add History News (with numbers constraint)
|
318 |
+
sampled_news = sample_news(
|
319 |
+
prev_rows[i][1],
|
320 |
+
min(3, len(prev_rows[i][1]))
|
321 |
+
)
|
322 |
+
if sampled_news:
|
323 |
+
prompt += "\n".join(sampled_news)
|
324 |
+
else:
|
325 |
+
prompt += "无有关新闻报告"
|
326 |
+
|
327 |
+
head, news, basics = get_prompt_by_row_new(stock, row)
|
328 |
+
|
329 |
+
prev_rows.append((head, news, basics))
|
330 |
+
|
331 |
+
if len(prev_rows) > max_past_weeks:
|
332 |
+
prev_rows.pop(0)
|
333 |
+
|
334 |
+
# set this to make sure there is history news for each considered date
|
335 |
+
if not prompt:
|
336 |
+
continue
|
337 |
+
|
338 |
+
prediction = map_return_label(row['简化周收益'])
|
339 |
+
|
340 |
+
prompt = company_prompt + '\n' + prompt + '\n' + basics
|
341 |
+
|
342 |
+
prompt += f"\n\n基于在{row['起始日期'].strftime('%Y-%m-%d')}之前的所有信息,让我们首先分析{stock}的积极发展和潜在担忧。请简洁地陈述,分别提出2-4个最重要的因素。大部分所提及的因素应该从公司的相关新闻中推断出来。" \
|
343 |
+
f"那么让我们假设你对于下一周({row['起始日期'].strftime('%Y-%m-%d')}至{row['结算日期'].strftime('%Y-%m-%d')})的预测是{prediction}。提供一个总结分析来支持你的预测。预测结果需要从你最后的分析中推断出来,因此不作为你分析的基础因素。"
|
344 |
+
|
345 |
+
all_prompts.append(prompt.strip())
|
346 |
+
|
347 |
+
return all_prompts
|
Inference_datapipe_.py
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Inference Data
|
2 |
+
# get company news online
|
3 |
+
from datetime import date
|
4 |
+
import akshare as ak
|
5 |
+
import pandas as pd
|
6 |
+
from datetime import date, datetime, timedelta
|
7 |
+
from Ashare_data import *
|
8 |
+
|
9 |
+
#default symbol
|
10 |
+
symbol = "600519"
|
11 |
+
B_INST, E_INST = "[INST]", "[/INST]"
|
12 |
+
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
|
13 |
+
|
14 |
+
def get_curday():
|
15 |
+
|
16 |
+
return date.today().strftime("%Y%m%d")
|
17 |
+
|
18 |
+
def n_weeks_before(date_string, n, format = "%Y%m%d"):
|
19 |
+
|
20 |
+
date = datetime.strptime(date_string, "%Y%m%d") - timedelta(days=7*n)
|
21 |
+
|
22 |
+
return date.strftime(format=format)
|
23 |
+
|
24 |
+
|
25 |
+
def get_news(symbol, max_page = 3):
|
26 |
+
|
27 |
+
df_list = []
|
28 |
+
for page in range(1, max_page):
|
29 |
+
|
30 |
+
try:
|
31 |
+
df_list.append(ak.stock_news_em(symbol, page))
|
32 |
+
except KeyError:
|
33 |
+
print(str(symbol) + "pages obtained for symbol: " + page)
|
34 |
+
break
|
35 |
+
|
36 |
+
news_df = pd.concat(df_list, ignore_index=True)
|
37 |
+
return news_df
|
38 |
+
|
39 |
+
# get return
|
40 |
+
def get_cur_return(symbol, start_date, end_date, adjust="qfq"):
|
41 |
+
"""
|
42 |
+
date = "yyyymmdd"
|
43 |
+
"""
|
44 |
+
|
45 |
+
# load data
|
46 |
+
return_data = ak.stock_zh_a_hist(symbol=symbol, period="daily", start_date=start_date, end_date=end_date, adjust=adjust)
|
47 |
+
|
48 |
+
# process timestamp
|
49 |
+
return_data["日期"] = pd.to_datetime(return_data["日期"])
|
50 |
+
return_data.set_index("日期", inplace=True)
|
51 |
+
|
52 |
+
# resample and filled with forward data
|
53 |
+
weekly_data = return_data["收盘"].resample("W").ffill()
|
54 |
+
weekly_returns = weekly_data.pct_change()[1:]
|
55 |
+
weekly_start_prices = weekly_data[:-1]
|
56 |
+
weekly_end_prices = weekly_data[1:]
|
57 |
+
weekly_data = pd.DataFrame({
|
58 |
+
'起始日期': weekly_start_prices.index,
|
59 |
+
'起始价': weekly_start_prices.values,
|
60 |
+
'结算日期': weekly_end_prices.index,
|
61 |
+
'结算价': weekly_end_prices.values,
|
62 |
+
'周收益': weekly_returns.values
|
63 |
+
})
|
64 |
+
weekly_data["简化周收益"] = weekly_data["周收益"].map(return_transform)
|
65 |
+
# check enddate
|
66 |
+
if weekly_data.iloc[-1, 2] > pd.to_datetime(end_date):
|
67 |
+
weekly_data.iloc[-1, 2] = pd.to_datetime(end_date)
|
68 |
+
|
69 |
+
return weekly_data
|
70 |
+
|
71 |
+
# get basics
|
72 |
+
def cur_financial_data(symbol, start_date, end_date, with_basics = True):
|
73 |
+
|
74 |
+
# get data
|
75 |
+
data = get_cur_return(symbol=symbol, start_date=start_date, end_date=end_date)
|
76 |
+
|
77 |
+
news_df = get_news(symbol=symbol)
|
78 |
+
news_df["发布时间"] = pd.to_datetime(news_df["发布时间"], exact=False, format="%Y-%m-%d")
|
79 |
+
news_df.sort_values(by=["发布时间"], inplace=True)
|
80 |
+
|
81 |
+
# match weekly news for return data
|
82 |
+
news_list = []
|
83 |
+
for a, row in data.iterrows():
|
84 |
+
week_start_date = row['起始日期'].strftime('%Y-%m-%d')
|
85 |
+
week_end_date = row['结算日期'].strftime('%Y-%m-%d')
|
86 |
+
print(symbol, ': ', week_start_date, ' - ', week_end_date)
|
87 |
+
|
88 |
+
weekly_news = news_df.loc[(news_df["发布时间"]>week_start_date) & (news_df["发布时间"]<week_end_date)]
|
89 |
+
|
90 |
+
weekly_news = [
|
91 |
+
{
|
92 |
+
"发布时间": n["发布时间"].strftime('%Y%m%d'),
|
93 |
+
"新闻标题": n['新闻标题'],
|
94 |
+
"新闻内容": n['新闻内容'],
|
95 |
+
} for a, n in weekly_news.iterrows()
|
96 |
+
]
|
97 |
+
news_list.append(json.dumps(weekly_news,ensure_ascii=False))
|
98 |
+
|
99 |
+
data["新闻"] = news_list
|
100 |
+
|
101 |
+
if with_basics:
|
102 |
+
data = get_basic(symbol=symbol, data=data)
|
103 |
+
# data.to_csv(symbol+start_date+"_"+end_date+".csv")
|
104 |
+
else:
|
105 |
+
data['新闻'] = [json.dumps({})] * len(data)
|
106 |
+
# data.to_csv(symbol+start_date+"_"+end_date+"_nobasics.csv")
|
107 |
+
|
108 |
+
return data
|
109 |
+
|
110 |
+
def get_all_prompts_online(symbol, with_basics=True, max_news_perweek = 3, weeks_before = 2):
|
111 |
+
|
112 |
+
end_date = get_curday()
|
113 |
+
start_date = n_weeks_before(end_date, weeks_before)
|
114 |
+
|
115 |
+
company_prompt, stock = get_company_prompt_new(symbol)
|
116 |
+
data = cur_financial_data(symbol=symbol, start_date=start_date, end_date=end_date, with_basics=with_basics)
|
117 |
+
|
118 |
+
prev_rows = []
|
119 |
+
|
120 |
+
for row_idx, row in data.iterrows():
|
121 |
+
head, news, basics = get_prompt_by_row_new(symbol, row)
|
122 |
+
prev_rows.append((head, news, basics))
|
123 |
+
|
124 |
+
prompt = ""
|
125 |
+
for i in range(-len(prev_rows), 0):
|
126 |
+
prompt += "\n" + prev_rows[i][0]
|
127 |
+
sampled_news = sample_news(
|
128 |
+
prev_rows[i][1],
|
129 |
+
min(max_news_perweek, len(prev_rows[i][1]))
|
130 |
+
)
|
131 |
+
if sampled_news:
|
132 |
+
prompt += "\n".join(sampled_news)
|
133 |
+
else:
|
134 |
+
prompt += "No relative news reported."
|
135 |
+
|
136 |
+
next_date = n_weeks_before(end_date, -1, format="%Y-%m-%d")
|
137 |
+
end_date = pd.to_datetime(end_date).strftime("%Y-%m-%d")
|
138 |
+
period = "{}至{}".format(end_date, next_date)
|
139 |
+
|
140 |
+
if with_basics:
|
141 |
+
basics = prev_rows[-1][2]
|
142 |
+
else:
|
143 |
+
basics = "[金融基本面]:\n\n 无金融基本面记录"
|
144 |
+
|
145 |
+
info = company_prompt + '\n' + prompt + '\n' + basics
|
146 |
+
|
147 |
+
new_system_prompt = SYSTEM_PROMPT.replace(':\n...', ':\n预测涨跌幅:...\n总结分析:...')
|
148 |
+
prompt = B_INST + B_SYS + new_system_prompt + E_SYS + info + f"\n\n基于在{end_date}之前的所有信息,让我们首先分析{stock}的积极发展和潜在担忧。请简洁地陈述,分别提出2-4个最重要的因素。大部分所提及的因素应该从公司的相关新闻中推断出来。" \
|
149 |
+
f"接下来请预测{symbol}下周({period})的股票涨跌幅,并提供一个总结分析来支持你的预测。" + E_INST
|
150 |
+
|
151 |
+
return info, prompt
|
152 |
+
|
153 |
+
if __name__ == "__main__":
|
154 |
+
info, pt = get_all_prompts_online(symbol=symbol)
|
155 |
+
print(pt)
|
app.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM
|
3 |
+
from peft import PeftModel
|
4 |
+
import torch
|
5 |
+
from Ashare_data import *
|
6 |
+
from Inference_datapipe import *
|
7 |
+
import re
|
8 |
+
|
9 |
+
# load model
|
10 |
+
model = "meta-llama/Llama-2-7b-chat-hf"
|
11 |
+
peft_model = "FinGPT/fingpt-forecaster_sz50_llama2-7B_lora"
|
12 |
+
|
13 |
+
tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True)
|
14 |
+
tokenizer.pad_token = tokenizer.eos_token
|
15 |
+
tokenizer.padding_side = "right"
|
16 |
+
|
17 |
+
model = AutoModelForCausalLM.from_pretrained(model, trust_remote_code=True, device_map = 'auto', offload_folder="offload/")
|
18 |
+
model = PeftModel.from_pretrained(model, peft_model, offload_folder="offload/")
|
19 |
+
|
20 |
+
model = model.eval()
|
21 |
+
|
22 |
+
|
23 |
+
def ask(symbol, weeks_before):
|
24 |
+
|
25 |
+
# load inference data
|
26 |
+
info, pt = get_all_prompts_online(symbol=symbol, weeks_before=weeks_before)
|
27 |
+
# print(info)
|
28 |
+
|
29 |
+
inputs = tokenizer(pt, return_tensors='pt')
|
30 |
+
inputs = {key: value.to(model.device) for key, value in inputs.items()}
|
31 |
+
print("Inputs loaded onto devices.")
|
32 |
+
|
33 |
+
res = model.generate(
|
34 |
+
**inputs,
|
35 |
+
use_cache=True
|
36 |
+
)
|
37 |
+
output = tokenizer.decode(res[0], skip_special_tokens=True)
|
38 |
+
output_cur = re.sub(r'.*\[/INST\]\s*', '', output, flags=re.DOTALL)
|
39 |
+
return info, output_cur
|
40 |
+
|
41 |
+
server = gr.Interface(
|
42 |
+
ask,
|
43 |
+
inputs=[
|
44 |
+
gr.Textbox(
|
45 |
+
label="Symbol",
|
46 |
+
value="600519",
|
47 |
+
info="Companys from SZ50 are recommended"
|
48 |
+
),
|
49 |
+
gr.Slider(
|
50 |
+
minimum=1,
|
51 |
+
maximum=3,
|
52 |
+
value=2,
|
53 |
+
step=1,
|
54 |
+
label="weeks_before",
|
55 |
+
info="Due to the token length constraint, you are recommended to input with 2"
|
56 |
+
),
|
57 |
+
],
|
58 |
+
outputs=[
|
59 |
+
gr.Textbox(
|
60 |
+
label="Information"
|
61 |
+
),
|
62 |
+
gr.Textbox(
|
63 |
+
label="Response"
|
64 |
+
)
|
65 |
+
],
|
66 |
+
title="FinGPT-Forecaster-Chinese",
|
67 |
+
description="""This version allows the prediction based on the most current date. We will upgrade it to allow customized date soon."""
|
68 |
+
)
|
69 |
+
|
70 |
+
server.launch()
|