|
import torch |
|
from transformers import PreTrainedTokenizerBase |
|
|
|
from .sensor_loc_finder import SensorLocFinder |
|
|
|
|
|
class StoriesSensorLocFinder(SensorLocFinder): |
|
|
|
def __init__(self, tokenizer: PreTrainedTokenizerBase, **kwargs): |
|
self.questions_section_toks = tokenizer.encode("## Questions") |
|
self.question_mark_tok = tokenizer.encode("?")[0] |
|
self.other_question_mark_tok = tokenizer.encode(")?")[0] |
|
assert len(self.questions_section_toks) == 2 |
|
|
|
def find_sensor_locs(self, input_ids: torch.Tensor) -> torch.Tensor: |
|
device = input_ids.device |
|
question_mark_locs = self._is_sensor_loc(input_ids) |
|
total_locs = torch.cumsum(question_mark_locs, dim=-1) |
|
total_overall = total_locs[:, -1] |
|
assert ( |
|
total_overall == 3 |
|
).all(), "can handle different cases, but assuming this is easiest" |
|
eqs = total_locs[:, :, None] == torch.arange(1, 4)[None, None].to(device) |
|
locs = torch.where( |
|
eqs.any(dim=-2), |
|
torch.argmax(eqs.to(torch.uint8), dim=-2), |
|
input_ids.shape[-1] - 3, |
|
).clamp(max=input_ids.shape[-1] - 3) |
|
return locs |
|
|
|
|
|
def _is_sensor_loc(self, input_ids: torch.Tensor): |
|
questions_section_toks = self.questions_section_toks |
|
question_mark_tok = self.question_mark_tok |
|
other_question_mark_tok = self.other_question_mark_tok |
|
eq_question_item = (input_ids[:, :-1] == questions_section_toks[0]) & ( |
|
input_ids[:, 1:] == questions_section_toks[1] |
|
) |
|
assert (eq_question_item.sum(dim=-1, dtype=torch.int) == 1).all(), "could relax" |
|
|
|
summed = torch.cumsum( |
|
torch.cat([eq_question_item, eq_question_item[:, -1:]], dim=-1), dim=-1 |
|
) |
|
return (summed > 0) & ( |
|
(input_ids == question_mark_tok) | (input_ids == other_question_mark_tok) |
|
) |
|
|