File size: 8,167 Bytes
7288748
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
import math
import os
import argparse
import sqlite3
import shutil
import uuid

from datasets import Dataset, concatenate_datasets
import gradio as gr
import torch

from storing.createdb import create_db
from preprocessing.youtubevideopreprocessor import YoutubeVideoPreprocessor
from loading.serialization import JsonSerializer
from utils import nest_list, is_google_colab
from datapipeline import create_hardcoded_data_pipeline
from threadeddatapipeline import ThreadedDataPipeline
from dataset.hf_dataset import HFDataset
from huggingface_hub import DatasetCard

NUM_THREADS = 1

# Detect if code is running in Colab
is_colab = is_google_colab()
colab_instruction = "" if is_colab else """
<p>You can skip the queue using Colab: 
<a href="https://colab.research.google.com/drive/1zNRnX1lXjlGtBMW8U8S9t4eY1cA0D6lm?usp=sharing">
<img data-canonical-src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab" src="https://colab.research.google.com/assets/colab-badge.svg"></a></p>"""
device_print = "GPU 🔥" if torch.cuda.is_available() else "CPU 🥶"

def numvideos_type(x):
    x = int(x)
    if x > 12:
        raise argparse.ArgumentTypeError("Maximum number of videos is 12")
    if x < 1:
        raise argparse.ArgumentTypeError("Minimum number of videos is 12")
    return x

def parse_args():
    parser = argparse.ArgumentParser(usage="[arguments] --channel_name --num_videos",
                                     description="Program to transcribe YouTube videos.")
    parser.add_argument("--channel_name", 
                        type=str, 
                        required=True,
                        help="Name of the channel from where the videos will be transcribed")
    parser.add_argument("--num_videos", 
                        type=numvideos_type, 
                        required=True,
                        help="Number of videos (min. 1 - max. 12) to transcribe from --channel_name")
    parser.add_argument("--hf_token", 
                        type=str, 
                        required=True,
                        help="Token of your HF account. You need a HF account to upload the dataset")
    parser.add_argument("--hf_dataset_identifier", 
                        type=str, 
                        required=True,
                        help="The ID of the repository to push to in the following format: <user>/<dataset_name> or <org>/<dataset_name>. Also accepts <dataset_name>, which will default to the namespace of the logged-in user.")
    parser.add_argument("--whisper_model", 
                        type=str, 
                        required=True,
                        help="Select one of the available whispers models",
                        choices=["tiny", "base", "small", "medium", "large"])
    
    args = parser.parse_args()
    return args

def transcribe(mode: str,
               channel_name: str,
               num_videos: int,
               hf_token: str,
               hf_dataset_identifier: str,
               whisper_model: str) -> str:
    
    # Create a unique name for the database
    unique_filename = str(uuid.uuid4())
    database_name = unique_filename +".db"

    create_db(database_name)
    
    # Create necessary resources
    yt_video_processor = YoutubeVideoPreprocessor(mode=mode,
                                                  serializer=JsonSerializer()) # TODO: Let user select serializer
    
    hf_dataset = HFDataset(hf_dataset_identifier)
    videos_downloaded = hf_dataset.list_of_ids
    
    paths, dataset_folder = yt_video_processor.preprocess(channel_name,
                                                          num_videos,
                                                          videos_downloaded)
    nested_listed_length = math.ceil(len(paths) / NUM_THREADS)
    nested_paths = nest_list(paths, nested_listed_length)
    data_pipelines = [create_hardcoded_data_pipeline(database_name, whisper_model) for i in range(NUM_THREADS)]
    
    # Run pipelines in multiple threads
    threads = []
    for data_pipeline, thread_paths in zip(data_pipelines, nested_paths):
        threads.append(ThreadedDataPipeline(data_pipeline, thread_paths))
    for thread in threads:
        thread.start()
    for thread in threads:
        thread.join()
    
    # Fetch entries and print them
    connection = sqlite3.connect(database_name)
    cursor = connection.cursor()
    cursor.execute("SELECT CHANNEL_NAME, URL, TITLE, DESCRIPTION, TRANSCRIPTION, SEGMENTS FROM VIDEO")
    videos = cursor.fetchall()
    
    num_new_videos = len(videos)
    
    dataset = Dataset.from_sql("SELECT CHANNEL_NAME, URL, TITLE, DESCRIPTION, TRANSCRIPTION, SEGMENTS FROM VIDEO", connection)
    
    if (hf_dataset.exist==True) and (hf_dataset.is_empty==False):
        dataset_to_upload = concatenate_datasets([hf_dataset.dataset["train"], dataset])
    else:
        dataset_to_upload = dataset
    
    dataset_to_upload.push_to_hub(hf_dataset_identifier, token=hf_token)
    card = DatasetCard.load(hf_dataset_identifier)
    card.data.tags = ["whisper", "whispering", whisper_model]
    card.data.task_categories = ["automatic-speech-recognition"]
    card.push_to_hub(hf_dataset_identifier, token=hf_token)
    
    # Close connection
    connection.close()
    
    # Remove db
    os.remove(database_name)
    try:
        shutil.rmtree(dataset_folder)
    except OSError as e:
        print("Error: %s : %s" % (dataset_folder, e.strerror))
        
    return f"Dataset created or updated at {hf_dataset_identifier}. {num_new_videos} samples were added"

with gr.Blocks() as demo:
    md = """# Use Whisper to create a HF dataset from YouTube videos
    This space will let you create a HF dataset by transcribing videos from YouTube.
    Enter the name of the YouTube channel or the URL of a YouTube playlist (in the form https://www.youtube.com/playlist?list=****), 
    and the repo_id of the dataset (you need a HuggingFace account).
    If the dataset already exists, it will only transcribe videos that are not in the dataset. 
    If it does not exists, it will create the dataset. For using this demo, you need a 
    [Hugging Face token](https://huggingface.co/settings/tokens) with write role. Learn more about [tokens](https://huggingface.co/docs/hub/security-tokens).
    """
    gr.Markdown(md)
    gr.HTML(
        f"""
        <p style="margin-bottom: 10px; font-size: 94%">
          Running on <b>{device_print}</b>{(" in a <b>Google Colab</b>." if is_colab else "")}
        </p>
        """
    )

    with gr.Row():
        with gr.Column():
            whisper_model = gr.Radio([
                "tiny", "base", "small", "medium", "large"
            ], label="Whisper model", value="base")

            mode = gr.Radio([
                "channel_name", "playlist"
            ], label="Get the videos from:", value="channel_name")
            channel_name = gr.Textbox(label="YouTube Channel or Playlist URL",
                                      placeholder="Enter the name of the YouTube channel or the URL of the playlist")
            num_videos = gr.Slider(1, 20000, value=4, step=1, label="Number of videos")
            hf_token = gr.Textbox(placeholder="Your HF write access token", type="password")
            hf_dataset_identifier = gr.Textbox(label = 'Dataset Name',
                                               placeholder = "Enter in the format <username>/<repo_name>")
            submit_btn = gr.Button("Submit")

        with gr.Column():
            output = gr.Text()
        
        submit_btn.click(fn=transcribe, inputs=[mode,
                                                channel_name,
                                                num_videos,
                                                hf_token,
                                                hf_dataset_identifier,
                                                whisper_model], outputs=[output])
    gr.Markdown('''
      ![visitors](https://visitor-badge.glitch.me/badge?page_id=juancopi81.whisper-youtube-2-hf_dataset)
    ''')

if not is_colab:
    demo.queue(concurrency_count=1)
demo.launch(debug=True, share=is_colab)