File size: 3,455 Bytes
5eb04ff
 
 
 
 
 
 
 
 
 
 
138378c
 
 
 
 
 
e41b138
 
 
5a2aaf5
e41b138
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138378c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e41b138
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
---
license: apache-2.0
datasets:
- kor_nli
language:
- ko
metrics:
- accuracy
pipeline_tag: zero-shot-classification
---

**This model has been referred to the following link : https://github.com/Huffon/klue-transformers-tutorial.git**

ํ•ด๋‹น ๋ชจ๋ธ์€ ์œ„ ๊นƒํ—ˆ๋ธŒ๋ฅผ ์ฐธ๊ณ ํ•˜์—ฌ klue/roberta-base ๋ชจ๋ธ์„ kor_nli ์˜ mnli, xnli๋กœ ํŒŒ์ธํŠœ๋‹ํ•œ ๋ชจ๋ธ์ž…๋‹ˆ๋‹ค.   
| train_loss | val_loss | acc | epoch | batch | lr |
| --- | --- | --- | --- | --- | --- |
| 0.326 | 0.538 | 0.811 | 3 | 32 | 2e-5 |


RoBERTa์™€ ๊ฐ™์ด token_type_ids๋ฅผ ์‚ฌ์šฉํ•˜์ง€ ์•Š๋Š” ๋ชจ๋ธ์˜ ๊ฒฝ์šฐ, zero-shot pipeline์„ ๋ฐ”๋กœ ์ ์šฉํ•  ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค(transformers==4.7.0 ๊ธฐ์ค€)  
๋”ฐ๋ผ์„œ ๋‹ค์Œ๊ณผ ๊ฐ™์ด ๋ณ€ํ™˜ํ•˜๋Š” ์ฝ”๋“œ๋ฅผ ๋„ฃ์–ด์ค˜์•ผ ํ•ฉ๋‹ˆ๋‹ค. ํ•ด๋‹น ์ฝ”๋“œ ๋˜ํ•œ ์œ„ ๊นƒํ—ˆ๋ธŒ์˜ ์ฝ”๋“œ๋ฅผ ์ˆ˜์ •ํ•˜์˜€์Šต๋‹ˆ๋‹ค.  

```python
class ArgumentHandler(ABC):
    """
    Base interface for handling arguments for each :class:`~transformers.pipelines.Pipeline`.
    """

    @abstractmethod
    def __call__(self, *args, **kwargs):
        raise NotImplementedError()


class CustomZeroShotClassificationArgumentHandler(ArgumentHandler):
    """
    Handles arguments for zero-shot for text classification by turning each possible label into an NLI
    premise/hypothesis pair.
    """

    def _parse_labels(self, labels):
        if isinstance(labels, str):
            labels = [label.strip() for label in labels.split(",")]
        return labels

    def __call__(self, sequences, labels, hypothesis_template):
        if len(labels) == 0 or len(sequences) == 0:
            raise ValueError("You must include at least one label and at least one sequence.")
        if hypothesis_template.format(labels[0]) == hypothesis_template:
            raise ValueError(
                (
                    'The provided hypothesis_template "{}" was not able to be formatted with the target labels. '
                    "Make sure the passed template includes formatting syntax such as {{}} where the label should go."
                ).format(hypothesis_template)
            )

        if isinstance(sequences, str):
            sequences = [sequences]
        labels = self._parse_labels(labels)

        sequence_pairs = []
        for label in labels:
            # ์ˆ˜์ •๋ถ€: ๋‘ ๋ฌธ์žฅ์„ ํŽ˜์–ด๋กœ ์ž…๋ ฅํ–ˆ์„ ๋•Œ, `token_type_ids`๊ฐ€ ์ž๋™์œผ๋กœ ๋ถ™๋Š” ๋ฌธ์ œ๋ฅผ ๋ฐฉ์ง€ํ•˜๊ธฐ ์œ„ํ•ด ๋ฏธ๋ฆฌ ๋‘ ๋ฌธ์žฅ์„ `sep_token` ๊ธฐ์ค€์œผ๋กœ ์ด์–ด์ฃผ๋„๋ก ํ•จ
            sequence_pairs.append(f"{sequences} {tokenizer.sep_token} {hypothesis_template.format(label)}")

        return sequence_pairs, sequences
```

์ดํ›„ classifier๋ฅผ ์ •์˜ํ•  ๋•Œ ์ด๋ฅผ ์ ์šฉํ•ด์•ผ ๋ฉ๋‹ˆ๋‹ค.  
```python
classifier = pipeline(
    "zero-shot-classification",
    args_parser=CustomZeroShotClassificationArgumentHandler(),
    model="pongjin/roberta_with_kornli"
)
```
#### results
```python
sequence = "๋ฐฐ๋‹น๋ฝ D-1 ์ฝ”์Šคํ”ผ, 2330์„  ์ƒ์Šน์„ธ...์™ธ์ธยท๊ธฐ๊ด€ ์‚ฌ์ž"	
candidate_labels =["์™ธํ™˜",'ํ™˜์œจ', "๊ฒฝ์ œ", "๊ธˆ์œต", "๋ถ€๋™์‚ฐ","์ฃผ์‹"]

classifier(
    sequence,
    candidate_labels,
    hypothesis_template='์ด๋Š” {}์— ๊ด€ํ•œ ๊ฒƒ์ด๋‹ค.',
)

>>{'sequence': '๋ฐฐ๋‹น๋ฝ D-1 ์ฝ”์Šคํ”ผ, 2330์„  ์ƒ์Šน์„ธ...์™ธ์ธยท๊ธฐ๊ด€ ์‚ฌ์ž',
 'labels': ['์ฃผ์‹', '๊ธˆ์œต', '๊ฒฝ์ œ', '์™ธํ™˜', 'ํ™˜์œจ', '๋ถ€๋™์‚ฐ'],
 'scores': [0.5052872896194458,
  0.17972524464130402,
  0.13852974772453308,
  0.09460823982954025,
  0.042949128895998,
  0.038900360465049744]}
```