Update README.md
Browse files
README.md
CHANGED
@@ -9,4 +9,62 @@ metrics:
|
|
9 |
pipeline_tag: zero-shot-classification
|
10 |
---
|
11 |
|
12 |
-
This model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
pipeline_tag: zero-shot-classification
|
10 |
---
|
11 |
|
12 |
+
This model has been referred to the following link : https://github.com/Huffon/klue-transformers-tutorial.git
|
13 |
+
|
14 |
+
|
15 |
+
RoBERTa์ ๊ฐ์ด token_type_ids๋ฅผ ์ฌ์ฉํ์ง ์๋ ๋ชจ๋ธ์ ๊ฒฝ์ฐ, zero-shot pipeline์ ๋ฐ๋ก ์ ์ฉํ ์ ์์ต๋๋ค(transformers==4.7.0 ๊ธฐ์ค)
|
16 |
+
๋ฐ๋ผ์ ๋ค์๊ณผ ๊ฐ์ด ๋ณํํ๋ ์ฝ๋๋ฅผ ๋ฃ์ด์ค์ผ ํฉ๋๋ค. ํด๋น ์ฝ๋ ๋ํ ์ง์ ์์ ํ์์ต๋๋ค.
|
17 |
+
|
18 |
+
```python
|
19 |
+
class ArgumentHandler(ABC):
|
20 |
+
"""
|
21 |
+
Base interface for handling arguments for each :class:`~transformers.pipelines.Pipeline`.
|
22 |
+
"""
|
23 |
+
|
24 |
+
@abstractmethod
|
25 |
+
def __call__(self, *args, **kwargs):
|
26 |
+
raise NotImplementedError()
|
27 |
+
|
28 |
+
|
29 |
+
class CustomZeroShotClassificationArgumentHandler(ArgumentHandler):
|
30 |
+
"""
|
31 |
+
Handles arguments for zero-shot for text classification by turning each possible label into an NLI
|
32 |
+
premise/hypothesis pair.
|
33 |
+
"""
|
34 |
+
|
35 |
+
def _parse_labels(self, labels):
|
36 |
+
if isinstance(labels, str):
|
37 |
+
labels = [label.strip() for label in labels.split(",")]
|
38 |
+
return labels
|
39 |
+
|
40 |
+
def __call__(self, sequences, labels, hypothesis_template):
|
41 |
+
if len(labels) == 0 or len(sequences) == 0:
|
42 |
+
raise ValueError("You must include at least one label and at least one sequence.")
|
43 |
+
if hypothesis_template.format(labels[0]) == hypothesis_template:
|
44 |
+
raise ValueError(
|
45 |
+
(
|
46 |
+
'The provided hypothesis_template "{}" was not able to be formatted with the target labels. '
|
47 |
+
"Make sure the passed template includes formatting syntax such as {{}} where the label should go."
|
48 |
+
).format(hypothesis_template)
|
49 |
+
)
|
50 |
+
|
51 |
+
if isinstance(sequences, str):
|
52 |
+
sequences = [sequences]
|
53 |
+
labels = self._parse_labels(labels)
|
54 |
+
|
55 |
+
sequence_pairs = []
|
56 |
+
for label in labels:
|
57 |
+
# ์์ ๋ถ: ๋ ๋ฌธ์ฅ์ ํ์ด๋ก ์
๋ ฅํ์ ๋, `token_type_ids`๊ฐ ์๋์ผ๋ก ๋ถ๋ ๋ฌธ์ ๋ฅผ ๋ฐฉ์งํ๊ธฐ ์ํด ๋ฏธ๋ฆฌ ๋ ๋ฌธ์ฅ์ `sep_token` ๊ธฐ์ค์ผ๋ก ์ด์ด์ฃผ๋๋ก ํจ
|
58 |
+
sequence_pairs.append(f"{sequences} {tokenizer.sep_token} {hypothesis_template.format(label)}")
|
59 |
+
|
60 |
+
return sequence_pairs, sequences
|
61 |
+
```
|
62 |
+
|
63 |
+
์ดํ classifier๋ฅผ ์ ์ํ ๋ ์ด๋ฅผ ์ ์ฉํด์ผ ๋ฉ๋๋ค.
|
64 |
+
```python
|
65 |
+
classifier = pipeline(
|
66 |
+
"zero-shot-classification",
|
67 |
+
args_parser=CustomZeroShotClassificationArgumentHandler(),
|
68 |
+
model="pongjin/roberta_with_kornli"
|
69 |
+
)
|
70 |
+
```
|