File size: 12,711 Bytes
0fdb130
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
import copy
import os
import tempfile
import types
from dataclasses import dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union

import torch
from huggingface_hub.utils import SoftTemporaryDirectory

from setfit.utils import set_docstring

from .. import logging
from ..modeling import SetFitModel
from .aspect_extractor import AspectExtractor


if TYPE_CHECKING:
    from spacy.tokens import Doc

logger = logging.get_logger(__name__)


@dataclass
class SpanSetFitModel(SetFitModel):
    spacy_model: str = "en_core_web_lg"
    span_context: int = 0

    attributes_to_save: Set[str] = field(
        init=False,
        repr=False,
        default_factory=lambda: {"normalize_embeddings", "labels", "span_context", "spacy_model"},
    )

    def prepend_aspects(self, docs: List["Doc"], aspects_list: List[List[slice]]) -> List[str]:
        for doc, aspects in zip(docs, aspects_list):
            for aspect_slice in aspects:
                aspect = doc[max(aspect_slice.start - self.span_context, 0) : aspect_slice.stop + self.span_context]
                # TODO: Investigate performance difference of different formats
                yield aspect.text + ":" + doc.text

    def __call__(self, docs: List["Doc"], aspects_list: List[List[slice]]) -> List[bool]:
        inputs_list = list(self.prepend_aspects(docs, aspects_list))
        preds = self.predict(inputs_list, as_numpy=True)
        iter_preds = iter(preds)
        return [[next(iter_preds) for _ in aspects] for aspects in aspects_list]

    def create_model_card(self, path: str, model_name: Optional[str] = None) -> None:
        """Creates and saves a model card for a SetFit model.



        Args:

            path (str): The path to save the model card to.

            model_name (str, *optional*): The name of the model. Defaults to `SetFit Model`.

        """
        if not os.path.exists(path):
            os.makedirs(path)

        # If the model_path is a folder that exists locally, i.e. when create_model_card is called
        # via push_to_hub, and the path is in a temporary folder, then we only take the last two
        # directories
        model_path = Path(model_name)
        if model_path.exists() and Path(tempfile.gettempdir()) in model_path.resolve().parents:
            model_name = "/".join(model_path.parts[-2:])

        is_aspect = isinstance(self, AspectModel)
        aspect_model = "setfit-absa-aspect"
        polarity_model = "setfit-absa-polarity"
        if model_name is not None:
            if is_aspect:
                aspect_model = model_name
                if model_name.endswith("-aspect"):
                    polarity_model = model_name[: -len("-aspect")] + "-polarity"
            else:
                polarity_model = model_name
                if model_name.endswith("-polarity"):
                    aspect_model = model_name[: -len("-polarity")] + "-aspect"

        # Only once:
        if self.model_card_data.absa is None and self.model_card_data.model_name:
            from spacy import __version__ as spacy_version

            self.model_card_data.model_name = self.model_card_data.model_name.replace(
                "SetFit", "SetFit Aspect Model" if is_aspect else "SetFit Polarity Model", 1
            )
            self.model_card_data.tags.insert(1, "absa")
            self.model_card_data.version["spacy"] = spacy_version
        self.model_card_data.absa = {
            "is_absa": True,
            "is_aspect": is_aspect,
            "spacy_model": self.spacy_model,
            "aspect_model": aspect_model,
            "polarity_model": polarity_model,
        }
        if self.model_card_data.task_name is None:
            self.model_card_data.task_name = "Aspect Based Sentiment Analysis (ABSA)"
        self.model_card_data.inference = False
        with open(os.path.join(path, "README.md"), "w", encoding="utf-8") as f:
            f.write(self.generate_model_card())


docstring = SpanSetFitModel.from_pretrained.__doc__
cut_index = docstring.find("multi_target_strategy")
if cut_index != -1:
    docstring = (
        docstring[:cut_index]
        + """model_card_data (`SetFitModelCardData`, *optional*):

                A `SetFitModelCardData` instance storing data such as model language, license, dataset name,

                    etc. to be used in the automatically generated model cards.

            use_differentiable_head (`bool`, *optional*):

                Whether to load SetFit using a differentiable (i.e., Torch) head instead of Logistic Regression.

            normalize_embeddings (`bool`, *optional*):

                Whether to apply normalization on the embeddings produced by the Sentence Transformer body.

            span_context (`int`, defaults to `0`):

                The number of words before and after the span candidate that should be prepended to the full sentence.

                By default, 0 for Aspect models and 3 for Polarity models.

            device (`Union[torch.device, str]`, *optional*):

                The device on which to load the SetFit model, e.g. `"cuda:0"`, `"mps"` or `torch.device("cuda")`."""
    )
    SpanSetFitModel.from_pretrained = set_docstring(SpanSetFitModel.from_pretrained, docstring, cls=SpanSetFitModel)


class AspectModel(SpanSetFitModel):
    def __call__(self, docs: List["Doc"], aspects_list: List[List[slice]]) -> List[bool]:
        sentence_preds = super().__call__(docs, aspects_list)
        return [
            [aspect for aspect, pred in zip(aspects, preds) if pred == "aspect"]
            for aspects, preds in zip(aspects_list, sentence_preds)
        ]


# The set_docstring magic has as a consequences that subclasses need to update the cls in the from_pretrained
# classmethod, otherwise the wrong instance will be instantiated.
AspectModel.from_pretrained = types.MethodType(AspectModel.from_pretrained.__func__, AspectModel)


@dataclass
class PolarityModel(SpanSetFitModel):
    span_context: int = 3


PolarityModel.from_pretrained = types.MethodType(PolarityModel.from_pretrained.__func__, PolarityModel)


@dataclass
class AbsaModel:
    aspect_extractor: AspectExtractor
    aspect_model: AspectModel
    polarity_model: PolarityModel

    def predict(self, inputs: Union[str, List[str]]) -> List[Dict[str, Any]]:
        is_str = isinstance(inputs, str)
        inputs_list = [inputs] if is_str else inputs
        docs, aspects_list = self.aspect_extractor(inputs_list)
        if sum(aspects_list, []) == []:
            return aspects_list

        aspects_list = self.aspect_model(docs, aspects_list)
        if sum(aspects_list, []) == []:
            return aspects_list

        polarity_list = self.polarity_model(docs, aspects_list)
        outputs = []
        for docs, aspects, polarities in zip(docs, aspects_list, polarity_list):
            outputs.append(
                [
                    {"span": docs[aspect_slice].text, "polarity": polarity}
                    for aspect_slice, polarity in zip(aspects, polarities)
                ]
            )
        return outputs if not is_str else outputs[0]

    @property
    def device(self) -> torch.device:
        return self.aspect_model.device

    def to(self, device: Union[str, torch.device]) -> "AbsaModel":
        self.aspect_model.to(device)
        self.polarity_model.to(device)

    def __call__(self, inputs: Union[str, List[str]]) -> List[Dict[str, Any]]:
        return self.predict(inputs)

    def save_pretrained(

        self,

        save_directory: Union[str, Path],

        polarity_save_directory: Optional[Union[str, Path]] = None,

        push_to_hub: bool = False,

        **kwargs,

    ) -> None:
        if polarity_save_directory is None:
            base_save_directory = Path(save_directory)
            save_directory = base_save_directory.parent / (base_save_directory.name + "-aspect")
            polarity_save_directory = base_save_directory.parent / (base_save_directory.name + "-polarity")
        self.aspect_model.save_pretrained(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs)
        self.polarity_model.save_pretrained(save_directory=polarity_save_directory, push_to_hub=push_to_hub, **kwargs)

    @classmethod
    def from_pretrained(

        cls,

        model_id: str,

        polarity_model_id: Optional[str] = None,

        spacy_model: Optional[str] = None,

        span_contexts: Tuple[Optional[int], Optional[int]] = (None, None),

        force_download: bool = None,

        resume_download: bool = None,

        proxies: Optional[Dict] = None,

        token: Optional[Union[str, bool]] = None,

        cache_dir: Optional[str] = None,

        local_files_only: bool = None,

        use_differentiable_head: bool = None,

        normalize_embeddings: bool = None,

        **model_kwargs,

    ) -> "AbsaModel":
        revision = None
        if len(model_id.split("@")) == 2:
            model_id, revision = model_id.split("@")
        if spacy_model:
            model_kwargs["spacy_model"] = spacy_model
        aspect_model = AspectModel.from_pretrained(
            model_id,
            span_context=span_contexts[0],
            revision=revision,
            force_download=force_download,
            resume_download=resume_download,
            proxies=proxies,
            token=token,
            cache_dir=cache_dir,
            local_files_only=local_files_only,
            use_differentiable_head=use_differentiable_head,
            normalize_embeddings=normalize_embeddings,
            labels=["no aspect", "aspect"],
            **model_kwargs,
        )
        if polarity_model_id:
            model_id = polarity_model_id
            revision = None
            if len(model_id.split("@")) == 2:
                model_id, revision = model_id.split("@")
        # If model_card_data was provided, "separate" the instance between the Aspect
        # and Polarity models.
        model_card_data = model_kwargs.pop("model_card_data", None)
        if model_card_data:
            model_kwargs["model_card_data"] = copy.deepcopy(model_card_data)
        polarity_model = PolarityModel.from_pretrained(
            model_id,
            span_context=span_contexts[1],
            revision=revision,
            force_download=force_download,
            resume_download=resume_download,
            proxies=proxies,
            token=token,
            cache_dir=cache_dir,
            local_files_only=local_files_only,
            use_differentiable_head=use_differentiable_head,
            normalize_embeddings=normalize_embeddings,
            **model_kwargs,
        )
        if aspect_model.spacy_model != polarity_model.spacy_model:
            logger.warning(
                "The Aspect and Polarity models are configured to use different spaCy models:\n"
                f"* {repr(aspect_model.spacy_model)} for the aspect model, and\n"
                f"* {repr(polarity_model.spacy_model)} for the polarity model.\n"
                f"This model will use {repr(aspect_model.spacy_model)}."
            )

        aspect_extractor = AspectExtractor(spacy_model=aspect_model.spacy_model)

        return cls(aspect_extractor, aspect_model, polarity_model)

    def push_to_hub(self, repo_id: str, polarity_repo_id: Optional[str] = None, **kwargs) -> None:
        if "/" not in repo_id:
            raise ValueError(
                '`repo_id` must be a full repository ID, including organisation, e.g. "tomaarsen/setfit-absa-restaurant".'
            )
        if polarity_repo_id is not None and "/" not in polarity_repo_id:
            raise ValueError(
                '`polarity_repo_id` must be a full repository ID, including organisation, e.g. "tomaarsen/setfit-absa-restaurant".'
            )
        commit_message = kwargs.pop("commit_message", "Add SetFit ABSA model")

        # Push the files to the repo in a single commit
        with SoftTemporaryDirectory() as tmp_dir:
            save_directory = Path(tmp_dir) / repo_id
            polarity_save_directory = None if polarity_repo_id is None else Path(tmp_dir) / polarity_repo_id
            self.save_pretrained(
                save_directory=save_directory,
                polarity_save_directory=polarity_save_directory,
                push_to_hub=True,
                commit_message=commit_message,
                **kwargs,
            )