|
import math |
|
import os |
|
import argparse |
|
import sqlite3 |
|
import shutil |
|
import uuid |
|
|
|
from datasets import Dataset, concatenate_datasets |
|
import gradio as gr |
|
import torch |
|
|
|
from storing.createdb import create_db |
|
from preprocessing.youtubevideopreprocessor import YoutubeVideoPreprocessor |
|
from loading.serialization import JsonSerializer |
|
from utils import nest_list, is_google_colab |
|
from datapipeline import create_hardcoded_data_pipeline |
|
from threadeddatapipeline import ThreadedDataPipeline |
|
from dataset.hf_dataset import HFDataset |
|
from huggingface_hub import DatasetCard |
|
|
|
NUM_THREADS = 1 |
|
|
|
|
|
is_colab = is_google_colab() |
|
colab_instruction = "" if is_colab else """ |
|
<p>You can skip the queue using Colab: |
|
<a href="https://colab.research.google.com/drive/1zNRnX1lXjlGtBMW8U8S9t4eY1cA0D6lm?usp=sharing"> |
|
<img data-canonical-src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab" src="https://colab.research.google.com/assets/colab-badge.svg"></a></p>""" |
|
device_print = "GPU 🔥" if torch.cuda.is_available() else "CPU 🥶" |
|
|
|
def numvideos_type(x): |
|
x = int(x) |
|
if x > 12: |
|
raise argparse.ArgumentTypeError("Maximum number of videos is 12") |
|
if x < 1: |
|
raise argparse.ArgumentTypeError("Minimum number of videos is 12") |
|
return x |
|
|
|
def parse_args(): |
|
parser = argparse.ArgumentParser(usage="[arguments] --channel_name --num_videos", |
|
description="Program to transcribe YouTube videos.") |
|
parser.add_argument("--channel_name", |
|
type=str, |
|
required=True, |
|
help="Name of the channel from where the videos will be transcribed") |
|
parser.add_argument("--num_videos", |
|
type=numvideos_type, |
|
required=True, |
|
help="Number of videos (min. 1 - max. 12) to transcribe from --channel_name") |
|
parser.add_argument("--hf_token", |
|
type=str, |
|
required=True, |
|
help="Token of your HF account. You need a HF account to upload the dataset") |
|
parser.add_argument("--hf_dataset_identifier", |
|
type=str, |
|
required=True, |
|
help="The ID of the repository to push to in the following format: <user>/<dataset_name> or <org>/<dataset_name>. Also accepts <dataset_name>, which will default to the namespace of the logged-in user.") |
|
parser.add_argument("--whisper_model", |
|
type=str, |
|
required=True, |
|
help="Select one of the available whispers models", |
|
choices=["tiny", "base", "small", "medium", "large"]) |
|
|
|
args = parser.parse_args() |
|
return args |
|
|
|
def transcribe(mode: str, |
|
channel_name: str, |
|
num_videos: int, |
|
hf_token: str, |
|
hf_dataset_identifier: str, |
|
whisper_model: str) -> str: |
|
|
|
|
|
unique_filename = str(uuid.uuid4()) |
|
database_name = unique_filename +".db" |
|
|
|
create_db(database_name) |
|
|
|
|
|
yt_video_processor = YoutubeVideoPreprocessor(mode=mode, |
|
serializer=JsonSerializer()) |
|
|
|
hf_dataset = HFDataset(hf_dataset_identifier) |
|
videos_downloaded = hf_dataset.list_of_ids |
|
|
|
paths, dataset_folder = yt_video_processor.preprocess(channel_name, |
|
num_videos, |
|
videos_downloaded) |
|
nested_listed_length = math.ceil(len(paths) / NUM_THREADS) |
|
nested_paths = nest_list(paths, nested_listed_length) |
|
data_pipelines = [create_hardcoded_data_pipeline(database_name, whisper_model) for i in range(NUM_THREADS)] |
|
|
|
|
|
threads = [] |
|
for data_pipeline, thread_paths in zip(data_pipelines, nested_paths): |
|
threads.append(ThreadedDataPipeline(data_pipeline, thread_paths)) |
|
for thread in threads: |
|
thread.start() |
|
for thread in threads: |
|
thread.join() |
|
|
|
|
|
connection = sqlite3.connect(database_name) |
|
cursor = connection.cursor() |
|
cursor.execute("SELECT CHANNEL_NAME, URL, TITLE, DESCRIPTION, TRANSCRIPTION, SEGMENTS FROM VIDEO") |
|
videos = cursor.fetchall() |
|
|
|
num_new_videos = len(videos) |
|
|
|
dataset = Dataset.from_sql("SELECT CHANNEL_NAME, URL, TITLE, DESCRIPTION, TRANSCRIPTION, SEGMENTS FROM VIDEO", connection) |
|
|
|
if (hf_dataset.exist==True) and (hf_dataset.is_empty==False): |
|
dataset_to_upload = concatenate_datasets([hf_dataset.dataset["train"], dataset]) |
|
else: |
|
dataset_to_upload = dataset |
|
|
|
dataset_to_upload.push_to_hub(hf_dataset_identifier, token=hf_token) |
|
card = DatasetCard.load(hf_dataset_identifier) |
|
card.data.tags = ["whisper", "whispering", whisper_model] |
|
card.data.task_categories = ["automatic-speech-recognition"] |
|
card.push_to_hub(hf_dataset_identifier, token=hf_token) |
|
|
|
|
|
connection.close() |
|
|
|
|
|
os.remove(database_name) |
|
try: |
|
shutil.rmtree(dataset_folder) |
|
except OSError as e: |
|
print("Error: %s : %s" % (dataset_folder, e.strerror)) |
|
|
|
return f"Dataset created or updated at {hf_dataset_identifier}. {num_new_videos} samples were added" |
|
|
|
with gr.Blocks() as demo: |
|
md = """# Use Whisper to create a HF dataset from YouTube videos |
|
This space will let you create a HF dataset by transcribing videos from YouTube. |
|
Enter the name of the YouTube channel or the URL of a YouTube playlist (in the form https://www.youtube.com/playlist?list=****), |
|
and the repo_id of the dataset (you need a HuggingFace account). |
|
If the dataset already exists, it will only transcribe videos that are not in the dataset. |
|
If it does not exists, it will create the dataset. For using this demo, you need a |
|
[Hugging Face token](https://huggingface.co/settings/tokens) with write role. Learn more about [tokens](https://huggingface.co/docs/hub/security-tokens). |
|
""" |
|
gr.Markdown(md) |
|
gr.HTML( |
|
f""" |
|
<p style="margin-bottom: 10px; font-size: 94%"> |
|
Running on <b>{device_print}</b>{(" in a <b>Google Colab</b>." if is_colab else "")} |
|
</p> |
|
""" |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
whisper_model = gr.Radio([ |
|
"tiny", "base", "small", "medium", "large" |
|
], label="Whisper model", value="base") |
|
|
|
mode = gr.Radio([ |
|
"channel_name", "playlist" |
|
], label="Get the videos from:", value="channel_name") |
|
channel_name = gr.Textbox(label="YouTube Channel or Playlist URL", |
|
placeholder="Enter the name of the YouTube channel or the URL of the playlist") |
|
num_videos = gr.Slider(1, 20000, value=4, step=1, label="Number of videos") |
|
hf_token = gr.Textbox(placeholder="Your HF write access token", type="password") |
|
hf_dataset_identifier = gr.Textbox(label = 'Dataset Name', |
|
placeholder = "Enter in the format <username>/<repo_name>") |
|
submit_btn = gr.Button("Submit") |
|
|
|
with gr.Column(): |
|
output = gr.Text() |
|
|
|
submit_btn.click(fn=transcribe, inputs=[mode, |
|
channel_name, |
|
num_videos, |
|
hf_token, |
|
hf_dataset_identifier, |
|
whisper_model], outputs=[output]) |
|
gr.Markdown(''' |
|
![visitors](https://visitor-badge.glitch.me/badge?page_id=juancopi81.whisper-youtube-2-hf_dataset) |
|
''') |
|
|
|
if not is_colab: |
|
demo.queue(concurrency_count=1) |
|
demo.launch(debug=True, share=is_colab) |