minskiter's picture
feat(model): update model parameters
57034b1
from typing import Any, Dict, Tuple
from transformers import Pipeline
from transformers.pipelines.base import GenericTensor
from transformers.utils import ModelOutput
from typing import Union,List
import torch
class EncodePipeline(Pipeline):
def __init__(self, max_length=256,*args, **kwargs):
super().__init__(*args, **kwargs)
self.max_length = max_length
def _sanitize_parameters(self, **pipeline_parameters):
return {},{},{}
def preprocess(self, input: Union[Tuple[str],List[Tuple[str]]], **preprocess_parameters: Dict) -> Dict[str, GenericTensor]:
tensors = self.tokenizer(
input,
max_length=self.max_length,
padding="max_length",
truncation=True,
return_tensors="pt",
)
return tensors
def _forward(self, input_tensors: Dict[str, GenericTensor], **forward_parameters: Dict) -> ModelOutput:
logits = self.model.encode(**input_tensors)
return logits.tolist()
def postprocess(
self,
model_outputs: ModelOutput,
**postprocess_parameters: Dict
) -> Any:
return model_outputs
class SimilarPipeline(Pipeline):
def __init__(self, max_length=256,*args, **kwargs):
super().__init__(*args, **kwargs)
self.max_length = max_length
def _sanitize_parameters(self, **pipeline_parameters):
return {},{},{}
def preprocess(self, input: Union[Tuple[str],List[Tuple[str]]], **preprocess_parameters: Dict) -> Dict[str, GenericTensor]:
if isinstance(input, list):
a = list(map(lambda x: x[0], input))
b = list(map(lambda x: x[1], input))
else:
a = input[0]
b = input[1]
tensors = self.tokenizer(
a,
max_length=self.max_length,
padding="max_length",
truncation=True,
return_tensors="pt",
)
tensors_b = self.tokenizer(
b,
max_length=self.max_length,
padding="max_length",
truncation=True,
return_tensors="pt",
)
for key in tensors:
tensors[key] = torch.cat((tensors[key],tensors_b[key]),dim=0)
return tensors
def _forward(self, input_tensors: Dict[str, GenericTensor], **forward_parameters: Dict) -> ModelOutput:
_,logits = self.model(**input_tensors)
logits_a = logits[:logits.size(0)//2]
logits_b = logits[logits.size(0)//2:]
logits = torch.nn.functional.cosine_similarity(logits_a, logits_b)
return logits.tolist()
def postprocess(
self,
model_outputs: ModelOutput,
**postprocess_parameters: Dict
) -> Any:
return model_outputs