import os import pandas as pd import numpy as np import json from collections import OrderedDict import torch from torch import nn from transformers import DistilBertTokenizer from transformers import pipeline class ClassificationModel(nn.Module): def __init__(self, text_fe, device, summary_hid_size=1024, title_hid_size=256, output=126, dropout_p=0.2): super(ClassificationModel, self).__init__() self.text_fe = text_fe.model.distilbert for param in self.text_fe.parameters(): param.requires_grad = False self.device = device features_size = 768 self.summary = nn.Sequential(OrderedDict([ ('bnorm0', nn.BatchNorm1d(features_size)), ('in2hid', nn.Linear(features_size, summary_hid_size)), ('act', nn.ReLU()), ('drop', nn.Dropout(dropout_p)), ('bnorm1', nn.BatchNorm1d(summary_hid_size)), ('hid2out', nn.Linear(summary_hid_size, output)), ('log_soft', nn.LogSoftmax(dim=-1)), ])) self.title = nn.Sequential(OrderedDict([ ('bnorm0', nn.BatchNorm1d(features_size)), ('in2hid', nn.Linear(features_size, title_hid_size)), ('act', nn.ReLU()), ('drop', nn.Dropout(dropout_p)), ('bnorm1', nn.BatchNorm1d(title_hid_size)), ('hid2out', nn.Linear(title_hid_size, output)), ('log_soft', nn.LogSoftmax(dim=-1)), ])) def forward(self, title_batch, summary_batch=None): if len(title_batch['input_ids'].size()) > 2: title_embeding = self.text_fe(title_batch['input_ids'].squeeze(), title_batch['attention_mask'].squeeze()).last_hidden_state[:, 0, :] title_probs = self.title(title_embeding) if summary_batch is not None: summary_embeding = self.text_fe(summary_batch['input_ids'].squeeze(), summary_batch['attention_mask'].squeeze()).last_hidden_state[:, 0, :] summary_probs = self.summary(summary_embeding) return title_probs, summary_probs return title_probs title_embeding = self.text_fe(title_batch['input_ids'], title_batch['attention_mask']).last_hidden_state[:, 0, : ] title_probs = self.title(title_embeding) if summary_batch != None: summary_embeding = self.text_fe(summary_batch['input_ids'], summary_batch['attention_mask']).last_hidden_state[:, 0, :] summary_probs = self.summary(summary_embeding) return title_probs, summary_probs return title_probs def create_model_and_optimizer(model_class, model_params, lr=1e-5): model = model_class(**model_params).float() model = model.to(model_params['device']) params = [] for param in model.parameters(): if param.requires_grad: params.append(param) beta1 = 0.9 beta2 = 0.999 optimizer = torch.optim.Adam(params, lr, [beta1, beta2]) return model, optimizer def get_input(): print('Write title') title = input() print('Write summary') summary = input() return title, summary def get_prediction(tokenizer, model, device, title, summary=None): if summary is not None: title_tokenized = tokenizer(title, max_length=33, return_tensors='pt', truncation=True, padding='max_length') summary_tokenized = tokenizer(summary, max_length=512, return_tensors='pt', truncation=True, padding='max_length') from_title, from_summary = model(title_tokenized.to(device), summary_tokenized.to(device)) summary_predictions = torch.argsort(from_summary, dim=1, descending=True)[0, :5] return summary_predictions title_tokenized = tokenizer(title, max_length=33, return_tensors='pt', truncation=True, padding='max_length') from_title, from_summary = model(title_tokenized.to(device), None) title_predictions = torch.argsort(from_title, dim=1, descending=True)[0, :5] return title_predictions def load_all(model_path, dict_of_term_path): with open(dict_of_term_path, "r") as f: dict_of_term = json.load(f) token = DistilBertTokenizer.from_pretrained('distilbert-base-uncased') distilbert = pipeline('text-classification', model='distilbert-base-uncased') new_model, _ = create_model_and_optimizer( model_class=ClassificationModel, model_params={ 'device': device, 'text_fe': distilbert, 'summary_hid_size': 1024, 'title_hid_size': 256, 'dropout_p': 0.2, }, lr=1e-4, ) checkpoint = torch.load(model_path) new_model.load_state_dict(checkpoint['model_state_dict']) return token, new_model, dict_of_term if __name__ == '__main__': device = torch.device('cpu') print(torch.cuda.get_device_properties(device)) SEED = 42 np.random.seed(SEED) torch.manual_seed(SEED) tokenizer, model, dict_of_term = load_all('chkp/model#21/model#21#11.pt', 'dict_of_terms.json') title, summary = get_input() prediction = get_prediction(tokenizer, model, device, title, summary) reversed_dict_of_term = {dict_of_term[key]: key for key in dict_of_term.keys()} for i, ind in enumerate(prediction): print('Place#{0} take term {1} (with number {2})'.format(i, reversed_dict_of_term[ind], ind))