|
from typing import Any, Dict, Tuple |
|
from transformers import Pipeline |
|
from transformers.pipelines.base import GenericTensor |
|
from transformers.utils import ModelOutput |
|
from typing import Union,List |
|
import torch |
|
|
|
class EncodePipeline(Pipeline): |
|
def __init__(self, max_length=256,*args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
self.max_length = max_length |
|
|
|
def _sanitize_parameters(self, **pipeline_parameters): |
|
return {},{},{} |
|
|
|
def preprocess(self, input: Union[Tuple[str],List[Tuple[str]]], **preprocess_parameters: Dict) -> Dict[str, GenericTensor]: |
|
tensors = self.tokenizer( |
|
input, |
|
max_length=self.max_length, |
|
padding="max_length", |
|
truncation=True, |
|
return_tensors="pt", |
|
) |
|
return tensors |
|
|
|
def _forward(self, input_tensors: Dict[str, GenericTensor], **forward_parameters: Dict) -> ModelOutput: |
|
logits = self.model.encode(**input_tensors) |
|
return logits.tolist() |
|
|
|
def postprocess( |
|
self, |
|
model_outputs: ModelOutput, |
|
**postprocess_parameters: Dict |
|
) -> Any: |
|
return model_outputs |
|
|
|
|
|
class SimilarPipeline(Pipeline): |
|
def __init__(self, max_length=256,*args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
self.max_length = max_length |
|
|
|
def _sanitize_parameters(self, **pipeline_parameters): |
|
return {},{},{} |
|
|
|
def preprocess(self, input: Union[Tuple[str],List[Tuple[str]]], **preprocess_parameters: Dict) -> Dict[str, GenericTensor]: |
|
if isinstance(input, list): |
|
a = list(map(lambda x: x[0], input)) |
|
b = list(map(lambda x: x[1], input)) |
|
else: |
|
a = input[0] |
|
b = input[1] |
|
tensors = self.tokenizer( |
|
a, |
|
max_length=self.max_length, |
|
padding="max_length", |
|
truncation=True, |
|
return_tensors="pt", |
|
) |
|
tensors_b = self.tokenizer( |
|
b, |
|
max_length=self.max_length, |
|
padding="max_length", |
|
truncation=True, |
|
return_tensors="pt", |
|
) |
|
for key in tensors: |
|
tensors[key] = torch.cat((tensors[key],tensors_b[key]),dim=0) |
|
return tensors |
|
|
|
def _forward(self, input_tensors: Dict[str, GenericTensor], **forward_parameters: Dict) -> ModelOutput: |
|
_,logits = self.model(**input_tensors) |
|
logits_a = logits[:logits.size(0)//2] |
|
logits_b = logits[logits.size(0)//2:] |
|
logits = torch.nn.functional.cosine_similarity(logits_a, logits_b) |
|
return logits.tolist() |
|
|
|
def postprocess( |
|
self, |
|
model_outputs: ModelOutput, |
|
**postprocess_parameters: Dict |
|
) -> Any: |
|
return model_outputs |