import torch from bert.tokenize import extract_inputs_masks, tokenize_encode_corpus from torch.utils.data import TensorDataset, DataLoader def predict(samples, tokenizer, scaler, model, device, max_len, batch_size, return_scaled=False): model.eval() encoded_corpus = tokenize_encode_corpus(tokenizer, samples, max_len) input_ids, attention_mask = extract_inputs_masks(encoded_corpus) input_ids = torch.tensor([input_ids]).to(device)[0] attention_mask = torch.tensor([attention_mask]).to(device)[0] dataset = TensorDataset(input_ids, attention_mask) dataloader = DataLoader(dataset, batch_size) output = [] for batch in dataloader: batch_inputs, batch_masks = tuple(b.to(device) for b in batch) with torch.no_grad(): output += model(batch_inputs, batch_masks).view(1,-1).tolist()[0] if return_scaled: return output output = scaler.inverse_transform([output]) return output.reshape(1,-1).tolist()[0]