Spaces:
Runtime error
Runtime error
import os | |
import re | |
import time | |
import torch | |
import scrapetube | |
from pytube import YouTube | |
from faster_whisper import WhisperModel | |
from tqdm import tqdm | |
# Available models: | |
# tiny.en, tiny, base.en, base, small.en, small, medium.en, medium | |
# large-v1, large-v2, large-v3, large | |
MODEL_NAME = "large-v3" | |
AUDIO_SAVE_PATH = 'datasets/huggingface_audio/' | |
TRANSCRIPTS_SAVE_PATH = 'datasets/huggingface_audio_transcribed/' | |
if torch.cuda.is_available(): | |
# requires: conda install -c anaconda cudnn | |
print(f"Using {MODEL_NAME} on GPU and float16") | |
model = WhisperModel(MODEL_NAME, device="cuda", compute_type="float16", device_index=[5]) | |
else: | |
print(f"Using {MODEL_NAME} on CPU and int8") | |
model = WhisperModel(MODEL_NAME, device="cpu", compute_type="int8") | |
def replace_unallowed_chars(filename: str) -> str: | |
unallowed_chars = [' ', '/', '\\', ':', '*', '?', '"', '<', '>', '|'] | |
for char in unallowed_chars: | |
filename = filename.replace(char, '_') | |
return filename | |
def get_videos_urls(channel_url: str) -> list[str]: | |
videos = scrapetube.get_channel(channel_url=channel_url) | |
return [ | |
f"https://www.youtube.com/watch?v={video['videoId']}" | |
for video in videos | |
] | |
def get_audio_from_video(video_url: str, save_path: str) -> tuple[str, int, str, int]: | |
yt = YouTube(video_url) | |
if check_if_file_exists(yt.title, save_path): | |
print(f'Audio already exists for: {yt.title}') | |
return (video_url, yt.title.replace(" ", "_")+".mp3", yt.title, yt.length) | |
else: | |
print(f'Downloading audio for: {yt.title}') | |
video = yt.streams.filter(only_audio=True).first() | |
out_file = video.download(output_path=save_path) | |
base, ext = os.path.splitext(out_file) | |
new_filename = save_path + replace_unallowed_chars(yt.title) + '.mp3' | |
print(f'Saving audio to: {new_filename}') | |
os.rename(out_file, new_filename) | |
print(f'Video length: {yt.length} seconds') | |
return (video_url, new_filename, yt.title, yt.length) | |
def check_if_file_exists(filename: str, save_path: str) -> bool: | |
title = filename.replace(' ', '_') | |
return any([ | |
title in filename_ | |
for filename_ in os.listdir(save_path) | |
]) | |
def transcript_from_audio(audio_path: str) -> dict[str, list[str]]: | |
segments, info = model.transcribe(audio_path, beam_size=10) | |
return list(segments) | |
def process_text(text: str) -> str: | |
text = text.strip() | |
text = re.sub('\s+', ' ', text) | |
return text | |
def merge_transcripts_segements( | |
segments: list[str], | |
file_title: str, | |
num_segments_to_merge: int = 5, | |
) -> dict[str, list[str]]: | |
merged_segments = {} | |
temp_text = '' | |
start_time = None | |
end_time = None | |
for i, segment in enumerate(segments): | |
if i % num_segments_to_merge == 0: | |
start_time = segment.start | |
end_time = segment.end | |
temp_text += segment.text + ' ' | |
if (i + 1) % num_segments_to_merge == 0 or i == len(segments) - 1: | |
key = f'{start_time:.2f}_{end_time:.2f}' | |
merged_segments[key] = process_text(temp_text) | |
temp_text = '' | |
return merged_segments | |
def main(): | |
if not os.path.exists(AUDIO_SAVE_PATH): | |
os.makedirs(AUDIO_SAVE_PATH) | |
if not os.path.exists(TRANSCRIPTS_SAVE_PATH): | |
os.makedirs(TRANSCRIPTS_SAVE_PATH) | |
print('Getting videos urls') | |
videos_urls = get_videos_urls('https://www.youtube.com/@HuggingFace') | |
print('Downloading audio files') | |
audio_data = [] | |
for video_url in tqdm(videos_urls): | |
try: | |
audio_data.append( | |
get_audio_from_video(video_url, save_path=AUDIO_SAVE_PATH) | |
) | |
except Exception as e: | |
print(f'Error downloading video: {video_url}') | |
print(e) | |
print('Transcribing audio files') | |
for video_url, filename, title, audio_length in tqdm(audio_data): | |
if check_if_file_exists(title, TRANSCRIPTS_SAVE_PATH): | |
print(f'Transcript already exists for: {title}') | |
continue | |
try: | |
print(f'Transcribing: {title}') | |
start_time = time.time() | |
segments = transcript_from_audio(filename) | |
print(f'Transcription took: {time.time() - start_time:.1f} seconds') | |
merged_segments = merge_transcripts_segements( | |
segments, | |
title, | |
num_segments_to_merge=10 | |
) | |
# save transcripts to separate files | |
title = replace_unallowed_chars(title) | |
for segment, text in merged_segments.items(): | |
with open(f'{TRANSCRIPTS_SAVE_PATH}{title}_{segment}.txt', 'w') as f: | |
video_url_with_time = f'{video_url}&t={float(segment.split("_")[0]):.0f}' | |
f.write(f'source: {video_url_with_time}\n\n' + text) | |
except Exception as e: | |
print(f'Error transcribing: {title}') | |
print(e) | |
if __name__ == '__main__': | |
main() | |