File size: 3,455 Bytes
5eb04ff 138378c e41b138 5a2aaf5 e41b138 138378c e41b138 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 |
---
license: apache-2.0
datasets:
- kor_nli
language:
- ko
metrics:
- accuracy
pipeline_tag: zero-shot-classification
---
**This model has been referred to the following link : https://github.com/Huffon/klue-transformers-tutorial.git**
ํด๋น ๋ชจ๋ธ์ ์ ๊นํ๋ธ๋ฅผ ์ฐธ๊ณ ํ์ฌ klue/roberta-base ๋ชจ๋ธ์ kor_nli ์ mnli, xnli๋ก ํ์ธํ๋ํ ๋ชจ๋ธ์
๋๋ค.
| train_loss | val_loss | acc | epoch | batch | lr |
| --- | --- | --- | --- | --- | --- |
| 0.326 | 0.538 | 0.811 | 3 | 32 | 2e-5 |
RoBERTa์ ๊ฐ์ด token_type_ids๋ฅผ ์ฌ์ฉํ์ง ์๋ ๋ชจ๋ธ์ ๊ฒฝ์ฐ, zero-shot pipeline์ ๋ฐ๋ก ์ ์ฉํ ์ ์์ต๋๋ค(transformers==4.7.0 ๊ธฐ์ค)
๋ฐ๋ผ์ ๋ค์๊ณผ ๊ฐ์ด ๋ณํํ๋ ์ฝ๋๋ฅผ ๋ฃ์ด์ค์ผ ํฉ๋๋ค. ํด๋น ์ฝ๋ ๋ํ ์ ๊นํ๋ธ์ ์ฝ๋๋ฅผ ์์ ํ์์ต๋๋ค.
```python
class ArgumentHandler(ABC):
"""
Base interface for handling arguments for each :class:`~transformers.pipelines.Pipeline`.
"""
@abstractmethod
def __call__(self, *args, **kwargs):
raise NotImplementedError()
class CustomZeroShotClassificationArgumentHandler(ArgumentHandler):
"""
Handles arguments for zero-shot for text classification by turning each possible label into an NLI
premise/hypothesis pair.
"""
def _parse_labels(self, labels):
if isinstance(labels, str):
labels = [label.strip() for label in labels.split(",")]
return labels
def __call__(self, sequences, labels, hypothesis_template):
if len(labels) == 0 or len(sequences) == 0:
raise ValueError("You must include at least one label and at least one sequence.")
if hypothesis_template.format(labels[0]) == hypothesis_template:
raise ValueError(
(
'The provided hypothesis_template "{}" was not able to be formatted with the target labels. '
"Make sure the passed template includes formatting syntax such as {{}} where the label should go."
).format(hypothesis_template)
)
if isinstance(sequences, str):
sequences = [sequences]
labels = self._parse_labels(labels)
sequence_pairs = []
for label in labels:
# ์์ ๋ถ: ๋ ๋ฌธ์ฅ์ ํ์ด๋ก ์
๋ ฅํ์ ๋, `token_type_ids`๊ฐ ์๋์ผ๋ก ๋ถ๋ ๋ฌธ์ ๋ฅผ ๋ฐฉ์งํ๊ธฐ ์ํด ๋ฏธ๋ฆฌ ๋ ๋ฌธ์ฅ์ `sep_token` ๊ธฐ์ค์ผ๋ก ์ด์ด์ฃผ๋๋ก ํจ
sequence_pairs.append(f"{sequences} {tokenizer.sep_token} {hypothesis_template.format(label)}")
return sequence_pairs, sequences
```
์ดํ classifier๋ฅผ ์ ์ํ ๋ ์ด๋ฅผ ์ ์ฉํด์ผ ๋ฉ๋๋ค.
```python
classifier = pipeline(
"zero-shot-classification",
args_parser=CustomZeroShotClassificationArgumentHandler(),
model="pongjin/roberta_with_kornli"
)
```
#### results
```python
sequence = "๋ฐฐ๋น๋ฝ D-1 ์ฝ์คํผ, 2330์ ์์น์ธ...์ธ์ธยท๊ธฐ๊ด ์ฌ์"
candidate_labels =["์ธํ",'ํ์จ', "๊ฒฝ์ ", "๊ธ์ต", "๋ถ๋์ฐ","์ฃผ์"]
classifier(
sequence,
candidate_labels,
hypothesis_template='์ด๋ {}์ ๊ดํ ๊ฒ์ด๋ค.',
)
>>{'sequence': '๋ฐฐ๋น๋ฝ D-1 ์ฝ์คํผ, 2330์ ์์น์ธ...์ธ์ธยท๊ธฐ๊ด ์ฌ์',
'labels': ['์ฃผ์', '๊ธ์ต', '๊ฒฝ์ ', '์ธํ', 'ํ์จ', '๋ถ๋์ฐ'],
'scores': [0.5052872896194458,
0.17972524464130402,
0.13852974772453308,
0.09460823982954025,
0.042949128895998,
0.038900360465049744]}
``` |