File size: 453 Bytes
51b0cb7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 |
from abc import ABC, abstractmethod
import torch
from transformers import PreTrainedTokenizerBase
class SensorLocFinder(ABC):
@abstractmethod
def __init__(self, tokenizer: PreTrainedTokenizerBase, **kwargs):
pass
@abstractmethod
def find_sensor_locs(self, input_ids: torch.Tensor) -> torch.Tensor:
pass
def __call__(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.find_sensor_locs(input_ids)
|