select [cls] token only
Browse files- modeling_tunbert.py +1 -1
modeling_tunbert.py
CHANGED
|
@@ -31,7 +31,7 @@ class TunBERT(PreTrainedModel):
|
|
| 31 |
# the [cls] token is used in bert to identify the class of the sentence
|
| 32 |
# meaning that we need only the first token of each sentence
|
| 33 |
# and the model representation of the rest of the sentence does not concern us
|
| 34 |
-
|
| 35 |
loss =None
|
| 36 |
if labels is not None :
|
| 37 |
loss_func = nn.CrossEntropyLoss()
|
|
|
|
| 31 |
# the [cls] token is used in bert to identify the class of the sentence
|
| 32 |
# meaning that we need only the first token of each sentence
|
| 33 |
# and the model representation of the rest of the sentence does not concern us
|
| 34 |
+
logits = logits[:,0,:] # [bs, seq, class]
|
| 35 |
loss =None
|
| 36 |
if labels is not None :
|
| 37 |
loss_func = nn.CrossEntropyLoss()
|