dmariko commited on
Commit
0b71553
1 Parent(s): 43ba32a
Files changed (1) hide show
  1. app.py +73 -0
app.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoConfig
2
+ import gradio as gr
3
+ from torch.nn import functional as F
4
+ import seaborn
5
+ import matplotlib
6
+ import platform
7
+ from transformers.file_utils import ModelOutput
8
+
9
+ if platform.system() == "Darwin":
10
+
11
+ print("MacOS")
12
+ matplotlib.use('Agg')
13
+
14
+ import matplotlib.pyplot as plt
15
+ import io
16
+ from PIL import Image
17
+ import matplotlib.font_manager as fm
18
+
19
+ # global var
20
+
21
+ MODEL_NAME = 'https://huggingface.co/yseop/FNP_T5_D2T_complete'
22
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
23
+ model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
24
+ config = AutoConfig.from_pretrained(MODEL_NAME)
25
+
26
+ MODEL_BUF = {
27
+ "name": MODEL_NAME,
28
+ "tokenizer": tokenizer,
29
+ "model": model,
30
+ "config": config
31
+ }
32
+
33
+ font_dir = ['./']
34
+
35
+ for font in fm.findSystemFonts(font_dir):
36
+ print(font)
37
+ fm.fontManager.addfont(font)
38
+
39
+ plt.rcParams["font.family"] = 'NanumGothicCoding'
40
+
41
+ def change_model_name(name):
42
+
43
+ MODEL_BUF["name"] = name
44
+ MODEL_BUF["tokenizer"] = AutoTokenizer.from_pretrained(name)
45
+ MODEL_BUF["model"] = AutoModelForSequenceClassification.from_pretrained(name)
46
+ MODEL_BUF["config"] = AutoConfig.from_pretrained(name)
47
+
48
+
49
+ def generate(text, model, tokenizer):
50
+ model.eval()
51
+ input_ids = tokenizer.encode("webNLG:{}".format(text), return_tensors="pt")
52
+ outputs = model.generate(input_ids, max_length=200, num_beams=2, repetition_penalty=2.5, top_k=50, top_p=0.98, length_penalty=1.0, early_stopping=True)
53
+ return tokenizer.decode(outputs[0])
54
+
55
+
56
+
57
+
58
+ if __name__ == '__main__':
59
+
60
+ text = 'Group profit | valIs | € 115.7 million && € 115.7 million | dTime | in 2019'
61
+
62
+ model_name_list = [
63
+
64
+ 'yseop/distilbert-base-financial-relation-extraction'
65
+
66
+ ]
67
+
68
+ inputs = [gradio.inputs.Textbox(label="Input")]
69
+
70
+ outputs = gradio.outputs.Textbox(label='Output')
71
+
72
+ iface = gradio.Interface(fn=generate, inputs=inputs, outputs=outputs, capture_session=True, examples=examples,
73
+ title=title, description=description, allow_flagging=False)