PaperClassifier / app.py
Koltochen's picture
Upload 3 files
a553f22
import os
import pandas as pd
import numpy as np
import json
from collections import OrderedDict
import torch
from torch import nn
from transformers import DistilBertTokenizer
from transformers import pipeline
class ClassificationModel(nn.Module):
def __init__(self, text_fe, device, summary_hid_size=1024, title_hid_size=256, output=126, dropout_p=0.2):
super(ClassificationModel, self).__init__()
self.text_fe = text_fe.model.distilbert
for param in self.text_fe.parameters():
param.requires_grad = False
self.device = device
features_size = 768
self.summary = nn.Sequential(OrderedDict([
('bnorm0', nn.BatchNorm1d(features_size)),
('in2hid', nn.Linear(features_size, summary_hid_size)),
('act', nn.ReLU()),
('drop', nn.Dropout(dropout_p)),
('bnorm1', nn.BatchNorm1d(summary_hid_size)),
('hid2out', nn.Linear(summary_hid_size, output)),
('log_soft', nn.LogSoftmax(dim=-1)),
]))
self.title = nn.Sequential(OrderedDict([
('bnorm0', nn.BatchNorm1d(features_size)),
('in2hid', nn.Linear(features_size, title_hid_size)),
('act', nn.ReLU()),
('drop', nn.Dropout(dropout_p)),
('bnorm1', nn.BatchNorm1d(title_hid_size)),
('hid2out', nn.Linear(title_hid_size, output)),
('log_soft', nn.LogSoftmax(dim=-1)),
]))
def forward(self, title_batch, summary_batch=None):
if len(title_batch['input_ids'].size()) > 2:
title_embeding = self.text_fe(title_batch['input_ids'].squeeze(),
title_batch['attention_mask'].squeeze()).last_hidden_state[:, 0, :]
title_probs = self.title(title_embeding)
if summary_batch is not None:
summary_embeding = self.text_fe(summary_batch['input_ids'].squeeze(),
summary_batch['attention_mask'].squeeze()).last_hidden_state[:, 0, :]
summary_probs = self.summary(summary_embeding)
return title_probs, summary_probs
return title_probs
title_embeding = self.text_fe(title_batch['input_ids'], title_batch['attention_mask']).last_hidden_state[:, 0, :
]
title_probs = self.title(title_embeding)
if summary_batch != None:
summary_embeding = self.text_fe(summary_batch['input_ids'],
summary_batch['attention_mask']).last_hidden_state[:, 0, :]
summary_probs = self.summary(summary_embeding)
return title_probs, summary_probs
return title_probs
def create_model_and_optimizer(model_class, model_params, lr=1e-5):
model = model_class(**model_params).float()
model = model.to(model_params['device'])
params = []
for param in model.parameters():
if param.requires_grad:
params.append(param)
beta1 = 0.9
beta2 = 0.999
optimizer = torch.optim.Adam(params, lr, [beta1, beta2])
return model, optimizer
def get_input():
print('Write title')
title = input()
print('Write summary')
summary = input()
return title, summary
def get_prediction(tokenizer, model, device, title, summary=None):
if summary is not None:
title_tokenized = tokenizer(title, max_length=33, return_tensors='pt', truncation=True, padding='max_length')
summary_tokenized = tokenizer(summary, max_length=512, return_tensors='pt', truncation=True,
padding='max_length')
from_title, from_summary = model(title_tokenized.to(device), summary_tokenized.to(device))
summary_predictions = torch.argsort(from_summary, dim=1, descending=True)[0, :5]
return summary_predictions
title_tokenized = tokenizer(title, max_length=33, return_tensors='pt', truncation=True, padding='max_length')
from_title, from_summary = model(title_tokenized.to(device), None)
title_predictions = torch.argsort(from_title, dim=1, descending=True)[0, :5]
return title_predictions
def load_all(model_path, dict_of_term_path):
with open(dict_of_term_path, "r") as f:
dict_of_term = json.load(f)
token = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
distilbert = pipeline('text-classification', model='distilbert-base-uncased')
new_model, _ = create_model_and_optimizer(
model_class=ClassificationModel,
model_params={
'device': device,
'text_fe': distilbert,
'summary_hid_size': 1024,
'title_hid_size': 256,
'dropout_p': 0.2,
},
lr=1e-4,
)
checkpoint = torch.load(model_path)
new_model.load_state_dict(checkpoint['model_state_dict'])
return token, new_model, dict_of_term
if __name__ == '__main__':
device = torch.device('cpu')
print(torch.cuda.get_device_properties(device))
SEED = 42
np.random.seed(SEED)
torch.manual_seed(SEED)
tokenizer, model, dict_of_term = load_all('chkp/model#21/model#21#11.pt', 'dict_of_terms.json')
title, summary = get_input()
prediction = get_prediction(tokenizer, model, device, title, summary)
reversed_dict_of_term = {dict_of_term[key]: key for key in dict_of_term.keys()}
for i, ind in enumerate(prediction):
print('Place#{0} take term {1} (with number {2})'.format(i, reversed_dict_of_term[ind], ind))