|
""" |
|
ModelWrapper class |
|
-------------------------- |
|
|
|
""" |
|
|
|
from abc import ABC, abstractmethod |
|
|
|
|
|
class ModelWrapper(ABC): |
|
"""A model wrapper queries a model with a list of text inputs. |
|
|
|
Classification-based models return a list of lists, where each sublist |
|
represents the model's scores for a given input. |
|
|
|
Text-to-text models return a list of strings, where each string is the |
|
output – like a translation or summarization – for a given input. |
|
""" |
|
|
|
@abstractmethod |
|
def __call__(self, text_input_list, **kwargs): |
|
raise NotImplementedError() |
|
|
|
def get_grad(self, text_input): |
|
"""Get gradient of loss with respect to input tokens.""" |
|
raise NotImplementedError() |
|
|
|
def _tokenize(self, inputs): |
|
"""Helper method for `tokenize`""" |
|
raise NotImplementedError() |
|
|
|
def tokenize(self, inputs, strip_prefix=False): |
|
"""Helper method that tokenizes input strings |
|
Args: |
|
inputs (list[str]): list of input strings |
|
strip_prefix (bool): If `True`, we strip auxiliary characters added to tokens as prefixes (e.g. "##" for BERT, "Ġ" for RoBERTa) |
|
Returns: |
|
tokens (list[list[str]]): List of list of tokens as strings |
|
""" |
|
tokens = self._tokenize(inputs) |
|
if strip_prefix: |
|
|
|
strip_chars = ["##", "Ġ", "__"] |
|
|
|
|
|
def strip(s, chars): |
|
for c in chars: |
|
s = s.replace(c, "") |
|
return s |
|
|
|
tokens = [[strip(t, strip_chars) for t in x] for x in tokens] |
|
|
|
return tokens |
|
|