pipeline1 / impresso_langident_wrapper.py
Gleb Vinarskis
restructured config
1c7044c
raw
history blame
804 Bytes
from transformers import Pipeline
from transformers.pipelines import PIPELINE_REGISTRY
class Pipeline_One(Pipeline):
def _sanitize_parameters(self, **kwargs):
# Add any additional parameter handling if necessary
return kwargs, {}, {}
def preprocess(self, text, **kwargs):
return text
def _forward(self, inputs):
model_output = self.model.predict(inputs, k=1)
return model_output
def postprocess(self, outputs, **kwargs):
return outputs
from transformers import AutoModelForSequenceClassification, TFAutoModelForSequenceClassification
PIPELINE_REGISTRY.register_pipeline(
"language-detection",
pipeline_class=Pipeline_One,
pt_model=AutoModelForSequenceClassification,
tf_model=TFAutoModelForSequenceClassification,
)