import torch import torch.nn as nn from transformers import PreTrainedModel import logging import floret from .configuration_lang import ImpressoConfig logger = logging.getLogger(__name__) class LangDetectorModel(PreTrainedModel): config_class = ImpressoConfig def __init__(self, config): super().__init__(config) self.config = config # Dummy for device checking self.dummy_param = nn.Parameter(torch.zeros(1)) # Load floret model self.model_floret = floret.load_model(self.config.config.filename) # def forward(self, input_ids, **kwargs): if isinstance(input_ids, str): # If the input is a single string, make it a list for floret texts = [input_ids] elif isinstance(input_ids, list) and all(isinstance(t, str) for t in input_ids): texts = input_ids else: raise ValueError(f"Unexpected input type: {type(input_ids)}") predictions, probabilities = self.model_floret.predict(texts, k=1) return ( predictions, probabilities, ) @property def device(self): return next(self.parameters()).device @classmethod def from_pretrained(cls, *args, **kwargs): # print("Ignoring weights and using custom initialization.") # Manually create the config config = ImpressoConfig(**kwargs) # Pass the manually created config to the class model = cls(config) return model