paulhindemith
commited on
Commit
•
46f1d8c
1
Parent(s):
cd708ec
commit files to HF hub
Browse files- config.json +3 -1
- fasttext_fsc.py +186 -46
config.json
CHANGED
@@ -19,7 +19,9 @@
|
|
19 |
},
|
20 |
"max_length": 128,
|
21 |
"model_type": "fasttext_classification",
|
22 |
-
"
|
|
|
|
|
23 |
"tokenizerI_class": "FastTextJpTokenizer",
|
24 |
"tokenizer_class": "FastTextJpTokenizer",
|
25 |
"torch_dtype": "float32",
|
|
|
19 |
},
|
20 |
"max_length": 128,
|
21 |
"model_type": "fasttext_classification",
|
22 |
+
"ngrams": [
|
23 |
+
2
|
24 |
+
],
|
25 |
"tokenizerI_class": "FastTextJpTokenizer",
|
26 |
"tokenizer_class": "FastTextJpTokenizer",
|
27 |
"torch_dtype": "float32",
|
fasttext_fsc.py
CHANGED
@@ -11,82 +11,92 @@ class FastTextForSeuqenceClassificationConfig(FastTextJpConfig):
|
|
11 |
model_type = "fasttext_classification"
|
12 |
|
13 |
def __init__(self,
|
14 |
-
ngram: int = 2,
|
15 |
tokenizer_class="FastTextJpTokenizer",
|
16 |
**kwargs):
|
17 |
"""初期化処理
|
18 |
|
19 |
Args:
|
20 |
-
ngram (int, optional):
|
21 |
-
文章を分割する際のNgram
|
22 |
tokenizer_class (str, optional):
|
23 |
tokenizer_classを指定しないと、pipelineから読み込まれません。
|
24 |
config.jsonに記載されます。
|
25 |
"""
|
26 |
-
|
|
|
|
|
|
|
|
|
|
|
27 |
kwargs["tokenizer_class"] = tokenizer_class
|
28 |
super().__init__(**kwargs)
|
29 |
|
30 |
|
31 |
-
class
|
32 |
-
"""FastTextのベクトルをベースとした分類を行います。
|
33 |
-
"""
|
34 |
|
35 |
-
def __init__(self
|
|
|
36 |
|
37 |
-
|
38 |
-
|
|
|
|
|
39 |
|
40 |
-
|
41 |
-
|
|
|
|
|
42 |
|
43 |
Returns:
|
44 |
-
|
|
|
45 |
"""
|
46 |
-
input_ids = inputs["input_ids"]
|
47 |
-
outputs = self.word_embeddings(input_ids)
|
48 |
|
49 |
-
|
50 |
-
for idx in range(len(outputs)):
|
51 |
-
output = outputs[idx]
|
52 |
-
# token_type_ids == 0が文章、1がラベルです。
|
53 |
-
token_type_ids = inputs["token_type_ids"][idx]
|
54 |
-
# attention_mask == 1がパディングでないもの
|
55 |
-
attention_mask = inputs["attention_mask"][idx]
|
56 |
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
attention_mask == 1)]
|
61 |
-
sentence_words = self.split_ngram(sentence, self.max_ngram)
|
62 |
-
candidate_label_mean = torch.mean(candidate_label,
|
63 |
-
dim=-2,
|
64 |
-
keepdim=True)
|
65 |
-
p = self.cosine_similarity(sentence_words, candidate_label_mean)
|
66 |
-
logits.append([torch.log(p), -torch.inf, torch.log(1 - p)])
|
67 |
-
logits = torch.FloatTensor(logits)
|
68 |
-
return SequenceClassifierOutput(
|
69 |
-
loss=None,
|
70 |
-
logits=logits,
|
71 |
-
hidden_states=None,
|
72 |
-
attentions=None,
|
73 |
-
)
|
74 |
|
75 |
def cosine_similarity(
|
76 |
-
self,
|
77 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
res = torch.tensor(0.)
|
79 |
-
for i in range(len(
|
80 |
-
sw =
|
81 |
p = torch.nn.functional.cosine_similarity(sw,
|
82 |
-
|
83 |
dim=0)
|
84 |
if p > res:
|
85 |
res = p
|
86 |
return res
|
87 |
|
88 |
-
def split_ngram(self, sentences: TensorType["
|
89 |
-
n: int) -> TensorType["
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
90 |
res = []
|
91 |
if len(sentences) <= n:
|
92 |
return torch.stack([torch.mean(sentences, dim=0, keepdim=False)])
|
@@ -96,6 +106,136 @@ class FastTextForSeuqenceClassification(FastTextJpModel):
|
|
96 |
return torch.stack(res)
|
97 |
|
98 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
99 |
# AutoModelに登録が必要だが、いろいろやり方が変わっているようで定まっていない。(2022/11/6)
|
100 |
# https://huggingface.co/docs/transformers/custom_models#sending-the-code-to-the-hub
|
101 |
FastTextForSeuqenceClassificationConfig.register_for_auto_class()
|
|
|
11 |
model_type = "fasttext_classification"
|
12 |
|
13 |
def __init__(self,
|
14 |
+
ngram: int | list[int] = 2,
|
15 |
tokenizer_class="FastTextJpTokenizer",
|
16 |
**kwargs):
|
17 |
"""初期化処理
|
18 |
|
19 |
Args:
|
20 |
+
ngram (int | list[int], optional):
|
21 |
+
文章を分割する際のNgram。
|
22 |
tokenizer_class (str, optional):
|
23 |
tokenizer_classを指定しないと、pipelineから読み込まれません。
|
24 |
config.jsonに記載されます。
|
25 |
"""
|
26 |
+
if isinstance(ngram, int):
|
27 |
+
self.ngrams = [ngram]
|
28 |
+
elif isinstance(ngram, list):
|
29 |
+
self.ngrams = ngram
|
30 |
+
else:
|
31 |
+
raise TypeError(f"got unknown type {type(ngram)}")
|
32 |
kwargs["tokenizer_class"] = tokenizer_class
|
33 |
super().__init__(**kwargs)
|
34 |
|
35 |
|
36 |
+
class NgramForSeuqenceClassification():
|
|
|
|
|
37 |
|
38 |
+
def __init__(self):
|
39 |
+
...
|
40 |
|
41 |
+
def __call__(self, sentence: TensorType["A", "vectors"],
|
42 |
+
candidate_label: TensorType["B", "vectors"],
|
43 |
+
ngram: int) -> TensorType[3]:
|
44 |
+
"""Ngramで文章を分けてコサイン類似度を算出する。
|
45 |
|
46 |
+
Args:
|
47 |
+
sentence (TensorType["A", "vectors"]): 文章ベクトル
|
48 |
+
candidate_label (TensorType["B", "vectors"]): ラベルベクトル
|
49 |
+
ngram (int): Ngram
|
50 |
|
51 |
Returns:
|
52 |
+
TensorType[3]:
|
53 |
+
文章の類似度。[Entailment, Neutral, Contradiction]
|
54 |
"""
|
|
|
|
|
55 |
|
56 |
+
sentence_ngrams = self.split_ngram(sentence, ngram)
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
|
58 |
+
candidate_label_mean = torch.mean(candidate_label, dim=0, keepdim=True)
|
59 |
+
p = self.cosine_similarity(sentence_ngrams, candidate_label_mean)
|
60 |
+
return torch.tensor([torch.log(p), -torch.inf, torch.log(1 - p)])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
|
62 |
def cosine_similarity(
|
63 |
+
self, sentence_ngrams: TensorType["ngrams", "vectors"],
|
64 |
+
candidate_label_mean: TensorType[1, "vectors"]) -> TensorType[1]:
|
65 |
+
"""コサイン類似度を計算する。
|
66 |
+
|
67 |
+
Args:
|
68 |
+
sentence_ngrams (TensorType["ngrams", "vectors"]):
|
69 |
+
Ngram化された文章ベクトル
|
70 |
+
candidate_label_mean (TensorType[1, "vectors"]):
|
71 |
+
ラベルベクトル
|
72 |
+
|
73 |
+
Returns:
|
74 |
+
TensorType[1]: _description_
|
75 |
+
"""
|
76 |
+
|
77 |
res = torch.tensor(0.)
|
78 |
+
for i in range(len(sentence_ngrams)):
|
79 |
+
sw = sentence_ngrams[i]
|
80 |
p = torch.nn.functional.cosine_similarity(sw,
|
81 |
+
candidate_label_mean[0],
|
82 |
dim=0)
|
83 |
if p > res:
|
84 |
res = p
|
85 |
return res
|
86 |
|
87 |
+
def split_ngram(self, sentences: TensorType["A", "vectors"],
|
88 |
+
n: int) -> TensorType["ngrams", "vectors"]:
|
89 |
+
"""AとBの関連度を計算します。
|
90 |
+
Args:
|
91 |
+
sentences(TensorType["A", "vectors"]):
|
92 |
+
対象の文章
|
93 |
+
n(int):
|
94 |
+
ngram
|
95 |
+
Returns:
|
96 |
+
TensorType["ngrams", "vectors"]:
|
97 |
+
Ngram化された文章
|
98 |
+
"""
|
99 |
+
|
100 |
res = []
|
101 |
if len(sentences) <= n:
|
102 |
return torch.stack([torch.mean(sentences, dim=0, keepdim=False)])
|
|
|
106 |
return torch.stack(res)
|
107 |
|
108 |
|
109 |
+
class NgramsForSeuqenceClassification():
|
110 |
+
|
111 |
+
def __init__(self, config: FastTextForSeuqenceClassificationConfig):
|
112 |
+
self.max_ngrams = config.ngrams
|
113 |
+
self.ngram_layer = NgramForSeuqenceClassification()
|
114 |
+
|
115 |
+
def __call__(self, sentence: TensorType["A", "vectors"],
|
116 |
+
candidate_label: TensorType["B", "vectors"]) -> TensorType[3]:
|
117 |
+
"""AとBの関連度を計算します。
|
118 |
+
Args:
|
119 |
+
sentence(TensorType["A", "vectors"]):
|
120 |
+
対象の文章
|
121 |
+
candidate_label(TensorType["B", "vectors"]):
|
122 |
+
ラベルの文章
|
123 |
+
|
124 |
+
Returns:
|
125 |
+
TensorType[3]:
|
126 |
+
文章の類似度。[Entailment, Neutral, Contradiction]
|
127 |
+
"""
|
128 |
+
|
129 |
+
res = [-torch.inf, -torch.inf, -torch.inf]
|
130 |
+
for ngram in self.max_ngrams:
|
131 |
+
logit = self.ngram_layer(sentence, candidate_label, ngram)
|
132 |
+
if logit[0] > res[0]:
|
133 |
+
res = logit
|
134 |
+
return torch.tensor(res)
|
135 |
+
|
136 |
+
|
137 |
+
class BatchedNgramsForSeuqenceClassification():
|
138 |
+
|
139 |
+
def __init__(self, config: FastTextForSeuqenceClassificationConfig):
|
140 |
+
self.ngrams_layer = NgramsForSeuqenceClassification(config)
|
141 |
+
|
142 |
+
def __call__(
|
143 |
+
self,
|
144 |
+
last_hidden_state: TensorType["batch", "A+B", "vectors"],
|
145 |
+
token_type_ids: TensorType["batch", "A+B"],
|
146 |
+
attention_mask: TensorType["batch", "A+B"],
|
147 |
+
) -> TensorType["batch", 3]:
|
148 |
+
"""AとBの関連度を計算します。
|
149 |
+
Args:
|
150 |
+
last_hidden_state(TensorType["batch", "A+B", "vectors"]):
|
151 |
+
embeddingsの値。
|
152 |
+
token_type_ids(TensorType["A+B"]):
|
153 |
+
文章のid。0か1で、Bの場合1。
|
154 |
+
attention_mask(TensorType["A+B"]):
|
155 |
+
padを識別する。0か1で、padの場合1。
|
156 |
+
|
157 |
+
Returns:
|
158 |
+
TensorType["batch", 3]:
|
159 |
+
文章の類似度。[Entailment, Neutral, Contradiction]
|
160 |
+
"""
|
161 |
+
|
162 |
+
logits = []
|
163 |
+
embeddings = last_hidden_state
|
164 |
+
for idx in range(len(embeddings)):
|
165 |
+
vec = embeddings[idx]
|
166 |
+
# token_type_ids == 0が文章、1がラベルです。
|
167 |
+
token_type_ids = token_type_ids[idx]
|
168 |
+
# attention_mask == 1がパディングでないもの
|
169 |
+
attention_mask = attention_mask[idx]
|
170 |
+
|
171 |
+
sentence, candidate_label = self.split_sentence(
|
172 |
+
vec, token_type_ids, attention_mask)
|
173 |
+
logit = self.ngrams_layer(sentence, candidate_label)
|
174 |
+
logits.append(logit)
|
175 |
+
logits = torch.tensor(logits)
|
176 |
+
return logits
|
177 |
+
|
178 |
+
def split_sentence(
|
179 |
+
self, vec: TensorType["A+B", "vectors"],
|
180 |
+
token_type_ids: TensorType["A+B"], attention_mask: TensorType["A+B"]
|
181 |
+
) -> tuple[TensorType["A", "vectors"], TensorType["B", "vectors"]]:
|
182 |
+
"""CrossEncoderになっているので、文章を分割します。
|
183 |
+
|
184 |
+
Args:
|
185 |
+
vec(TensorType["A+B","vectors"]):
|
186 |
+
単語ベクトル
|
187 |
+
|
188 |
+
token_type_ids(TensorType["A+B"]):
|
189 |
+
文章のid。0か1で、Bの場合1。
|
190 |
+
|
191 |
+
attention_mask(TensorType["A+B"]):
|
192 |
+
padを識別する。0か1で、padの場合1。
|
193 |
+
|
194 |
+
Returns:
|
195 |
+
tuple[TensorType["A", "vectors"], TensorType["B", "vectors"]]:
|
196 |
+
AとBの文章を分割して返します。
|
197 |
+
"""
|
198 |
+
|
199 |
+
sentence = vec[torch.logical_and(token_type_ids == 0,
|
200 |
+
attention_mask == 1)]
|
201 |
+
candidate_label = vec[torch.logical_and(token_type_ids == 1,
|
202 |
+
attention_mask == 1)]
|
203 |
+
return sentence, candidate_label
|
204 |
+
|
205 |
+
|
206 |
+
class FastTextForSeuqenceClassification(FastTextJpModel):
|
207 |
+
"""FastTextのベクトルをベースとした分類を行います。
|
208 |
+
"""
|
209 |
+
|
210 |
+
def __init__(self, config: FastTextForSeuqenceClassificationConfig):
|
211 |
+
|
212 |
+
self.layer = BatchedNgramsForSeuqenceClassification(config)
|
213 |
+
super().__init__(config)
|
214 |
+
|
215 |
+
def forward(
|
216 |
+
self,
|
217 |
+
input_ids: TensorType["batch", "A+B", "vecotors"] = None,
|
218 |
+
attention_mask: TensorType["batch", "A+B"] = None,
|
219 |
+
token_type_ids: TensorType["batch", "A+B"] = None
|
220 |
+
) -> SequenceClassifierOutput:
|
221 |
+
"""候補となるラベルから分類を行います。
|
222 |
+
|
223 |
+
Returns:
|
224 |
+
SequenceClassifierOutput: 候補が正解している確率
|
225 |
+
"""
|
226 |
+
outputs = self.word_embeddings(input_ids)
|
227 |
+
logits = self.layer(last_hidden_state=outputs,
|
228 |
+
attention_mask=attention_mask,
|
229 |
+
token_type_ids=token_type_ids)
|
230 |
+
|
231 |
+
return SequenceClassifierOutput(
|
232 |
+
loss=None,
|
233 |
+
logits=logits,
|
234 |
+
hidden_states=None,
|
235 |
+
attentions=None,
|
236 |
+
)
|
237 |
+
|
238 |
+
|
239 |
# AutoModelに登録が必要だが、いろいろやり方が変わっているようで定まっていない。(2022/11/6)
|
240 |
# https://huggingface.co/docs/transformers/custom_models#sending-the-code-to-the-hub
|
241 |
FastTextForSeuqenceClassificationConfig.register_for_auto_class()
|