| import re | |
| from typing import List, Tuple | |
| import pathlib | |
| import torch | |
| from transformers import BertTokenizer | |
| from utils.sentence_retrieval_model import sentence_retrieval_model | |
| THIS_DIR = pathlib.Path(__file__).parent.absolute() | |
| ARGS = { | |
| 'batch_size': 32, | |
| 'bert_pretrain': 'base/bert_base', | |
| 'checkpoint': 'base/model.best.32.pt', | |
| 'dropout': 0.6, | |
| 'bert_hidden_dim': 768, | |
| 'max_len': 384, | |
| 'cuda': torch.cuda.is_available() | |
| } | |
| if not ARGS['cuda']: | |
| print('CUDA NOT AVAILABLE') | |
| def process_sent(sentence): | |
| sentence = re.sub("LSB.*?RSB", "", sentence) | |
| sentence = re.sub("LRB\s*?RRB", "", sentence) | |
| sentence = re.sub("(\s*?)LRB((\s*?))", "\\1(\\2", sentence) | |
| sentence = re.sub("(\s*?)RRB((\s*?))", "\\1)\\2", sentence) | |
| sentence = re.sub("--", "-", sentence) | |
| sentence = re.sub("``", '"', sentence) | |
| sentence = re.sub("''", '"', sentence) | |
| return sentence | |
| class SentenceRetrievalModule(): | |
| def __init__(self, max_len=None): | |
| if max_len: | |
| ARGS['max_len'] = max_len | |
| self.tokenizer = BertTokenizer.from_pretrained(ARGS['bert_pretrain'], do_lower_case=False) | |
| self.model = sentence_retrieval_model(ARGS) | |
| self.model.load_state_dict(torch.load(ARGS['checkpoint'], map_location=torch.device('cpu'))['model']) | |
| if ARGS['cuda']: | |
| self.model = self.model.cuda() | |
| def score_sentence_pairs(self, inputs: List[Tuple[str]]): | |
| inputs_processed = [(process_sent(input[0]), process_sent(input[1])) for input in inputs] | |
| encodings = self.tokenizer( | |
| inputs_processed, | |
| padding='max_length', | |
| truncation='longest_first', | |
| max_length=ARGS['max_len'], | |
| return_token_type_ids=True, | |
| return_attention_mask=True, | |
| return_tensors='pt', | |
| ) | |
| inp = encodings['input_ids'] | |
| msk = encodings['attention_mask'] | |
| seg = encodings['token_type_ids'] | |
| if ARGS['cuda']: | |
| inp = inp.cuda() | |
| msk = msk.cuda() | |
| seg = seg.cuda() | |
| self.model.eval() | |
| with torch.no_grad(): | |
| outputs = self.model(inp, msk, seg).tolist() | |
| assert len(outputs) == len(inputs) | |
| return outputs |