secilozksen's picture
Upload model
a2357b3
from transformers import PreTrainedModel, AutoModel, AutoTokenizer
import torch
import torch.nn as nn
from .configuration_dpr import CustomDPRConfig
from typing import Union, List, Dict
class OBSSDPRModel(PreTrainedModel):
config_class = CustomDPRConfig
def __init__(self, config):
super().__init__(config)
self.config = config
self.model = DPRModel()
def forward(self, input):
return self.model(input)
class DPRModel(nn.Module):
def __init__(self,
question_model_name='facebook/contriever-msmarco',
context_model_name='facebook/contriever-msmarco'):
super(DPRModel, self).__init__()
self.question_model = AutoModel.from_pretrained(question_model_name)
self.context_model = AutoModel.from_pretrained(context_model_name)
def freeze_layers(self, freeze_params):
num_layers_context = sum(1 for _ in self.context_model.parameters())
num_layers_question = sum(1 for _ in self.question_model.parameters())
for parameters in list(self.context_model.parameters())[:int(freeze_params * num_layers_context)]:
parameters.requires_grad = False
for parameters in list(self.context_model.parameters())[int(freeze_params * num_layers_context):]:
parameters.requires_grad = True
for parameters in list(self.question_model.parameters())[:int(freeze_params * num_layers_question)]:
parameters.requires_grad = False
for parameters in list(self.question_model.parameters())[int(freeze_params * num_layers_question):]:
parameters.requires_grad = True
def batch_dot_product(self, context_output, question_output):
mat1 = torch.unsqueeze(question_output, dim=1)
mat2 = torch.unsqueeze(context_output, dim=2)
result = torch.bmm(mat1, mat2)
result = torch.squeeze(result, dim=1)
result = torch.squeeze(result, dim=1)
return result
##FOR CONTRIEVER
def mean_pooling(self, token_embeddings, mask):
token_embeddings = token_embeddings.masked_fill(~mask[..., None].bool(), 0.)
sentence_embeddings = token_embeddings.sum(dim=1) / mask.sum(dim=1)[..., None]
return sentence_embeddings
def forward(self, batch: Union[List[Dict], Dict]):
context_tensor = batch['context_tensor']
question_tensor = batch['question_tensor']
context_model_output = self.context_model(**context_tensor)
question_model_output = self.question_model(**question_tensor)
embeddings_context = self.mean_pooling(context_model_output[0], context_tensor['attention_mask'])
embeddings_question = self.mean_pooling(question_model_output[0], question_tensor['attention_mask'])
scores = self.batch_dot_product(embeddings_context, embeddings_question) # self.scale
return scores