File size: 453 Bytes
51b0cb7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
from abc import ABC, abstractmethod
import torch
from transformers import PreTrainedTokenizerBase


class SensorLocFinder(ABC):

    @abstractmethod
    def __init__(self, tokenizer: PreTrainedTokenizerBase, **kwargs):
        pass

    @abstractmethod
    def find_sensor_locs(self, input_ids: torch.Tensor) -> torch.Tensor:
        pass

    def __call__(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.find_sensor_locs(input_ids)