File size: 17,422 Bytes
c171cd4
3739f4a
c171cd4
 
6ac2431
5165373
 
 
 
 
 
 
 
 
3739f4a
0d3bc03
c171cd4
 
 
 
 
0d3bc03
c171cd4
 
 
3739f4a
 
 
c171cd4
 
 
 
5165373
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3739f4a
 
 
 
5165373
 
c171cd4
80c8075
c171cd4
 
80c8075
c171cd4
 
 
 
 
 
 
 
3739f4a
 
c171cd4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80c8075
 
 
 
 
c171cd4
 
 
80c8075
c171cd4
 
 
 
 
 
342ab47
 
80c8075
 
342ab47
 
 
69ae4fa
342ab47
c171cd4
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
import gradio as gr
from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM, TextStreamer
from peft import PeftModel
import re
import os
import akshare as ak
import pandas as pd
import random
import json
import requests
import math
from datetime import date
from datetime import date, datetime, timedelta


access_token = os.environ["TOKEN"]

# load model
model = "meta-llama/Llama-2-7b-chat-hf"
peft_model = "FinGPT/fingpt-forecaster_sz50_llama2-7B_lora"

tokenizer = AutoTokenizer.from_pretrained(model, token = access_token, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

streamer = TextStreamer(tokenizer)

model = AutoModelForCausalLM.from_pretrained(model, trust_remote_code=True, token = access_token, device_map="cuda", load_in_8bit=True, offload_folder="offload/")
model = PeftModel.from_pretrained(model, peft_model, offload_folder="offload/")

model = model.eval()

# Inference Data
# get company news online

B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
SYSTEM_PROMPT = "你是一名经验丰富的股票市场分析师。你的任务是根据公司在过去几周内的相关新闻和季度财务状况,列出公司的积极发展和潜在担忧,然后结合你对整体金融经济市场的判断,对公司未来一周的股价变化提供预测和分析。" \
    "你的回答语言应为中文。你的回答格式应该如下:\n\n[积极发展]:\n1. ...\n\n[潜在担忧]:\n1. ...\n\n[预测和分析]:\n...\n"

# ------------------------------------------------------------------------------
# Utils
# ------------------------------------------------------------------------------
def get_curday():
    
    return date.today().strftime("%Y%m%d")

def n_weeks_before(date_string, n, format = "%Y%m%d"):
    
    date = datetime.strptime(date_string, "%Y%m%d") - timedelta(days=7*n)
    
    return date.strftime(format=format)

def check_news_quality(n, last_n, week_end_date, repeat_rate = 0.6):
    try:
        # check content avalability
        if not (not(str(n['新闻内容'])[0].isdigit()) and not(str(n['新闻内容'])=='nan') and n['发布时间'][:8] <= week_end_date.replace('-', '')):
            return False
        # check highly duplicated news
        # (assume the duplicated contents happened adjacent)

        elif str(last_n['新闻内容'])=='nan':
            return True
        elif len(set(n['新闻内容'][:20]) & set(last_n['新闻内容'][:20])) >= 20*repeat_rate or len(set(n['新闻标题']) & set(last_n['新闻标题']))/len(last_n['新闻标题']) > repeat_rate:
            return False
        
        else:
            return True
    except TypeError:
        print(n)
        print(last_n)
        raise Exception("Check Error")

def sample_news(news, k=5):

    return [news[i] for i in sorted(random.sample(range(len(news)), k))]

def return_transform(ret):
    
    up_down = '涨' if ret >= 0 else '跌'
    integer = math.ceil(abs(100 * ret))
    if integer == 0:
        return "平"
    
    return up_down + (str(integer) if integer <= 5 else '5+')

def map_return_label(return_lb):

    lb = return_lb.replace('涨', '上涨')
    lb = lb.replace('跌', '下跌')
    lb = lb.replace('平', '股价持平')
    lb = lb.replace('1', '0-1%')
    lb = lb.replace('2', '1-2%')
    lb = lb.replace('3', '2-3%')
    lb = lb.replace('4', '3-4%')
    if lb.endswith('+'):
        lb = lb.replace('5+', '超过5%')
    else:
        lb = lb.replace('5', '4-5%')
    
    return lb
# ------------------------------------------------------------------------------
# Get data from website
# ------------------------------------------------------------------------------
def stock_news_em(symbol: str = "300059", page = 1) -> pd.DataFrame:
    
    url = "https://search-api-web.eastmoney.com/search/jsonp"
    params = {
        "cb": "jQuery3510875346244069884_1668256937995",
        "param": '{"uid":"",'
        + f'"keyword":"{symbol}"'
        + ',"type":["cmsArticleWebOld"],"client":"web","clientType":"web","clientVersion":"curr","param":{"cmsArticleWebOld":{"searchScope":"default","sort":"default",' + f'"pageIndex":{page}'+ ',"pageSize":100,"preTag":"<em>","postTag":"</em>"}}}',
        "_": "1668256937996",
    }
    r = requests.get(url, params=params)
    data_text = r.text
    data_json = json.loads(
        data_text.strip("jQuery3510875346244069884_1668256937995(")[:-1]
    )
    temp_df = pd.DataFrame(data_json["result"]["cmsArticleWebOld"])
    temp_df.rename(
        columns={
            "date": "发布时间",
            "mediaName": "文章来源",
            "code": "-",
            "title": "新闻标题",
            "content": "新闻内容",
            "url": "新闻链接",
            "image": "-",
        },
        inplace=True,
    )
    temp_df["关键词"] = symbol
    temp_df = temp_df[
        [
            "关键词",
            "新闻标题",
            "新闻内容",
            "发布时间",
            "文章来源",
            "新闻链接",
        ]
    ]
    temp_df["新闻标题"] = (
        temp_df["新闻标题"]
        .str.replace(r"\(<em>", "", regex=True)
        .str.replace(r"</em>\)", "", regex=True)
    )
    temp_df["新闻标题"] = (
        temp_df["新闻标题"]
        .str.replace(r"<em>", "", regex=True)
        .str.replace(r"</em>", "", regex=True)
    )
    temp_df["新闻内容"] = (
        temp_df["新闻内容"]
        .str.replace(r"\(<em>", "", regex=True)
        .str.replace(r"</em>\)", "", regex=True)
    )
    temp_df["新闻内容"] = (
        temp_df["新闻内容"]
        .str.replace(r"<em>", "", regex=True)
        .str.replace(r"</em>", "", regex=True)
    )
    temp_df["新闻内容"] = temp_df["新闻内容"].str.replace(r"\u3000", "", regex=True)
    temp_df["新闻内容"] = temp_df["新闻内容"].str.replace(r"\r\n", " ", regex=True)
    return temp_df
    
def get_news(symbol, max_page = 3):
    
    df_list = []
    for page in range(1, max_page):
        
        try:
            df_list.append(stock_news_em(symbol, page))
        except KeyError:
            print(str(symbol) + "pages obtained for symbol: " + page)
            break

    news_df = pd.concat(df_list, ignore_index=True)
    return news_df

def get_cur_return(symbol, start_date, end_date, adjust="qfq"):

    # load data
    return_data = ak.stock_zh_a_hist(symbol=symbol, period="daily", start_date=start_date, end_date=end_date, adjust=adjust)
    
    # process timestamp
    return_data["日期"] = pd.to_datetime(return_data["日期"])
    return_data.set_index("日期", inplace=True)

    # resample and filled with forward data
    weekly_data = return_data["收盘"].resample("W").ffill()
    weekly_returns = weekly_data.pct_change()[1:]
    weekly_start_prices = weekly_data[:-1]
    weekly_end_prices = weekly_data[1:]
    weekly_data = pd.DataFrame({
        '起始日期': weekly_start_prices.index,
        '起始价': weekly_start_prices.values,
        '结算日期': weekly_end_prices.index,
        '结算价': weekly_end_prices.values,
        '周收益': weekly_returns.values
    })
    weekly_data["简化周收益"] = weekly_data["周收益"].map(return_transform)
    # check enddate
    if weekly_data.iloc[-1, 2] > pd.to_datetime(end_date):
        weekly_data.iloc[-1, 2] = pd.to_datetime(end_date)

    return weekly_data

def get_basic(symbol, data):

    key_financials = ['报告期', '净利润同比增长率', '营业总收入同比增长率', '流动比率', '速动比率', '资产负债率']
    
    # load quarterly basic data
    basic_quarter_financials = ak.stock_financial_abstract_ths(symbol = symbol, indicator="按单季度")
    basic_fin_dict = basic_quarter_financials.to_dict("index")
    basic_fin_list = [dict([(key, val) for key, val in basic_fin_dict[i].items() if (key in key_financials) and val]) for i in range(len(basic_fin_dict))]

    # match basic financial data to news dataframe
    matched_basic_fin = []
    for i, row in data.iterrows():

        newsweek_enddate = row['结算日期'].strftime("%Y-%m-%d")

        matched_basic = {}
        for basic in basic_fin_list:
            # match the most current financial report
            if basic["报告期"] < newsweek_enddate:
                matched_basic = basic
                break
        matched_basic_fin.append(json.dumps(matched_basic, ensure_ascii=False))

    data['基本面'] = matched_basic_fin

    return data
# ------------------------------------------------------------------------------
# Structure Data 
# ------------------------------------------------------------------------------
def cur_financial_data(symbol, start_date, end_date, with_basics = True):
    
    # get data
    data = get_cur_return(symbol=symbol, start_date=start_date, end_date=end_date)

    news_df = get_news(symbol=symbol)
    news_df["发布时间"] = pd.to_datetime(news_df["发布时间"], exact=False, format="%Y-%m-%d")
    news_df.sort_values(by=["发布时间"], inplace=True)
    
    # match weekly news for return data
    news_list = []
    for a, row in data.iterrows():
        week_start_date = row['起始日期'].strftime('%Y-%m-%d')
        week_end_date = row['结算日期'].strftime('%Y-%m-%d')
        print(symbol, ': ', week_start_date, ' - ', week_end_date)
        
        weekly_news = news_df.loc[(news_df["发布时间"]>week_start_date) & (news_df["发布时间"]<week_end_date)]

        weekly_news = [
            {
                "发布时间": n["发布时间"].strftime('%Y%m%d'),
                "新闻标题": n['新闻标题'],
                "新闻内容": n['新闻内容'],
            } for a, n in weekly_news.iterrows()
        ]
        news_list.append(json.dumps(weekly_news,ensure_ascii=False))

    data["新闻"] = news_list

    if with_basics:
        data = get_basic(symbol=symbol, data=data)
        # data.to_csv(symbol+start_date+"_"+end_date+".csv")
    else:
        data['新闻'] = [json.dumps({})] * len(data)
        # data.to_csv(symbol+start_date+"_"+end_date+"_nobasics.csv")
    
    return data
# ------------------------------------------------------------------------------
# Formate Instruction 
# ------------------------------------------------------------------------------
def get_company_prompt_new(symbol):
    try:
        company_profile = dict(ak.stock_individual_info_em(symbol).values)
    except:
        print("Company Info Request Time Out! Please wait and retry.")
    company_profile["上市时间"] =  pd.to_datetime(str(company_profile["上市时间"])).strftime("%Y年%m月%d日")

    template = "[公司介绍]:\n\n{股票简称}是一家在{行业}行业的领先实体,自{上市时间}成立并公开交易。截止今天,{股票简称}的总市值为{总市值}人民币,总股本数为{总股本},流通市值为{流通市值}人民币,流通股数为{流通股}。" \
        "\n\n{股票简称}主要在中国运营,以股票代码{股票代码}在交易所进行交易。"
    
    formatted_profile = template.format(**company_profile)
    stockname = company_profile['股票简称']
    return formatted_profile, stockname

def get_prompt_by_row_new(stock, row):

    week_start_date = row['起始日期'] if isinstance(row['起始日期'], str) else row['起始日期'].strftime('%Y-%m-%d')
    week_end_date = row['结算日期'] if isinstance(row['结算日期'], str) else row['结算日期'].strftime('%Y-%m-%d')
    term = '上涨' if row['结算价'] > row['起始价'] else '下跌'
    chg = map_return_label(row['简化周收益'])
    head = "自{}至{},{}的股票价格由{:.2f}{}至{:.2f},涨跌幅为:{}。在此期间的公司新闻如下:\n\n".format(
        week_start_date, week_end_date, stock, row['起始价'], term, row['结算价'], chg)

    news = json.loads(row["新闻"])

    left, right = 0, 0
    filtered_news = []
    while left < len(news):
        n = news[left]

        if left == 0:
            # check first news quality
            if (not(str(n['新闻内容'])[0].isdigit()) and not(str(n['新闻内容'])=='nan') and n['发布时间'][:8] <= week_end_date.replace('-', '')):
                filtered_news.append("[新闻标题]:{}\n[新闻内容]:{}\n".format(n['新闻标题'], n['新闻内容']))
            left += 1

        else:
            news_check = check_news_quality(n, last_n = news[right], week_end_date= week_end_date, repeat_rate=0.5)
            if news_check:
                filtered_news.append("[新闻标题]:{}\n[新闻内容]:{}\n".format(n['新闻标题'], n['新闻内容']))
            left += 1
            right += 1


    basics = json.loads(row['基本面'])
    if basics:
        basics = "如下所列为{}近期的一些金融基本面信息,记录时间为{}:\n\n[金融基本面]:\n\n".format(
            stock, basics['报告期']) + "\n".join(f"{k}: {v}" for k, v in basics.items() if k != 'period')
    else:
        basics = "[金融基本面]:\n\n 无金融基本面记录"

    return head, filtered_news, basics

def get_all_prompts_online(symbol, with_basics=True, max_news_perweek = 3, weeks_before = 2):

    end_date = get_curday()
    start_date = n_weeks_before(end_date, weeks_before)

    company_prompt, stock = get_company_prompt_new(symbol)
    data = cur_financial_data(symbol=symbol, start_date=start_date, end_date=end_date, with_basics=with_basics)

    prev_rows = []

    for row_idx, row in data.iterrows():
        head, news, basics = get_prompt_by_row_new(symbol, row)
        prev_rows.append((head, news, basics))
        
    prompt = ""
    for i in range(-len(prev_rows), 0):
        prompt += "\n" + prev_rows[i][0]
        sampled_news = sample_news(
            prev_rows[i][1],
            min(max_news_perweek, len(prev_rows[i][1]))
        )
        if sampled_news:
            prompt += "\n".join(sampled_news)
        else:
            prompt += "No relative news reported."
    
    next_date = n_weeks_before(end_date, -1, format="%Y-%m-%d")
    end_date = pd.to_datetime(end_date).strftime("%Y-%m-%d")
    period = "{}至{}".format(end_date, next_date)
    
    if with_basics:
        basics = prev_rows[-1][2]
    else:
        basics = "[金融基本面]:\n\n 无金融基本面记录"
    
    info = company_prompt + '\n' + prompt + '\n' + basics

    new_system_prompt = SYSTEM_PROMPT.replace(':\n...', ':\n预测涨跌幅:...\n总结分析:...')
    prompt = B_INST + B_SYS + new_system_prompt + E_SYS + info + f"\n\n基于在{end_date}之前的所有信息,让我们首先分析{stock}的积极发展和潜在担忧。请简洁地陈述,分别提出2-4个最重要的因素。大部分所提及的因素应该从公司的相关新闻中推断出来。" \
        f"接下来请预测{symbol}下周({period})的股票涨跌幅,并提供一个总结分析来支持你的预测。" + E_INST
    
    del prev_rows
    del data

    return info, prompt


def ask(symbol, weeks_before, withbasic):

  # load inference data
  info, pt = get_all_prompts_online(symbol=symbol, weeks_before=weeks_before, with_basics=withbasic)
#   print(info)

  inputs = tokenizer(pt, return_tensors='pt')
  inputs = {key: value.to(model.device) for key, value in inputs.items()}
  print("Inputs loaded onto devices.")

  res = model.generate(
      **inputs,
      use_cache=True,
      streamer=streamer
  )
  output = tokenizer.decode(res[0], skip_special_tokens=True)
  output_cur = re.sub(r'.*\[/INST\]\s*', '', output, flags=re.DOTALL)
  return info, output_cur

server = gr.Interface(
    ask,
    inputs=[
        gr.Textbox(
            label="Symbol",
            value="600519",
            info="Companys from SZ50 are recommended"
        ),
        gr.Slider(
            minimum=1,
            maximum=3,
            value=2,
            step=1,
            label="weeks_before",
            info="Due to the token length constraint, you are recommended to input with 2"
        ),
        gr.Checkbox(
            label="Use Latest Basic Financials",
            value=True,
            info="If checked, the latest quarterly reported basic financials of the company is taken into account."
        )
    ],
    outputs=[
        gr.Textbox(
            label="Information Provided"
        ),
        gr.Textbox(
            label="Response"
        )
    ],
    title="FinGPT-Forecaster-Chinese",
    description="""This version allows the prediction based on the most current date. We will upgrade it to allow customized date soon.

    **The estimated time cost is 180s**

    This demo has been downgraded to using T4 with 8-bit inference due to cost considerations, speed & performance may be affected.
    
    **⚠️Warning**: This is just a demo showing what this model can do. During each individual inference, company news is randomly sampled from all the news from designated weeks, which might result in different predictions for the same period. We suggest users deploy the original model or clone this space and inference with more carefully selected news in their favorable ways.
    
    **Disclaimer**: Nothing herein is financial advice, and NOT a recommendation to trade real money. Please use common sense and always first consult a professional before trading or investing."""
)

server.launch()