Translator / tests /test_transcript_analysis.py
david
fix seg id error
a93e2b8
raw
history blame
3.06 kB
import unittest
from unittest.mock import MagicMock, patch
from transcribe.strategy import TranscriptStabilityAnalyzer, TranscriptChunk, TranscriptResult, SplitMode
class TestTranscriptStabilityAnalyzer(unittest.TestCase):
def setUp(self):
self.analyzer = TranscriptStabilityAnalyzer()
def test_first_chunk_yields_pending_text(self):
mock_chunk = MagicMock(spec=TranscriptChunk)
mock_chunk.join.return_value = "Hello world."
with patch.object(self.analyzer._transcript_history, 'previous_chunk', return_value=None):
results = list(self.analyzer.analysis(" ", mock_chunk, buffer_duration=5.0))
self.assertEqual(len(results), 1)
self.assertIsInstance(results[0], TranscriptResult)
self.assertIn("Hello", results[0].context)
def test_short_buffer_with_high_similarity_and_end_sentence(self):
curr_chunk = MagicMock(spec=TranscriptChunk)
curr_first = MagicMock()
curr_rest = [MagicMock()]
prev_chunk = MagicMock(spec=TranscriptChunk)
prev_first = MagicMock()
# Mock the items attribute
curr_chunk.items = [curr_first, curr_rest[0]] # Ensure it is iterable
curr_chunk.get_split_first_rest.return_value = (curr_first, curr_rest)
prev_chunk.get_split_first_rest.return_value = (prev_first, [])
curr_first.compare.return_value = 0.85
curr_first.is_end_sentence.return_value = True
curr_first.has_punctuation.return_value = True
curr_first.join.return_value = "This is a test sentence."
curr_first.get_buffer_index.return_value = 0
curr_rest[0].join.return_value = " Continuing..."
with patch.object(self.analyzer._transcript_history, 'previous_chunk', return_value=prev_chunk):
with patch.object(self.analyzer._transcript_history, 'add'):
results = list(self.analyzer.analysis(" ", curr_chunk, buffer_duration=5.0))
self.assertGreaterEqual(len(results), 1)
self.assertTrue(any(r.is_end_sentence for r in results))
self.assertTrue(any("test" in r.context for r in results))
def test_long_buffer_triggers_commit(self):
chunk1 = MagicMock()
chunk2 = MagicMock()
chunk3 = MagicMock()
chunk1.join.return_value = "Hello."
chunk2.join.return_value = "How are"
chunk3.join.return_value = " you?"
mock_chunk = MagicMock(spec=TranscriptChunk)
mock_chunk.split_by.return_value = [chunk1, chunk2, chunk3]
mock_chunk.get_buffer_index.return_value = 0
with patch.object(self.analyzer._transcript_history, 'previous_chunk', return_value=MagicMock()):
with patch.object(self.analyzer._transcript_history, 'add'):
results = list(self.analyzer.analysis(" ", mock_chunk, buffer_duration=15.0))
self.assertTrue(any(r.is_end_sentence for r in results))
self.assertTrue(any("Hello" in r.context for r in results))
if __name__ == '__main__':
unittest.main()