llk010502 commited on
Commit
c171cd4
1 Parent(s): ba20939

application files

Browse files
Files changed (3) hide show
  1. Ashare_data_.py +347 -0
  2. Inference_datapipe_.py +155 -0
  3. app.py +70 -0
Ashare_data_.py ADDED
@@ -0,0 +1,347 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import akshare as ak
2
+ import pandas as pd
3
+ import os
4
+ import csv
5
+ import re
6
+ import time
7
+ import math
8
+ import json
9
+ import random
10
+ from datasets import Dataset
11
+ import datasets
12
+
13
+ # os.chdir("/Users/mac/Desktop/FinGPT_Forecasting_Project/")
14
+ # print(os.getcwd())
15
+
16
+ start_date = "20230201"
17
+ end_date = "20240101"
18
+
19
+ # ------------------------------------------------------------------------------
20
+ # Data Aquisition
21
+ # ------------------------------------------------------------------------------
22
+
23
+ # get return
24
+ def get_return(symbol, adjust="qfq"):
25
+ """
26
+ Get stock return data.
27
+
28
+ Args:
29
+ symbol: str
30
+ A-share market stock symbol
31
+ adjust: str ("qfq", "hfq")
32
+ price ajustment
33
+ default = "qfq" 前复权
34
+
35
+ Return:
36
+ weekly forward filled return data
37
+ """
38
+
39
+ # load data
40
+ return_data = ak.stock_zh_a_hist(symbol=symbol, period="daily", start_date=start_date, end_date=end_date, adjust=adjust)
41
+
42
+ # process timestamp
43
+ return_data["日期"] = pd.to_datetime(return_data["日期"])
44
+ return_data.set_index("日期", inplace=True)
45
+
46
+ # resample and filled with forward data
47
+ weekly_data = return_data["收盘"].resample("W").ffill()
48
+ weekly_returns = weekly_data.pct_change()[1:]
49
+ weekly_start_prices = weekly_data[:-1]
50
+ weekly_end_prices = weekly_data[1:]
51
+ weekly_data = pd.DataFrame({
52
+ '起始日期': weekly_start_prices.index,
53
+ '起始价': weekly_start_prices.values,
54
+ '结算日期': weekly_end_prices.index,
55
+ '结算价': weekly_end_prices.values,
56
+ '周收益': weekly_returns.values
57
+ })
58
+ weekly_data["简化周收益"] = weekly_data["周收益"].map(return_transform)
59
+
60
+ return weekly_data
61
+ def return_transform(ret):
62
+
63
+ up_down = '涨' if ret >= 0 else '跌'
64
+ integer = math.ceil(abs(100 * ret))
65
+ if integer == 0:
66
+ return "平"
67
+
68
+ return up_down + (str(integer) if integer <= 5 else '5+')
69
+
70
+ # get basics
71
+ def get_basic(symbol, data):
72
+ """
73
+ Get and match basic data to news dataframe.
74
+
75
+ Args:
76
+ symbol: str
77
+ A-share market stock symbol
78
+ data: DataFrame
79
+ dated news data
80
+
81
+ Return:
82
+ financial news dataframe with matched basic_financial info
83
+ """
84
+ key_financials = ['报告期', '净利润同比增长率', '营业总收入同比增长率', '流动比率', '速动比率', '资产负债率']
85
+
86
+ # load quarterly basic data
87
+ basic_quarter_financials = ak.stock_financial_abstract_ths(symbol = symbol, indicator="按单季度")
88
+ basic_fin_dict = basic_quarter_financials.to_dict("index")
89
+ basic_fin_list = [dict([(key, val) for key, val in basic_fin_dict[i].items() if (key in key_financials) and val]) for i in range(len(basic_fin_dict))]
90
+
91
+ # match basic financial data to news dataframe
92
+ matched_basic_fin = []
93
+ for i, row in data.iterrows():
94
+
95
+ newsweek_enddate = row['结算日期'].strftime("%Y-%m-%d")
96
+
97
+ matched_basic = {}
98
+ for basic in basic_fin_list:
99
+ # match the most current financial report
100
+ if basic["报告期"] < newsweek_enddate:
101
+ matched_basic = basic
102
+ break
103
+ matched_basic_fin.append(json.dumps(matched_basic, ensure_ascii=False))
104
+
105
+ data['基本面'] = matched_basic_fin
106
+
107
+ return data
108
+
109
+ def raw_financial_data(symbol, with_basics = True):
110
+
111
+ # get return data from API
112
+ data = get_return(symbol=symbol)
113
+
114
+ # get news data from local
115
+ file_name = "news_data" + symbol + ".csv"
116
+ news_df = pd.read_csv("HS300_news_data20240118/"+file_name, index_col=0)
117
+ news_df["发布时间"] = pd.to_datetime(news_df["发布时间"], exact=False, format="%Y-%m-%d")
118
+ news_df.sort_values(by=["发布时间"], inplace=True)
119
+
120
+ # match weekly news for return data
121
+ news_list = []
122
+ for a, row in data.iterrows():
123
+ week_start_date = row['起始日期'].strftime('%Y-%m-%d')
124
+ week_end_date = row['结算日期'].strftime('%Y-%m-%d')
125
+ print(symbol, ': ', week_start_date, ' - ', week_end_date)
126
+
127
+ weekly_news = news_df.loc[(news_df["发布时间"]>week_start_date) & (news_df["发布时间"]<week_end_date)]
128
+
129
+ weekly_news = [
130
+ {
131
+ "发布时间": n["发布时间"].strftime('%Y%m%d'),
132
+ "新闻标题": n['新闻标题'],
133
+ "新闻内容": n['新闻内容'],
134
+ } for a, n in weekly_news.iterrows()
135
+ ]
136
+ news_list.append(json.dumps(weekly_news,ensure_ascii=False))
137
+
138
+ data["新闻"] = news_list
139
+
140
+ if with_basics:
141
+ data = get_basic(symbol=symbol, data=data)
142
+ # data.to_csv(symbol+start_date+"_"+end_date+".csv")
143
+ else:
144
+ data['新闻'] = [json.dumps({})] * len(data)
145
+ # data.to_csv(symbol+start_date+"_"+end_date+"_nobasics.csv")
146
+
147
+ return data
148
+
149
+ # ------------------------------------------------------------------------------
150
+ # Prompt Generation
151
+ # ------------------------------------------------------------------------------
152
+
153
+ # SYSTEM_PROMPT = "你是一个经验丰富的股票市场分析师。你的任务是根据过去几周的相关新闻和基本财务状况,列出公司的积极发展和潜在担忧,然后对公司未来一周的股价变化提供分析和预测。" \
154
+ # "你的回答语言应为中文。你的回答格式应该如下:\n\n[积极发展]:\n1. ...\n\n[潜在担忧]:\n1. ...\n\n[预测和分析]:\n...\n"
155
+ SYSTEM_PROMPT = "你是一名经验丰富的股票市场分析师。你的任务是根据公司在过去几周内的相关新闻和季度财务状况,列出公司的积极发展和潜在担忧,然后结合你对整体金融经济市场的判断,对公司未来一周的股价变化提供预测和分析。" \
156
+ "你的回答语言应为中文。你的回答格式应该如下:\n\n[积极发展]:\n1. ...\n\n[潜在担忧]:\n1. ...\n\n[预测和分析]:\n...\n"
157
+
158
+ def get_company_prompt_new(symbol):
159
+ try:
160
+ company_profile = dict(ak.stock_individual_info_em(symbol).values)
161
+ except:
162
+ print("Company Info Request Time Out! Please wait and retry.")
163
+ company_profile["上市时间"] = pd.to_datetime(str(company_profile["上市时间"])).strftime("%Y年%m月%d日")
164
+
165
+ template = "[公司介绍]:\n\n{股票简称}是一家在{行业}行业的领先实体,自{上市时间}成立并公开交易。截止今天,{股票简称}的总市值为{总市值}人民币,总股本数为{总股本},流通市值为{流通市值}人民币,流通股数为{流通股}。" \
166
+ "\n\n{股票简称}主要在中国运营,以股票代码{股票代码}在交易所进行交易。"
167
+
168
+ formatted_profile = template.format(**company_profile)
169
+ stockname = company_profile['股票简称']
170
+ return formatted_profile, stockname
171
+
172
+ def map_return_label(return_lb):
173
+ """
174
+ Map abbrev in the raw data
175
+ Example:
176
+ 涨1 -- 上涨1%
177
+ 跌2 -- 下跌2%
178
+ 平 -- 股价持平
179
+ """
180
+
181
+ lb = return_lb.replace('涨', '上涨')
182
+ lb = lb.replace('跌', '下跌')
183
+ lb = lb.replace('平', '股价持平')
184
+ lb = lb.replace('1', '0-1%')
185
+ lb = lb.replace('2', '1-2%')
186
+ lb = lb.replace('3', '2-3%')
187
+ lb = lb.replace('4', '3-4%')
188
+ if lb.endswith('+'):
189
+ lb = lb.replace('5+', '超过5%')
190
+ else:
191
+ lb = lb.replace('5', '4-5%')
192
+
193
+ return lb
194
+
195
+ # check news quality
196
+ def check_news_quality(n, last_n, week_end_date, repeat_rate = 0.6):
197
+ try:
198
+ # check content avalability
199
+ if not (not(str(n['新闻内容'])[0].isdigit()) and not(str(n['新闻内容'])=='nan') and n['发布时间'][:8] <= week_end_date.replace('-', '')):
200
+ return False
201
+ # check highly duplicated news
202
+ # (assume the duplicated contents happened adjacent)
203
+
204
+ elif str(last_n['新闻内容'])=='nan':
205
+ return True
206
+ elif len(set(n['新闻内容'][:20]) & set(last_n['新闻内容'][:20])) >= 20*repeat_rate or len(set(n['新闻标题']) & set(last_n['新闻标题']))/len(last_n['新闻标题']) > repeat_rate:
207
+ return False
208
+
209
+ else:
210
+ return True
211
+ except TypeError:
212
+ print(n)
213
+ print(last_n)
214
+ raise Exception("Check Error")
215
+
216
+ def get_prompt_by_row_new(stock, row):
217
+ """
218
+ Generate prompt for each row in the raw data
219
+ Args:
220
+ stock: str
221
+ stock name
222
+ row: pandas.Series
223
+ Return:
224
+ head: heading prompt
225
+ news: news info
226
+ basics: basic financial info
227
+ """
228
+
229
+ week_start_date = row['起始日期'] if isinstance(row['起始日期'], str) else row['起始日期'].strftime('%Y-%m-%d')
230
+ week_end_date = row['结算日期'] if isinstance(row['结算日期'], str) else row['结算日期'].strftime('%Y-%m-%d')
231
+ term = '上涨' if row['结算价'] > row['起始价'] else '下跌'
232
+ chg = map_return_label(row['简化周收益'])
233
+ head = "自{}至{},{}的股票价格由{:.2f}{}至{:.2f},涨跌幅为:{}。在此期间的公司新闻如下:\n\n".format(
234
+ week_start_date, week_end_date, stock, row['起始价'], term, row['结算价'], chg)
235
+
236
+ news = json.loads(row["新闻"])
237
+
238
+ left, right = 0, 0
239
+ filtered_news = []
240
+ while left < len(news):
241
+ n = news[left]
242
+
243
+ if left == 0:
244
+ # check first news quality
245
+ if (not(str(n['新闻内容'])[0].isdigit()) and not(str(n['新闻内容'])=='nan') and n['发布时间'][:8] <= week_end_date.replace('-', '')):
246
+ filtered_news.append("[新闻标题]:{}\n[新闻内容]:{}\n".format(n['新闻标题'], n['新闻内容']))
247
+ left += 1
248
+
249
+ else:
250
+ news_check = check_news_quality(n, last_n = news[right], week_end_date= week_end_date, repeat_rate=0.5)
251
+ if news_check:
252
+ filtered_news.append("[新闻标题]:{}\n[新闻内容]:{}\n".format(n['新闻标题'], n['新闻内容']))
253
+ left += 1
254
+ right += 1
255
+
256
+
257
+ basics = json.loads(row['基本面'])
258
+ if basics:
259
+ basics = "如下所列为{}近期的一些金融基本面信息,记录时间为{}:\n\n[金融基本面]:\n\n".format(
260
+ stock, basics['报告期']) + "\n".join(f"{k}: {v}" for k, v in basics.items() if k != 'period')
261
+ else:
262
+ basics = "[金融基本面]:\n\n 无金融基本面记录"
263
+
264
+ return head, filtered_news, basics
265
+
266
+ def sample_news(news, k=5):
267
+ """
268
+ Ramdomly select past news.
269
+
270
+ Args:
271
+ news:
272
+ newslist in the timerange
273
+ k: int
274
+ the number of selected news
275
+ """
276
+ return [news[i] for i in sorted(random.sample(range(len(news)), k))]
277
+
278
+ def get_all_prompts_new(symbol, min_past_week=1, max_past_weeks=2, with_basics=True):
279
+ """
280
+ Generate prompt. The prompt consists of news from past weeks, basics financial information, and weekly return.
281
+ History news in the prompt is chosen from past weeks range from min_past_week to max_past_week,
282
+ and there is a number constraint on ramdomly selected data (default: up to 5).
283
+
284
+ Args:
285
+ symbol: str
286
+ stock ticker
287
+ min_past_week: int
288
+ max_past_week: int
289
+ with_basics: bool
290
+ If true, add basic infomation to the prompt
291
+
292
+ Return:
293
+ Prompts for the daterange
294
+ """
295
+
296
+ # Load Data
297
+ df = raw_financial_data(symbol, with_basics=with_basics)
298
+
299
+ company_prompt, stock = get_company_prompt_new(symbol)
300
+
301
+ prev_rows = []
302
+ all_prompts = []
303
+
304
+ for row_idx, row in df.iterrows():
305
+
306
+ prompt = ""
307
+
308
+ # judge for available history news
309
+ if len(prev_rows) >= min_past_week:
310
+
311
+ # randomly set retrieve data of past weeks
312
+ # idx = min(random.choice(range(min_past_week, max_past_weeks+1)), len(prev_rows))
313
+ idx = min(max_past_weeks, len(prev_rows))
314
+ for i in range(-idx, 0):
315
+ # Add Head
316
+ prompt += "\n" + prev_rows[i][0]
317
+ # Add History News (with numbers constraint)
318
+ sampled_news = sample_news(
319
+ prev_rows[i][1],
320
+ min(3, len(prev_rows[i][1]))
321
+ )
322
+ if sampled_news:
323
+ prompt += "\n".join(sampled_news)
324
+ else:
325
+ prompt += "无有关新闻报告"
326
+
327
+ head, news, basics = get_prompt_by_row_new(stock, row)
328
+
329
+ prev_rows.append((head, news, basics))
330
+
331
+ if len(prev_rows) > max_past_weeks:
332
+ prev_rows.pop(0)
333
+
334
+ # set this to make sure there is history news for each considered date
335
+ if not prompt:
336
+ continue
337
+
338
+ prediction = map_return_label(row['简化周收益'])
339
+
340
+ prompt = company_prompt + '\n' + prompt + '\n' + basics
341
+
342
+ prompt += f"\n\n基于在{row['起始日期'].strftime('%Y-%m-%d')}之前的所有信息,让我们首先分析{stock}的积极发展和潜在担忧。请简洁地陈述,分别提出2-4个最重要的因素。大部分所提及的因素应该从公司的相关新闻中推断出来。" \
343
+ f"那么让我们假设你对于下一周({row['起始日期'].strftime('%Y-%m-%d')}至{row['结算日期'].strftime('%Y-%m-%d')})的预测是{prediction}。提供一个总结分析来支持你的预测。预测结果需要从你最后的分析中推断出来,因此不作为你分析的基础因素。"
344
+
345
+ all_prompts.append(prompt.strip())
346
+
347
+ return all_prompts
Inference_datapipe_.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Inference Data
2
+ # get company news online
3
+ from datetime import date
4
+ import akshare as ak
5
+ import pandas as pd
6
+ from datetime import date, datetime, timedelta
7
+ from Ashare_data import *
8
+
9
+ #default symbol
10
+ symbol = "600519"
11
+ B_INST, E_INST = "[INST]", "[/INST]"
12
+ B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
13
+
14
+ def get_curday():
15
+
16
+ return date.today().strftime("%Y%m%d")
17
+
18
+ def n_weeks_before(date_string, n, format = "%Y%m%d"):
19
+
20
+ date = datetime.strptime(date_string, "%Y%m%d") - timedelta(days=7*n)
21
+
22
+ return date.strftime(format=format)
23
+
24
+
25
+ def get_news(symbol, max_page = 3):
26
+
27
+ df_list = []
28
+ for page in range(1, max_page):
29
+
30
+ try:
31
+ df_list.append(ak.stock_news_em(symbol, page))
32
+ except KeyError:
33
+ print(str(symbol) + "pages obtained for symbol: " + page)
34
+ break
35
+
36
+ news_df = pd.concat(df_list, ignore_index=True)
37
+ return news_df
38
+
39
+ # get return
40
+ def get_cur_return(symbol, start_date, end_date, adjust="qfq"):
41
+ """
42
+ date = "yyyymmdd"
43
+ """
44
+
45
+ # load data
46
+ return_data = ak.stock_zh_a_hist(symbol=symbol, period="daily", start_date=start_date, end_date=end_date, adjust=adjust)
47
+
48
+ # process timestamp
49
+ return_data["日期"] = pd.to_datetime(return_data["日期"])
50
+ return_data.set_index("日期", inplace=True)
51
+
52
+ # resample and filled with forward data
53
+ weekly_data = return_data["收盘"].resample("W").ffill()
54
+ weekly_returns = weekly_data.pct_change()[1:]
55
+ weekly_start_prices = weekly_data[:-1]
56
+ weekly_end_prices = weekly_data[1:]
57
+ weekly_data = pd.DataFrame({
58
+ '起始日期': weekly_start_prices.index,
59
+ '起始价': weekly_start_prices.values,
60
+ '结算日期': weekly_end_prices.index,
61
+ '结算价': weekly_end_prices.values,
62
+ '周收益': weekly_returns.values
63
+ })
64
+ weekly_data["简化周收益"] = weekly_data["周收益"].map(return_transform)
65
+ # check enddate
66
+ if weekly_data.iloc[-1, 2] > pd.to_datetime(end_date):
67
+ weekly_data.iloc[-1, 2] = pd.to_datetime(end_date)
68
+
69
+ return weekly_data
70
+
71
+ # get basics
72
+ def cur_financial_data(symbol, start_date, end_date, with_basics = True):
73
+
74
+ # get data
75
+ data = get_cur_return(symbol=symbol, start_date=start_date, end_date=end_date)
76
+
77
+ news_df = get_news(symbol=symbol)
78
+ news_df["发布时间"] = pd.to_datetime(news_df["发布时间"], exact=False, format="%Y-%m-%d")
79
+ news_df.sort_values(by=["发布时间"], inplace=True)
80
+
81
+ # match weekly news for return data
82
+ news_list = []
83
+ for a, row in data.iterrows():
84
+ week_start_date = row['起始日期'].strftime('%Y-%m-%d')
85
+ week_end_date = row['结算日期'].strftime('%Y-%m-%d')
86
+ print(symbol, ': ', week_start_date, ' - ', week_end_date)
87
+
88
+ weekly_news = news_df.loc[(news_df["发布时间"]>week_start_date) & (news_df["发布时间"]<week_end_date)]
89
+
90
+ weekly_news = [
91
+ {
92
+ "发布时间": n["发布时间"].strftime('%Y%m%d'),
93
+ "新闻标题": n['新闻标题'],
94
+ "新闻内容": n['新闻内容'],
95
+ } for a, n in weekly_news.iterrows()
96
+ ]
97
+ news_list.append(json.dumps(weekly_news,ensure_ascii=False))
98
+
99
+ data["新闻"] = news_list
100
+
101
+ if with_basics:
102
+ data = get_basic(symbol=symbol, data=data)
103
+ # data.to_csv(symbol+start_date+"_"+end_date+".csv")
104
+ else:
105
+ data['新闻'] = [json.dumps({})] * len(data)
106
+ # data.to_csv(symbol+start_date+"_"+end_date+"_nobasics.csv")
107
+
108
+ return data
109
+
110
+ def get_all_prompts_online(symbol, with_basics=True, max_news_perweek = 3, weeks_before = 2):
111
+
112
+ end_date = get_curday()
113
+ start_date = n_weeks_before(end_date, weeks_before)
114
+
115
+ company_prompt, stock = get_company_prompt_new(symbol)
116
+ data = cur_financial_data(symbol=symbol, start_date=start_date, end_date=end_date, with_basics=with_basics)
117
+
118
+ prev_rows = []
119
+
120
+ for row_idx, row in data.iterrows():
121
+ head, news, basics = get_prompt_by_row_new(symbol, row)
122
+ prev_rows.append((head, news, basics))
123
+
124
+ prompt = ""
125
+ for i in range(-len(prev_rows), 0):
126
+ prompt += "\n" + prev_rows[i][0]
127
+ sampled_news = sample_news(
128
+ prev_rows[i][1],
129
+ min(max_news_perweek, len(prev_rows[i][1]))
130
+ )
131
+ if sampled_news:
132
+ prompt += "\n".join(sampled_news)
133
+ else:
134
+ prompt += "No relative news reported."
135
+
136
+ next_date = n_weeks_before(end_date, -1, format="%Y-%m-%d")
137
+ end_date = pd.to_datetime(end_date).strftime("%Y-%m-%d")
138
+ period = "{}至{}".format(end_date, next_date)
139
+
140
+ if with_basics:
141
+ basics = prev_rows[-1][2]
142
+ else:
143
+ basics = "[金融基本面]:\n\n 无金融基本面记录"
144
+
145
+ info = company_prompt + '\n' + prompt + '\n' + basics
146
+
147
+ new_system_prompt = SYSTEM_PROMPT.replace(':\n...', ':\n预测涨跌幅:...\n总结分析:...')
148
+ prompt = B_INST + B_SYS + new_system_prompt + E_SYS + info + f"\n\n基于在{end_date}之前的所有信息,让我们首先分析{stock}的积极发展和潜在担忧。请简洁地陈述,分别提出2-4个最重要的因素。大部分所提及的因素应该从公司的相关新闻中推断出来。" \
149
+ f"接下来请预测{symbol}下周({period})的股票涨跌幅,并提供一个总结分析来支持你的预测。" + E_INST
150
+
151
+ return info, prompt
152
+
153
+ if __name__ == "__main__":
154
+ info, pt = get_all_prompts_online(symbol=symbol)
155
+ print(pt)
app.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM
3
+ from peft import PeftModel
4
+ import torch
5
+ from Ashare_data import *
6
+ from Inference_datapipe import *
7
+ import re
8
+
9
+ # load model
10
+ model = "meta-llama/Llama-2-7b-chat-hf"
11
+ peft_model = "FinGPT/fingpt-forecaster_sz50_llama2-7B_lora"
12
+
13
+ tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True)
14
+ tokenizer.pad_token = tokenizer.eos_token
15
+ tokenizer.padding_side = "right"
16
+
17
+ model = AutoModelForCausalLM.from_pretrained(model, trust_remote_code=True, device_map = 'auto', offload_folder="offload/")
18
+ model = PeftModel.from_pretrained(model, peft_model, offload_folder="offload/")
19
+
20
+ model = model.eval()
21
+
22
+
23
+ def ask(symbol, weeks_before):
24
+
25
+ # load inference data
26
+ info, pt = get_all_prompts_online(symbol=symbol, weeks_before=weeks_before)
27
+ # print(info)
28
+
29
+ inputs = tokenizer(pt, return_tensors='pt')
30
+ inputs = {key: value.to(model.device) for key, value in inputs.items()}
31
+ print("Inputs loaded onto devices.")
32
+
33
+ res = model.generate(
34
+ **inputs,
35
+ use_cache=True
36
+ )
37
+ output = tokenizer.decode(res[0], skip_special_tokens=True)
38
+ output_cur = re.sub(r'.*\[/INST\]\s*', '', output, flags=re.DOTALL)
39
+ return info, output_cur
40
+
41
+ server = gr.Interface(
42
+ ask,
43
+ inputs=[
44
+ gr.Textbox(
45
+ label="Symbol",
46
+ value="600519",
47
+ info="Companys from SZ50 are recommended"
48
+ ),
49
+ gr.Slider(
50
+ minimum=1,
51
+ maximum=3,
52
+ value=2,
53
+ step=1,
54
+ label="weeks_before",
55
+ info="Due to the token length constraint, you are recommended to input with 2"
56
+ ),
57
+ ],
58
+ outputs=[
59
+ gr.Textbox(
60
+ label="Information"
61
+ ),
62
+ gr.Textbox(
63
+ label="Response"
64
+ )
65
+ ],
66
+ title="FinGPT-Forecaster-Chinese",
67
+ description="""This version allows the prediction based on the most current date. We will upgrade it to allow customized date soon."""
68
+ )
69
+
70
+ server.launch()