from abc import ABC, abstractmethod | |
import torch | |
from transformers import PreTrainedTokenizerBase | |
class SensorLocFinder(ABC): | |
def __init__(self, tokenizer: PreTrainedTokenizerBase, **kwargs): | |
pass | |
def find_sensor_locs(self, input_ids: torch.Tensor) -> torch.Tensor: | |
pass | |
def __call__(self, input_ids: torch.Tensor) -> torch.Tensor: | |
return self.find_sensor_locs(input_ids) | |