Update modeling_moonshine.py
Browse files- modeling_moonshine.py +1 -1
modeling_moonshine.py
CHANGED
@@ -447,7 +447,7 @@ class MoonshineModelTorch(nn.Module):
|
|
447 |
sample = logits[:, -1].argmax(dim=-1, keepdim=True)
|
448 |
seq = torch.cat((seq, sample), dim=-1)
|
449 |
|
450 |
-
seq_len = int(src.shape[-1] *
|
451 |
while sample != eot_token and len(seq.flatten()) <= seq_len:
|
452 |
vals = self.decoder(
|
453 |
seq,
|
|
|
447 |
sample = logits[:, -1].argmax(dim=-1, keepdim=True)
|
448 |
seq = torch.cat((seq, sample), dim=-1)
|
449 |
|
450 |
+
seq_len = int(src.shape[-1] * 6.5 / 16000)
|
451 |
while sample != eot_token and len(seq.flatten()) <= seq_len:
|
452 |
vals = self.decoder(
|
453 |
seq,
|