from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoConfig import re import gradio as gr from torch.nn import functional as F import seaborn import matplotlib import platform from transformers.file_utils import ModelOutput if platform.system() == "Darwin": print("MacOS") matplotlib.use('Agg') import matplotlib.pyplot as plt import io from PIL import Image import matplotlib.font_manager as fm # global var MODEL_NAME = 'yseop/FNP_T5_D2T_complete' tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME) config = AutoConfig.from_pretrained(MODEL_NAME) MODEL_BUF = { "name": MODEL_NAME, "tokenizer": tokenizer, "model": model, "config": config } font_dir = ['./'] for font in fm.findSystemFonts(font_dir): print(font) fm.fontManager.addfont(font) plt.rcParams["font.family"] = 'NanumGothicCoding' def change_model_name(name): MODEL_BUF["name"] = name MODEL_BUF["tokenizer"] = AutoTokenizer.from_pretrained(name) MODEL_BUF["model"] = AutoModelForSeq2SeqLM.from_pretrained(name) MODEL_BUF["config"] = AutoConfig.from_pretrained(name) def generate(model_name, text): if model_name != MODEL_NAME: change_model_name(model_name) tokenizer = MODEL_BUF["tokenizer"] model = MODEL_BUF["model"] config = MODEL_BUF["config"] model.eval() input_ids = tokenizer.encode("AFA:{}".format(text), return_tensors="pt") outputs = model.generate(input_ids, max_length=200, num_beams=2, repetition_penalty=2.5, top_k=50, top_p=0.98, length_penalty=1.0, early_stopping=True) output = tokenizer.decode(outputs[0]) #return ".".join(output.split(".")[:-1]) + "." sent = ".".join(output.split(".")[:-1]) + "." return re.match(r' ([^<>]*)', sent).group(1) output_text = gr.outputs.Textbox() if __name__ == '__main__': text = ['Group profit | valIs | € 115.7 million && € 115.7 million | dTime | in 2019'] example = [['Net income | valIs | $48 million && $48 million | diGeo | in France && Net income | jPose | the interest rate && the interest rate | valIs | 0.6%'], ['The retirement age | incBy | 7 years && 7 years | cTime | 2018 && The retirement age | jpose | life expectancy && life expectancy | incBy | 10 years'], ['sales | incBy | € 115.7 million && € 115.7 million | dTime | in 2019 && € 115.7 million | diGeo | Europe']] model_name_list = [ 'yseop/FNP_T5_D2T_complete', 'yseop/FNP_T5_D2T_simple' ] app = gr.Interface( fn=generate, inputs=[gr.inputs.Dropdown(model_name_list, label="Model Name"), 'text'], outputs=output_text, examples = [[MODEL_BUF["name"], text]], title="FTG", description="Financial Text Generation" ) app.launch(inline=False)