pongjin commited on
Commit
e41b138
โ€ข
1 Parent(s): 5eb04ff

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +59 -1
README.md CHANGED
@@ -9,4 +9,62 @@ metrics:
9
  pipeline_tag: zero-shot-classification
10
  ---
11
 
12
- This model is
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  pipeline_tag: zero-shot-classification
10
  ---
11
 
12
+ This model has been referred to the following link : https://github.com/Huffon/klue-transformers-tutorial.git
13
+
14
+
15
+ RoBERTa์™€ ๊ฐ™์ด token_type_ids๋ฅผ ์‚ฌ์šฉํ•˜์ง€ ์•Š๋Š” ๋ชจ๋ธ์˜ ๊ฒฝ์šฐ, zero-shot pipeline์„ ๋ฐ”๋กœ ์ ์šฉํ•  ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค(transformers==4.7.0 ๊ธฐ์ค€)
16
+ ๋”ฐ๋ผ์„œ ๋‹ค์Œ๊ณผ ๊ฐ™์ด ๋ณ€ํ™˜ํ•˜๋Š” ์ฝ”๋“œ๋ฅผ ๋„ฃ์–ด์ค˜์•ผ ํ•ฉ๋‹ˆ๋‹ค. ํ•ด๋‹น ์ฝ”๋“œ ๋˜ํ•œ ์ง์ ‘ ์ˆ˜์ •ํ•˜์˜€์Šต๋‹ˆ๋‹ค.
17
+
18
+ ```python
19
+ class ArgumentHandler(ABC):
20
+ """
21
+ Base interface for handling arguments for each :class:`~transformers.pipelines.Pipeline`.
22
+ """
23
+
24
+ @abstractmethod
25
+ def __call__(self, *args, **kwargs):
26
+ raise NotImplementedError()
27
+
28
+
29
+ class CustomZeroShotClassificationArgumentHandler(ArgumentHandler):
30
+ """
31
+ Handles arguments for zero-shot for text classification by turning each possible label into an NLI
32
+ premise/hypothesis pair.
33
+ """
34
+
35
+ def _parse_labels(self, labels):
36
+ if isinstance(labels, str):
37
+ labels = [label.strip() for label in labels.split(",")]
38
+ return labels
39
+
40
+ def __call__(self, sequences, labels, hypothesis_template):
41
+ if len(labels) == 0 or len(sequences) == 0:
42
+ raise ValueError("You must include at least one label and at least one sequence.")
43
+ if hypothesis_template.format(labels[0]) == hypothesis_template:
44
+ raise ValueError(
45
+ (
46
+ 'The provided hypothesis_template "{}" was not able to be formatted with the target labels. '
47
+ "Make sure the passed template includes formatting syntax such as {{}} where the label should go."
48
+ ).format(hypothesis_template)
49
+ )
50
+
51
+ if isinstance(sequences, str):
52
+ sequences = [sequences]
53
+ labels = self._parse_labels(labels)
54
+
55
+ sequence_pairs = []
56
+ for label in labels:
57
+ # ์ˆ˜์ •๋ถ€: ๋‘ ๋ฌธ์žฅ์„ ํŽ˜์–ด๋กœ ์ž…๋ ฅํ–ˆ์„ ๋•Œ, `token_type_ids`๊ฐ€ ์ž๋™์œผ๋กœ ๋ถ™๋Š” ๋ฌธ์ œ๋ฅผ ๋ฐฉ์ง€ํ•˜๊ธฐ ์œ„ํ•ด ๋ฏธ๋ฆฌ ๋‘ ๋ฌธ์žฅ์„ `sep_token` ๊ธฐ์ค€์œผ๋กœ ์ด์–ด์ฃผ๋„๋ก ํ•จ
58
+ sequence_pairs.append(f"{sequences} {tokenizer.sep_token} {hypothesis_template.format(label)}")
59
+
60
+ return sequence_pairs, sequences
61
+ ```
62
+
63
+ ์ดํ›„ classifier๋ฅผ ์ •์˜ํ•  ๋•Œ ์ด๋ฅผ ์ ์šฉํ•ด์•ผ ๋ฉ๋‹ˆ๋‹ค.
64
+ ```python
65
+ classifier = pipeline(
66
+ "zero-shot-classification",
67
+ args_parser=CustomZeroShotClassificationArgumentHandler(),
68
+ model="pongjin/roberta_with_kornli"
69
+ )
70
+ ```