|
import unittest |
|
from transcribe.strategy import TranscriptChunk, TranscriptToken, SplitMode |
|
|
|
class TestTranscriptChunk(unittest.TestCase): |
|
|
|
def setUp(self): |
|
self.tokens = [ |
|
TranscriptToken(text="Hello", t0=0, t1=100), |
|
TranscriptToken(text=",", t0=100, t1=200), |
|
TranscriptToken(text="world", t0=200, t1=300), |
|
TranscriptToken(text=".", t0=300, t1=400), |
|
] |
|
self.chunk = TranscriptChunk(items=self.tokens, separator=" ") |
|
|
|
def test_split_by_punctuation(self): |
|
chunks = self.chunk.split_by(SplitMode.PUNCTUATION) |
|
self.assertEqual(len(chunks), 3) |
|
self.assertEqual(chunks[0].join(), "Hello ,") |
|
self.assertEqual(chunks[1].join(), "world .") |
|
self.assertEqual(chunks[2].join(), "") |
|
|
|
def test_get_split_first_rest(self): |
|
first, rest = self.chunk.get_split_first_rest(SplitMode.PUNCTUATION) |
|
self.assertEqual(first.join(), "Hello ,") |
|
self.assertEqual(len(rest), 2) |
|
self.assertEqual(rest[0].join(), "world .") |
|
self.assertEqual(rest[1].join(), "") |
|
|
|
def test_punctuation_numbers(self): |
|
self.assertEqual(self.chunk.puncation_numbers(), 2) |
|
|
|
def test_length(self): |
|
self.assertEqual(self.chunk.length(), 4) |
|
|
|
def test_join(self): |
|
self.assertEqual(self.chunk.join(), "Hello , world .") |
|
|
|
def test_compare(self): |
|
other_chunk = TranscriptChunk(items=[ |
|
TranscriptToken(text="Hello", t0=0, t1=100), |
|
TranscriptToken(text="!", t0=100, t1=200), |
|
], separator=" ") |
|
similarity = self.chunk.compare(other_chunk) |
|
self.assertTrue(0 < similarity < 1) |
|
|
|
def test_has_punctuation(self): |
|
self.assertTrue(self.chunk.has_punctuation()) |
|
|
|
def test_get_buffer_index(self): |
|
|
|
self.assertEqual(self.chunk.get_buffer_index(), 64000) |
|
|
|
def test_is_end_sentence(self): |
|
self.assertTrue(self.chunk.is_end_sentence()) |
|
|
|
if __name__ == '__main__': |
|
unittest.main() |
|
|