import os import sys sys.path.insert(0, os.path.abspath('./')) import torch from tqdm.auto import tqdm from torch.utils.data import DataLoader, random_split from transformers import AutoTokenizer, AutoModelForMaskedLM from event_detection_dataset import * from event_detection_model import * import gradio as gr #print(f"Gradio version: {gr.__version__}") def predict(data): data=[data] device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") #print(f"Device {device}") """Load Tokenizer""" tokenizer = AutoTokenizer.from_pretrained('distilbert-base-cased', use_fast=True) """Tokenized Inputs""" tokenized_inputs = tokenizer( data, add_special_tokens=True, max_length=512, padding='max_length', return_token_type_ids=True, truncation=True, is_split_into_words=True ) """Load Model""" model_path = "./" #print("model_path:", model_path) #print("================ load model ===========================") model = DistillBERTClass('distilbert-base-cased') #print("================ model init ===========================") pretrained_model=torch.load(model_path + "event_domain_final.pt",map_location=torch.device('cpu')) model.load_state_dict(pretrained_model['model_state_dict']) model.to(device) """Make Prediction""" model.eval() ids = torch.tensor(tokenized_inputs['input_ids']).to(device) mask = torch.tensor(tokenized_inputs['attention_mask']).to(device) with torch.no_grad(): outputs = model(ids, mask) max_val, max_idx = torch.max(outputs.data, dim=1) #print("=============== inference result =================") #print(f"predicted class {max_idx}") id2tags={0: "Acquisition",1: "I-Positive Clinical Trial & FDA Approval",2: "Dividend Cut",3: "Dividend Increase",4: "Guidance Increase",5: "New Contract",6: "Dividend",7: "Reverse Stock Split",8: "Special Dividend ",9: "Stock Repurchase",10: "Stock Split",11: "Others"} return id2tags[max_idx.item()] title="Financial Event Detection" description="Predict Finacial Events." article="modified the model in the following paper: Zhou, Z., Ma, L., & Liu, H. (2021)." example_list=[["Investors who receive dividends can choose to take them as cash or as additional shares."]] # Create the Gradio demo demo = gr.Interface(fn=predict, # mapping function from input to output inputs="text", # what are the inputs? outputs="text", # our fn has two outputs, therefore we have two outputs examples=example_list, title=title, description=description, article=article) # Launch the demo! demo.launch(debug=False, share=False)