File size: 959 Bytes
075c00d
 
8e5dcff
a4c6322
075c00d
 
 
 
 
 
 
 
a0a2517
137cb44
075c00d
a0a2517
03cf626
 
075c00d
a0a2517
 
03cf626
 
 
 
3776ec2
03cf626
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
from transformers import Pipeline


class LangDetectionPipeline(Pipeline):

    def _sanitize_parameters(self, **kwargs):
        preprocess_kwargs = {}
        if "text" in kwargs:
            preprocess_kwargs["text"] = kwargs["text"]
        return preprocess_kwargs, {}, {}

    def preprocess(self, text, **kwargs):
        # Nothing to preprocess
        return text

    def _forward(self, text, **kwargs):
        predictions, probabilities = self.model(text)
        return predictions, probabilities

    def postprocess(self, outputs, **kwargs):
        predictions, probabilities = outputs
        label = predictions[0][0].replace("__label__", "")  # Remove __label__ prefix
        confidence = float(
            probabilities[0][0]
        )  # Convert to float for JSON serialization

        # Format as JSON-compatible dictionary
        model_output = {"label": label, "confidence": round(confidence * 100, 2)}
        return model_output