BookSearch / predictOnce.py
xyh1756
first commit
ad16774
import os
import time
import numpy as np
import torch
from transformers import BertTokenizer
from bert.modeling_jointbert import JointBERT
class Estimator:
class Args:
adam_epsilon = 1e-08
batch_size = 16
data_dir = 'data'
device = 'cpu'
do_eval = True
do_train = False
dropout_rate = 0.1
eval_batch_size = 64
gradient_accumulation_steps = 1
ignore_index = 0
intent_label_file = 'data/intent_label.txt'
learning_rate = 5e-05
logging_steps = 50
max_grad_norm = 1.0
max_seq_len = 50
max_steps = -1
model_dir = 'book_model'
model_name_or_path = 'bert-base-chinese'
model_type = 'bert-chinese'
no_cuda = False
num_train_epochs = 5.0
save_steps = 200
seed = 1234
slot_label_file = 'data/slot_label.txt'
slot_loss_coef = 1.0
slot_pad_label = 'PAD'
task = 'book'
train_batch_size = 32
use_crf = False
warmup_steps = 0
weight_decay = 0.0
def __init__(self, args=Args):
self.intent_label_lst = [label.strip() for label in open(args.intent_label_file, 'r', encoding='utf-8')]
self.slot_label_lst = [label.strip() for label in open(args.slot_label_file, 'r', encoding='utf-8')]
# Check whether model exists
if not os.path.exists(args.model_dir):
raise Exception("Model doesn't exists! Train first!")
self.model = JointBERT.from_pretrained(args.model_dir,
args=args,
intent_label_lst=self.intent_label_lst,
slot_label_lst=self.slot_label_lst)
self.model.to(args.device)
self.model.eval()
self.args = args
self.tokenizer = BertTokenizer.from_pretrained(self.args.model_name_or_path)
def convert_input_to_tensor_data(self, input, tokenizer, pad_token_label_id,
cls_token_segment_id=0,
pad_token_segment_id=0,
sequence_a_segment_id=0,
mask_padding_with_zero=True):
# Setting based on the current model type
cls_token = tokenizer.cls_token
sep_token = tokenizer.sep_token
unk_token = tokenizer.unk_token
pad_token_id = tokenizer.pad_token_id
slot_label_mask = []
words = list(input)
tokens = []
for word in words:
word_tokens = tokenizer.tokenize(word)
if not word_tokens:
word_tokens = [unk_token] # For handling the bad-encoded word
tokens.extend(word_tokens)
# Use the real label id for the first token of the word, and padding ids for the remaining tokens
slot_label_mask.extend([pad_token_label_id + 1] + [pad_token_label_id] * (len(word_tokens) - 1))
# Account for [CLS] and [SEP]
special_tokens_count = 2
if len(tokens) > self.args.max_seq_len - special_tokens_count:
tokens = tokens[: (self.args.max_seq_len - special_tokens_count)]
slot_label_mask = slot_label_mask[:(self.args.max_seq_len - special_tokens_count)]
# Add [SEP] token
tokens += [sep_token]
token_type_ids = [sequence_a_segment_id] * len(tokens)
slot_label_mask += [pad_token_label_id]
# Add [CLS] token
tokens = [cls_token] + tokens
token_type_ids = [cls_token_segment_id] + token_type_ids
slot_label_mask = [pad_token_label_id] + slot_label_mask
input_ids = tokenizer.convert_tokens_to_ids(tokens)
# The mask has 1 for real tokens and 0 for padding tokens. Only real tokens are attended to.
attention_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)
# Zero-pad up to the sequence length.
padding_length = self.args.max_seq_len - len(input_ids)
input_ids = input_ids + ([pad_token_id] * padding_length)
attention_mask = attention_mask + ([0 if mask_padding_with_zero else 1] * padding_length)
token_type_ids = token_type_ids + ([pad_token_segment_id] * padding_length)
slot_label_mask = slot_label_mask + ([pad_token_label_id] * padding_length)
# Change to Tensor
input_ids = torch.tensor([input_ids], dtype=torch.long)
attention_mask = torch.tensor([attention_mask], dtype=torch.long)
token_type_ids = torch.tensor([token_type_ids], dtype=torch.long)
slot_label_mask = torch.tensor([slot_label_mask], dtype=torch.long)
data = [input_ids, attention_mask, token_type_ids, slot_label_mask]
return data
def predict(self, input):
# Convert input file to TensorDataset
pad_token_label_id = self.args.ignore_index
batch = self.convert_input_to_tensor_data(input, self.tokenizer, pad_token_label_id)
# Predict
batch = tuple(t.to(self.args.device) for t in batch)
with torch.no_grad():
inputs = {"input_ids": batch[0],
"attention_mask": batch[1],
"token_type_ids": batch[2],
"intent_label_ids": None,
"slot_labels_ids": None}
outputs = self.model(**inputs)
_, (intent_logits, slot_logits) = outputs[:2]
# Intent Prediction
intent_pred = intent_logits.detach().cpu().numpy()
# Slot prediction
if self.args.use_crf:
# decode() in `torchcrf` returns list with best index directly
slot_preds = np.array(self.model.crf.decode(slot_logits))
else:
slot_preds = slot_logits.detach().cpu().numpy()
all_slot_label_mask = batch[3].detach().cpu().numpy()
intent_pred = np.argmax(intent_pred, axis=1)[0]
if not self.args.use_crf:
slot_preds = np.argmax(slot_preds, axis=2)
slot_label_map = {i: label for i, label in enumerate(self.slot_label_lst)}
slot_preds_list = []
for i in range(slot_preds.shape[1]):
if all_slot_label_mask[0, i] != pad_token_label_id:
slot_preds_list.append(slot_label_map[slot_preds[0][i]])
words = list(input)
slots = dict()
slot = str()
for i in range(len(words)):
if slot_preds_list[i] == 'O':
if slot == '':
continue
slots[slot_preds_list[i - 1].split('-')[1]] = slot
slot = str()
else:
slot += words[i]
if slot != '':
slots[slot_preds_list[len(words) - 1].split('-')[1]] = slot
return self.intent_label_lst[intent_pred], slots
if __name__ == "__main__":
e = Estimator()
while True:
print(e.predict(input(">>")))