File size: 2,267 Bytes
da968a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from prime_radiant.configs.pr_config import PRConfig
from prime_radiant.models import MODEL_CLASSES
from prime_radiant.models.pr_model import PRModel


class LaBSeModelConfig(PRConfig):
    """
    Configuration class for LaBSeModel.
    """
    model_type = "LaBSe"

    def __init__(self, **kwargs):
        # Validate and initialize the config
        super(LaBSeModelConfig, self).__init__(**kwargs)

    def validate(self, **kwargs: dict):
        # Perform validation of the parent
        super(LaBSeModelConfig, self).validate(**kwargs)
        # No further validation since the base model does not take any more
        # custom configuration


class LaBSeModel(PRModel):
    """
    LaBSeModel class for using LaBSe-based models.
    """
    config_class = LaBSeModelConfig

    def __init__(self, config: LaBSeModelConfig, ):
        # Initialize the PRModel superclass
        super().__init__(config=config)

        # Get the appropriate model class based on the configuration
        model_class = MODEL_CLASSES[config.pr_model_name]

        # Load the SentenceTransformer model
        # This should be improved by using the actual model from model_class
        # instead of the `.sbert()` method which doesn't exist
        self.model = model_class.model.sbert(model_class.checkpoint)

        # Set the maximum sequence length
        self.model.max_seq_length = 512

        # Get the tokenizer associated with the model
        self.tokenizer = self.model.tokenizer

    def predict(self, text: str):
        # Encode a single text
        return self.model.encode(text)

    def predict_batch(self, text_list: list):
        # Encode a list of texts
        return self.model.encode(text_list)

    def push_to_hub(self, only_base_model_flag=True, *args, **kwargs):
        """
        Push the base model to the hub
        We need to use a different way of pushing it so that we can use the
        base model to instantiate other models like ROC.
        """
        # Remove only_base_model from kwargs
        if "only_base_model" in kwargs:
            del kwargs["only_base_model"]

        # Call the superclass method with the modified arguments
        return super().push_to_hub(only_base_model=only_base_model_flag, *args, **kwargs)