marinone94
commited on
Commit
•
0713f7f
1
Parent(s):
c4280a8
fix tensor dim
Browse files
run_speech_recognition_seq2seq_streaming.py
CHANGED
@@ -334,9 +334,9 @@ class DataCollatorSpeechSeq2SeqWithPadding:
|
|
334 |
labels = labels[:, 1:]
|
335 |
|
336 |
# add start of sentence token to labels + language + task
|
337 |
-
labels = torch.cat((torch.full_like(labels[:, 0], self.task_id).unsqueeze(0).T, labels), dim
|
338 |
-
labels = torch.cat((torch.full_like(labels[:, 0], self.language_id).unsqueeze(0).T, labels), dim
|
339 |
-
labels = torch.cat((torch.full_like(labels[:, 0], self.decoder_start_token_id).unsqueeze(0).T, labels), dim
|
340 |
|
341 |
batch["labels"] = labels
|
342 |
|
|
|
334 |
labels = labels[:, 1:]
|
335 |
|
336 |
# add start of sentence token to labels + language + task
|
337 |
+
labels = torch.cat((torch.full_like(labels[:, 0], self.task_id).unsqueeze(0).T, labels), dim=-1)
|
338 |
+
labels = torch.cat((torch.full_like(labels[:, 0], self.language_id).unsqueeze(0).T, labels), dim=-1)
|
339 |
+
labels = torch.cat((torch.full_like(labels[:, 0], self.decoder_start_token_id).unsqueeze(0).T, labels), dim=-1)
|
340 |
|
341 |
batch["labels"] = labels
|
342 |
|