from abc import abstractmethod from transformers import PretrainedConfig class MeasurementPredictorConfig(PretrainedConfig): def __init__( self, sensor_token=" omit", sensor_loc_type="locs_from_token", n_sensors=3, use_aggregated=True, sensors_weight = 0.7, aggregate_weight=0.3, **kwargs ): self.sensor_token = sensor_token self.sensor_loc_type = sensor_loc_type self.n_sensors = n_sensors self.use_aggregated = use_aggregated self.sensors_weight = sensors_weight self.aggregate_weight = aggregate_weight super().__init__(**kwargs) self.emb_dim = self.get_emb_dim() @abstractmethod def get_emb_dim(self): raise NotImplementedError