Spaces:
Runtime error
Runtime error
| import os | |
| import time | |
| import numpy as np | |
| import torch | |
| from transformers import BertTokenizer | |
| from bert.modeling_jointbert import JointBERT | |
| class Estimator: | |
| class Args: | |
| adam_epsilon = 1e-08 | |
| batch_size = 16 | |
| data_dir = 'data' | |
| device = 'cpu' | |
| do_eval = True | |
| do_train = False | |
| dropout_rate = 0.1 | |
| eval_batch_size = 64 | |
| gradient_accumulation_steps = 1 | |
| ignore_index = 0 | |
| intent_label_file = 'data/intent_label.txt' | |
| learning_rate = 5e-05 | |
| logging_steps = 50 | |
| max_grad_norm = 1.0 | |
| max_seq_len = 50 | |
| max_steps = -1 | |
| model_dir = 'book_model' | |
| model_name_or_path = 'bert-base-chinese' | |
| model_type = 'bert-chinese' | |
| no_cuda = False | |
| num_train_epochs = 5.0 | |
| save_steps = 200 | |
| seed = 1234 | |
| slot_label_file = 'data/slot_label.txt' | |
| slot_loss_coef = 1.0 | |
| slot_pad_label = 'PAD' | |
| task = 'book' | |
| train_batch_size = 32 | |
| use_crf = False | |
| warmup_steps = 0 | |
| weight_decay = 0.0 | |
| def __init__(self, args=Args): | |
| self.intent_label_lst = [label.strip() for label in open(args.intent_label_file, 'r', encoding='utf-8')] | |
| self.slot_label_lst = [label.strip() for label in open(args.slot_label_file, 'r', encoding='utf-8')] | |
| # Check whether model exists | |
| if not os.path.exists(args.model_dir): | |
| raise Exception("Model doesn't exists! Train first!") | |
| self.model = JointBERT.from_pretrained(args.model_dir, | |
| args=args, | |
| intent_label_lst=self.intent_label_lst, | |
| slot_label_lst=self.slot_label_lst) | |
| self.model.to(args.device) | |
| self.model.eval() | |
| self.args = args | |
| self.tokenizer = BertTokenizer.from_pretrained(self.args.model_name_or_path) | |
| def convert_input_to_tensor_data(self, input, tokenizer, pad_token_label_id, | |
| cls_token_segment_id=0, | |
| pad_token_segment_id=0, | |
| sequence_a_segment_id=0, | |
| mask_padding_with_zero=True): | |
| # Setting based on the current model type | |
| cls_token = tokenizer.cls_token | |
| sep_token = tokenizer.sep_token | |
| unk_token = tokenizer.unk_token | |
| pad_token_id = tokenizer.pad_token_id | |
| slot_label_mask = [] | |
| words = list(input) | |
| tokens = [] | |
| for word in words: | |
| word_tokens = tokenizer.tokenize(word) | |
| if not word_tokens: | |
| word_tokens = [unk_token] # For handling the bad-encoded word | |
| tokens.extend(word_tokens) | |
| # Use the real label id for the first token of the word, and padding ids for the remaining tokens | |
| slot_label_mask.extend([pad_token_label_id + 1] + [pad_token_label_id] * (len(word_tokens) - 1)) | |
| # Account for [CLS] and [SEP] | |
| special_tokens_count = 2 | |
| if len(tokens) > self.args.max_seq_len - special_tokens_count: | |
| tokens = tokens[: (self.args.max_seq_len - special_tokens_count)] | |
| slot_label_mask = slot_label_mask[:(self.args.max_seq_len - special_tokens_count)] | |
| # Add [SEP] token | |
| tokens += [sep_token] | |
| token_type_ids = [sequence_a_segment_id] * len(tokens) | |
| slot_label_mask += [pad_token_label_id] | |
| # Add [CLS] token | |
| tokens = [cls_token] + tokens | |
| token_type_ids = [cls_token_segment_id] + token_type_ids | |
| slot_label_mask = [pad_token_label_id] + slot_label_mask | |
| input_ids = tokenizer.convert_tokens_to_ids(tokens) | |
| # The mask has 1 for real tokens and 0 for padding tokens. Only real tokens are attended to. | |
| attention_mask = [1 if mask_padding_with_zero else 0] * len(input_ids) | |
| # Zero-pad up to the sequence length. | |
| padding_length = self.args.max_seq_len - len(input_ids) | |
| input_ids = input_ids + ([pad_token_id] * padding_length) | |
| attention_mask = attention_mask + ([0 if mask_padding_with_zero else 1] * padding_length) | |
| token_type_ids = token_type_ids + ([pad_token_segment_id] * padding_length) | |
| slot_label_mask = slot_label_mask + ([pad_token_label_id] * padding_length) | |
| # Change to Tensor | |
| input_ids = torch.tensor([input_ids], dtype=torch.long) | |
| attention_mask = torch.tensor([attention_mask], dtype=torch.long) | |
| token_type_ids = torch.tensor([token_type_ids], dtype=torch.long) | |
| slot_label_mask = torch.tensor([slot_label_mask], dtype=torch.long) | |
| data = [input_ids, attention_mask, token_type_ids, slot_label_mask] | |
| return data | |
| def predict(self, input): | |
| # Convert input file to TensorDataset | |
| pad_token_label_id = self.args.ignore_index | |
| batch = self.convert_input_to_tensor_data(input, self.tokenizer, pad_token_label_id) | |
| # Predict | |
| batch = tuple(t.to(self.args.device) for t in batch) | |
| with torch.no_grad(): | |
| inputs = {"input_ids": batch[0], | |
| "attention_mask": batch[1], | |
| "token_type_ids": batch[2], | |
| "intent_label_ids": None, | |
| "slot_labels_ids": None} | |
| outputs = self.model(**inputs) | |
| _, (intent_logits, slot_logits) = outputs[:2] | |
| # Intent Prediction | |
| intent_pred = intent_logits.detach().cpu().numpy() | |
| # Slot prediction | |
| if self.args.use_crf: | |
| # decode() in `torchcrf` returns list with best index directly | |
| slot_preds = np.array(self.model.crf.decode(slot_logits)) | |
| else: | |
| slot_preds = slot_logits.detach().cpu().numpy() | |
| all_slot_label_mask = batch[3].detach().cpu().numpy() | |
| intent_pred = np.argmax(intent_pred, axis=1)[0] | |
| if not self.args.use_crf: | |
| slot_preds = np.argmax(slot_preds, axis=2) | |
| slot_label_map = {i: label for i, label in enumerate(self.slot_label_lst)} | |
| slot_preds_list = [] | |
| for i in range(slot_preds.shape[1]): | |
| if all_slot_label_mask[0, i] != pad_token_label_id: | |
| slot_preds_list.append(slot_label_map[slot_preds[0][i]]) | |
| words = list(input) | |
| slots = dict() | |
| slot = str() | |
| for i in range(len(words)): | |
| if slot_preds_list[i] == 'O': | |
| if slot == '': | |
| continue | |
| slots[slot_preds_list[i - 1].split('-')[1]] = slot | |
| slot = str() | |
| else: | |
| slot += words[i] | |
| if slot != '': | |
| slots[slot_preds_list[len(words) - 1].split('-')[1]] = slot | |
| return self.intent_label_lst[intent_pred], slots | |
| if __name__ == "__main__": | |
| e = Estimator() | |
| while True: | |
| print(e.predict(input(">>"))) | |