File size: 795 Bytes
2b97f0b 7e59788 2b97f0b 7e59788 2b97f0b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 |
from abc import abstractmethod
from transformers import PretrainedConfig
class MeasurementPredictorConfig(PretrainedConfig):
def __init__(
self,
sensor_token=" omit",
sensor_loc_type="locs_from_token",
n_sensors=3,
use_aggregated=True,
sensors_weight = 0.7,
aggregate_weight=0.3,
**kwargs
):
self.sensor_token = sensor_token
self.sensor_loc_type = sensor_loc_type
self.n_sensors = n_sensors
self.use_aggregated = use_aggregated
self.sensors_weight = sensors_weight
self.aggregate_weight = aggregate_weight
super().__init__(**kwargs)
self.emb_dim = self.get_emb_dim()
@abstractmethod
def get_emb_dim(self):
raise NotImplementedError |