|
from textattack.models.wrappers import HuggingFaceModelWrapper |
|
|
|
|
|
class TADModelWrapper(HuggingFaceModelWrapper): |
|
"""Transformers sentiment analysis pipeline returns a list of responses |
|
like |
|
|
|
[{'label': 'POSITIVE', 'score': 0.7817379832267761}] |
|
|
|
We need to convert that to a format TextAttack understands, like |
|
|
|
[[0.218262017, 0.7817379832267761] |
|
""" |
|
|
|
def __init__(self, model): |
|
self.model = model |
|
|
|
def __call__(self, text_inputs, **kwargs): |
|
outputs = [] |
|
for text_input in text_inputs: |
|
raw_outputs = self.model.infer(text_input, print_result=False, **kwargs) |
|
outputs.append(raw_outputs["probs"]) |
|
|
|
return outputs |
|
|