Prove_KCL / utils /sentence_retrieval_model.py
Jongmo's picture
Upload 25 files
a5bbcdb verified
raw
history blame
751 Bytes
import torch
import torch.nn as nn
from utils.bert_model import BertForSequenceEncoder
class sentence_retrieval_model(nn.Module):
def __init__(self, args):
super(sentence_retrieval_model, self).__init__()
self.pred_model = BertForSequenceEncoder.from_pretrained(args['bert_pretrain'])
self.bert_hidden_dim = args['bert_hidden_dim']
self.dropout = nn.Dropout(args['dropout'])
self.proj_match = nn.Linear(self.bert_hidden_dim, 1)
def forward(self, inp_tensor, msk_tensor, seg_tensor):
_, inputs = self.pred_model(inp_tensor, msk_tensor, seg_tensor)
inputs = self.dropout(inputs)
score = self.proj_match(inputs).squeeze(-1)
score = torch.tanh(score)
return score