Prove_KCL / utils /sentence_retrieval_module.py
Jongmo's picture
Upload 25 files
a5bbcdb verified
raw
history blame
No virus
2.28 kB
import re
from typing import List, Tuple
import pathlib
import torch
from transformers import BertTokenizer
from utils.sentence_retrieval_model import sentence_retrieval_model
THIS_DIR = pathlib.Path(__file__).parent.absolute()
ARGS = {
'batch_size': 32,
'bert_pretrain': 'base/bert_base',
'checkpoint': 'base/model.best.32.pt',
'dropout': 0.6,
'bert_hidden_dim': 768,
'max_len': 384,
'cuda': torch.cuda.is_available()
}
if not ARGS['cuda']:
print('CUDA NOT AVAILABLE')
def process_sent(sentence):
sentence = re.sub("LSB.*?RSB", "", sentence)
sentence = re.sub("LRB\s*?RRB", "", sentence)
sentence = re.sub("(\s*?)LRB((\s*?))", "\\1(\\2", sentence)
sentence = re.sub("(\s*?)RRB((\s*?))", "\\1)\\2", sentence)
sentence = re.sub("--", "-", sentence)
sentence = re.sub("``", '"', sentence)
sentence = re.sub("''", '"', sentence)
return sentence
class SentenceRetrievalModule():
def __init__(self, max_len=None):
if max_len:
ARGS['max_len'] = max_len
self.tokenizer = BertTokenizer.from_pretrained(ARGS['bert_pretrain'], do_lower_case=False)
self.model = sentence_retrieval_model(ARGS)
self.model.load_state_dict(torch.load(ARGS['checkpoint'], map_location=torch.device('cpu'))['model'])
if ARGS['cuda']:
self.model = self.model.cuda()
def score_sentence_pairs(self, inputs: List[Tuple[str]]):
inputs_processed = [(process_sent(input[0]), process_sent(input[1])) for input in inputs]
encodings = self.tokenizer(
inputs_processed,
padding='max_length',
truncation='longest_first',
max_length=ARGS['max_len'],
return_token_type_ids=True,
return_attention_mask=True,
return_tensors='pt',
)
inp = encodings['input_ids']
msk = encodings['attention_mask']
seg = encodings['token_type_ids']
if ARGS['cuda']:
inp = inp.cuda()
msk = msk.cuda()
seg = seg.cuda()
self.model.eval()
with torch.no_grad():
outputs = self.model(inp, msk, seg).tolist()
assert len(outputs) == len(inputs)
return outputs