Spaces:
Runtime error
Runtime error
import os | |
from abc import ABC, abstractmethod | |
from youtube_transcript_api import YouTubeTranscriptApi | |
from youtube_transcript_api.formatters import SRTFormatter, WebVTTFormatter | |
# import whisperx | |
import stable_whisper | |
from faster_whisper import WhisperModel | |
class Transcription(ABC): | |
def __init__(self, media_path, output_path, subtitle_format): | |
self.media_path = media_path | |
self.output_path = os.path.join(os.getcwd(), output_path) | |
self.filename = os.path.splitext(media_path)[0] | |
self.subtitle_format = subtitle_format | |
def generate_transcript(self): | |
pass | |
def save_transcript(self): | |
pass | |
class YouTubeTranscriptAPI(Transcription): | |
def __init__(self, url, media_path, output_path, subtitle_format='srt', transcript_language='en'): | |
super().__init__(media_path, output_path, subtitle_format) | |
self.url = url | |
self.video_id = url.split('v=')[1] | |
self.transcript_language = transcript_language | |
self.supported_subtitle_formats = ['srt', 'vtt'] | |
assert(self.subtitle_format.lower() in self.supported_subtitle_formats) | |
def get_available_transcripts(self): | |
''' | |
Returns a dictionary of available transcripts & their info | |
''' | |
# Getting List of all Available Transcripts | |
transcript_list = YouTubeTranscriptApi.list_transcripts(self.video_id) | |
# Converting to Available Transcripts to Dictionary | |
transcripts_info = dict() | |
for transcript in transcript_list: | |
transcript_info = { | |
'language': transcript.language, | |
'is_generated': transcript.is_generated, | |
'is_translatable': transcript.is_translatable | |
} | |
transcripts_info[transcript.language_code] = transcript_info | |
return transcripts_info | |
def generate_transcript(self): | |
''' | |
Generates the transcript for the media file | |
''' | |
self.transcript = YouTubeTranscriptApi.get_transcript(self.video_id, languages=[self.transcript_language]) | |
def save_transcript(self): | |
''' | |
Writes the transcript into file | |
''' | |
# Getting the Formatter | |
if self.subtitle_format == 'srt': | |
formatter = SRTFormatter() | |
elif self.subtitle_format == 'vtt': | |
formatter = WebVTTFormatter() | |
# Getting the Formatted Transcript | |
formatted_transcript = formatter.format_transcript(self.transcript) | |
# Writing the Formatted Transcript | |
file_path = f'{self.filename}.{self.subtitle_format}' | |
with open(file_path, 'w', encoding='utf-8') as transcript_file: | |
transcript_file.write(formatted_transcript) | |
return file_path | |
class Whisper(Transcription): | |
def __init__(self, media_path, output_path, subtitle_format, word_level): | |
super().__init__(media_path, output_path, subtitle_format) | |
self.word_level = word_level | |
self.supported_subtitle_formats = ['ass', 'srt', 'vtt'] | |
assert(self.subtitle_format.lower() in self.supported_subtitle_formats) | |
class FasterWhisper(Whisper): | |
def __init__(self, media_path, output_path, subtitle_format='srt', word_level=True): | |
super().__init__(media_path, output_path, subtitle_format, word_level) | |
self.model = WhisperModel("large-v2", device="cuda", compute_type="float16") | |
def generate_transcript(self): | |
''' | |
Generates the transcript for the media file | |
''' | |
all_text = [] | |
all_segments = [] | |
if self.word_level: | |
# Generating Word Level Transcript | |
segments, info = self.model.transcribe(self.media_path, word_timestamps=True) | |
# Converting to Dictionary | |
all_segments = [] | |
for segment in segments: | |
for word in segment.words: | |
all_text.append(word.word) | |
segment_info = { | |
'text': word.word, | |
'start': round(word.start, 2), | |
'end': round(word.end, 2) | |
} | |
all_segments.append(segment_info) | |
else: | |
# Generating Word Level Transcript | |
segments, info = self.model.transcribe(self.media_path, beam_size=5) | |
# Converting to Dictionary | |
for segment in segments: | |
all_text.append(segment.text) | |
segment_info = { | |
'text': segment.text, | |
'start': round(segment.start, 2), | |
'end': round(segment.end, 2) | |
} | |
all_segments.append(segment_info) | |
# Setting Transcript Properties | |
self.text = ' '.join(all_text) | |
self.language = info.language | |
self.segments = all_segments | |
# Returning Transcript Properties as Dictionary | |
transcript_dict = { | |
'language': self.language, | |
'text': self.text, | |
'segments': self.segments | |
} | |
return transcript_dict | |
def save_transcript(self, transcript, output_file): | |
''' | |
Writes the transcript into file | |
''' | |
# TODO: Can't seem to find any built-in methods for writing transcript | |
pass | |
class StableWhisper(Whisper): | |
def __init__(self, media_path, output_path, subtitle_format='srt', word_level=True): | |
super().__init__(media_path, output_path, subtitle_format, word_level) | |
self.model = stable_whisper.load_model('large-v2') | |
def generate_transcript(self): | |
''' | |
Generates the transcript for the media file | |
''' | |
# Generating Word Level Transcript | |
self.result = self.model.transcribe(self.media_path, word_timestamps=self.word_level) | |
# Converting to Dictionary | |
self.resultdict = self.result.to_dict() | |
# Formatting Dictionary | |
all_segments = [] | |
if self.word_level: | |
all_segments = [] | |
for segment in self.resultdict['segments']: | |
for word in segment['words']: | |
segment_info = { | |
'text': word['word'], | |
'start': round(word['start'], 2), | |
'end': round(word['end'], 2) | |
} | |
all_segments.append(segment_info) | |
else: | |
for segment in self.resultdict['segments']: | |
segment_info = { | |
'text': segment['text'], | |
'start': round(segment['start'], 2), | |
'end': round(segment['end'], 2) | |
} | |
all_segments.append(segment_info) | |
# Setting Transcript Properties | |
self.text = self.resultdict['text'] | |
self.language = self.resultdict['language'] | |
self.segments = all_segments | |
# Returning Transcript Properties as Dictionary | |
transcript_dict = { | |
'language': self.language, | |
'text': self.text, | |
'segments': self.segments | |
} | |
return transcript_dict | |
def save_transcript(self): | |
''' | |
Writes the transcript into file | |
''' | |
# Writing according to the Format | |
file_path = f'{self.filename}.{self.subtitle_format}' | |
if self.subtitle_format == 'ass': | |
self.result.to_ass(file_path, segment_level=True, word_level=self.word_level) | |
elif self.subtitle_format in ['srt', 'vtt']: | |
self.result.to_srt_vtt(file_path, segment_level=True, word_level=self.word_level) | |
return file_path |