SHSH0819 commited on
Commit
9db14d3
1 Parent(s): a7f2ad1

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +83 -0
app.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ sys.path.insert(0, os.path.abspath('./'))
4
+
5
+ import torch
6
+ from tqdm.auto import tqdm
7
+ from torch.utils.data import DataLoader, random_split
8
+ from transformers import AutoTokenizer, AutoModelForMaskedLM
9
+ from event_detection_dataset import *
10
+ from event_detection_model import *
11
+
12
+ import gradio as gr
13
+ #print(f"Gradio version: {gr.__version__}")
14
+
15
+
16
+ def predict(data):
17
+ data=[data]
18
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
19
+ #print(f"Device {device}")
20
+
21
+
22
+ """Load Tokenizer"""
23
+ tokenizer = AutoTokenizer.from_pretrained('distilbert-base-cased', use_fast=True)
24
+
25
+
26
+ """Tokenized Inputs"""
27
+ tokenized_inputs = tokenizer(
28
+ data,
29
+ add_special_tokens=True,
30
+ max_length=512,
31
+ padding='max_length',
32
+ return_token_type_ids=True,
33
+ truncation=True,
34
+ is_split_into_words=True
35
+ )
36
+
37
+
38
+ """Load Model"""
39
+ model_path = "./"
40
+ #print("model_path:", model_path)
41
+
42
+ #print("================ load model ===========================")
43
+ model = DistillBERTClass('distilbert-base-cased')
44
+
45
+ #print("================ model init ===========================")
46
+ pretrained_model=torch.load(model_path + "/event_domain_final.pt",map_location=torch.device('cpu'))
47
+ model.load_state_dict(pretrained_model['model_state_dict'])
48
+ model.to(device)
49
+
50
+
51
+ """Make Prediction"""
52
+ model.eval()
53
+
54
+ ids = torch.tensor(tokenized_inputs['input_ids']).to(device)
55
+ mask = torch.tensor(tokenized_inputs['attention_mask']).to(device)
56
+
57
+ with torch.no_grad():
58
+ outputs = model(ids, mask)
59
+
60
+ max_val, max_idx = torch.max(outputs.data, dim=1)
61
+
62
+ #print("=============== inference result =================")
63
+ #print(f"predicted class {max_idx}")
64
+ id2tags={0: "Acquisition",1: "I-Positive Clinical Trial & FDA Approval",2: "Dividend Cut",3: "Dividend Increase",4: "Guidance Increase",5: "New Contract",6: "Dividend",7: "Reverse Stock Split",8: "Special Dividend ",9: "Stock Repurchase",10: "Stock Split",11: "Others"}
65
+ return id2tags[max_idx.item()]
66
+
67
+
68
+ title="Financial Event Detection"
69
+ description="Predict Finacial Events."
70
+ article="modified the model in the following paper: Zhou, Z., Ma, L., & Liu, H. (2021)."
71
+ example_list=[["Investors who receive dividends can choose to take them as cash or as additional shares."]]
72
+
73
+ # Create the Gradio demo
74
+ demo = gr.Interface(fn=predict, # mapping function from input to output
75
+ inputs="text", # what are the inputs?
76
+ outputs="text", # our fn has two outputs, therefore we have two outputs
77
+ examples=example_list,
78
+ title=title,
79
+ description=description,
80
+ article=article)
81
+
82
+ # Launch the demo!
83
+ demo.launch(debug=False, share=True)