|
import re |
|
from typing import List, Tuple |
|
import pathlib |
|
|
|
import torch |
|
from transformers import BertTokenizer |
|
|
|
from utils.sentence_retrieval_model import sentence_retrieval_model |
|
|
|
|
|
THIS_DIR = pathlib.Path(__file__).parent.absolute() |
|
ARGS = { |
|
'batch_size': 32, |
|
'bert_pretrain': 'base/bert_base', |
|
'checkpoint': 'base/model.best.32.pt', |
|
'dropout': 0.6, |
|
'bert_hidden_dim': 768, |
|
'max_len': 384, |
|
'cuda': torch.cuda.is_available() |
|
} |
|
|
|
if not ARGS['cuda']: |
|
print('CUDA NOT AVAILABLE') |
|
|
|
|
|
def process_sent(sentence): |
|
sentence = re.sub("LSB.*?RSB", "", sentence) |
|
sentence = re.sub("LRB\s*?RRB", "", sentence) |
|
sentence = re.sub("(\s*?)LRB((\s*?))", "\\1(\\2", sentence) |
|
sentence = re.sub("(\s*?)RRB((\s*?))", "\\1)\\2", sentence) |
|
sentence = re.sub("--", "-", sentence) |
|
sentence = re.sub("``", '"', sentence) |
|
sentence = re.sub("''", '"', sentence) |
|
return sentence |
|
|
|
class SentenceRetrievalModule(): |
|
|
|
def __init__(self, max_len=None): |
|
|
|
if max_len: |
|
ARGS['max_len'] = max_len |
|
|
|
self.tokenizer = BertTokenizer.from_pretrained(ARGS['bert_pretrain'], do_lower_case=False) |
|
self.model = sentence_retrieval_model(ARGS) |
|
self.model.load_state_dict(torch.load(ARGS['checkpoint'], map_location=torch.device('cpu'))['model']) |
|
if ARGS['cuda']: |
|
self.model = self.model.cuda() |
|
|
|
def score_sentence_pairs(self, inputs: List[Tuple[str]]): |
|
inputs_processed = [(process_sent(input[0]), process_sent(input[1])) for input in inputs] |
|
|
|
encodings = self.tokenizer( |
|
inputs_processed, |
|
padding='max_length', |
|
truncation='longest_first', |
|
max_length=ARGS['max_len'], |
|
return_token_type_ids=True, |
|
return_attention_mask=True, |
|
return_tensors='pt', |
|
) |
|
|
|
inp = encodings['input_ids'] |
|
msk = encodings['attention_mask'] |
|
seg = encodings['token_type_ids'] |
|
|
|
if ARGS['cuda']: |
|
inp = inp.cuda() |
|
msk = msk.cuda() |
|
seg = seg.cuda() |
|
|
|
self.model.eval() |
|
with torch.no_grad(): |
|
outputs = self.model(inp, msk, seg).tolist() |
|
|
|
assert len(outputs) == len(inputs) |
|
|
|
return outputs |