whisper-youtube-2-hf_dataset / test /test_datapipeline.py
juancopi81's picture
Duplicate from Whispering-GPT/whisper-youtube-2-hf_dataset
7288748
import os
import pytest
import sqlite3
from pathlib import Path
from youtube_transcriber.datapipeline import DataPipeline
from youtube_transcriber.datapipeline import create_hardcoded_data_pipeline
from youtube_transcriber.preprocessing.youtubevideopreprocessor import YoutubeVideoPreprocessor
from youtube_transcriber.loading.loaderiterator import LoaderIterator
from youtube_transcriber.loading.serialization import JsonSerializer
from youtube_transcriber.transforming.addtitletransform import AddTitleTransform
from youtube_transcriber.transforming.adddescriptiontransform import AddDescriptionTransform
from youtube_transcriber.transforming.whispertransform import WhisperTransform
from youtube_transcriber.transforming.batchtransformer import BatchTransformer
from youtube_transcriber.storing.sqlitebatchvideostorer import SQLiteBatchVideoStorer
from youtube_transcriber.storing.sqlitecontextmanager import SQLiteContextManager
from youtube_transcriber.storing.createdb import create_db
@pytest.fixture
def expected_db_output():
return [
("Tquotes",
"https://www.youtube.com/watch?v=NSkoGZ8J1Ag",
"Steve Jobs quotes Bob Dylan",
" Good morning. Good morning and welcome to Apple's 1984 annual shareholders meeting. I'd like to open the meeting with a part of an old poem about a 20-year-old poem by Dylan. That's Bob Dylan. Come writers and critics who prophesize with your pens and keep your eyes wide, the chance won't come again. And don't speak too soon for the wheels still in spin. And there's no telling who that it's naming. For the loser now will be later to win for the times they are a change in. Now."),
("changminjen",
"https://www.youtube.com/watch?v=Ak516vtDTEA",
"My allegiance is to the Republic, to democracy!",
" I have brought peace, freedom, justice and security to my new empire. Your new empire don't make me kill you. Anakin, my allegiance is to the Republic, to democracy! If you're not with me, then you're my enemy. Only a Sith deals an absolute.")
]
@pytest.fixture
def data_pipeline():
loader_iterator = LoaderIterator(JsonSerializer(), 2)
batch_transformer = BatchTransformer([AddTitleTransform(),
AddDescriptionTransform(),
WhisperTransform()])
video_storer = SQLiteBatchVideoStorer()
sqlite_context_manager = SQLiteContextManager("dummy.db")
return DataPipeline(loader_iterator,
batch_transformer,
video_storer,
sqlite_context_manager)
def test_datapipeline_init():
data_pipeline = DataPipeline("loader_iterator",
"transformer",
"storer",
"context")
assert type(data_pipeline) == DataPipeline
assert data_pipeline.loader_iterator == "loader_iterator"
assert data_pipeline.batch_transformer == "transformer"
assert data_pipeline.storer == "storer"
assert data_pipeline.sqlite_context_manager == "context"
def test_process_files(data_pipeline, expected_db_output):
test_folder = Path.home()/"whisper_gpt_pipeline/youtube_transcriber/test"
files = [Path(test_folder/"files/6.json"), Path(test_folder/"files/7.json")]
try:
create_db("dummy.db")
connection = sqlite3.connect("dummy.db")
cursor = connection.cursor()
data_pipeline.process(files)
cursor.execute("SELECT CHANNEL_NAME, URL, TITLE, TRANSCRIPTION FROM VIDEO")
videos = cursor.fetchall()
for i in range(len(videos)):
assert videos[i][0] == expected_db_output[i][0]
assert videos[i][1] == expected_db_output[i][1]
assert videos[i][2] == expected_db_output[i][2]
assert videos[i][3] == expected_db_output[i][3]
finally:
os.remove("dummy.db")
def test_process_video_batch(data_pipeline, expected_db_output):
video_data = [
{
"channel_name": "Tquotes",
"url": "https://www.youtube.com/watch?v=NSkoGZ8J1Ag",
},
{
"channel_name": "changminjen",
"url": "https://www.youtube.com/watch?v=Ak516vtDTEA",
}
]
try:
create_db("dummy.db")
connection = sqlite3.connect("dummy.db")
cursor = connection.cursor()
data_pipeline._process_video_batch(cursor, video_data)
cursor.execute("SELECT CHANNEL_NAME, URL, TITLE, TRANSCRIPTION FROM VIDEO")
videos = cursor.fetchall()
for i in range(len(videos)):
assert videos[i][0] == expected_db_output[i][0]
assert videos[i][1] == expected_db_output[i][1]
assert videos[i][2] == expected_db_output[i][2]
assert videos[i][3] == expected_db_output[i][3]
finally:
os.remove("dummy.db")
def test_hardcoded_data_pipeline_is_instantiated():
data_pipeline = create_hardcoded_data_pipeline()
assert type(data_pipeline) == DataPipeline