import textwrap import unicodedata import re import zlib from typing import Iterator, TextIO import tqdm import urllib3 def exact_div(x, y): assert x % y == 0 return x // y def str2bool(string): str2val = {"True": True, "False": False} if string in str2val: return str2val[string] else: raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}") def optional_int(string): return None if string == "None" else int(string) def optional_float(string): return None if string == "None" else float(string) def compression_ratio(text) -> float: return len(text) / len(zlib.compress(text.encode("utf-8"))) def format_timestamp(seconds: float, always_include_hours: bool = False, fractionalSeperator: str = '.'): assert seconds >= 0, "non-negative timestamp expected" milliseconds = round(seconds * 1000.0) hours = milliseconds // 3_600_000 milliseconds -= hours * 3_600_000 minutes = milliseconds // 60_000 milliseconds -= minutes * 60_000 seconds = milliseconds // 1_000 milliseconds -= seconds * 1_000 hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else "" return f"{hours_marker}{minutes:02d}:{seconds:02d}{fractionalSeperator}{milliseconds:03d}" def write_txt(transcript: Iterator[dict], file: TextIO): for segment in transcript: print(segment['text'].strip(), file=file, flush=True) def write_vtt(transcript: Iterator[dict], file: TextIO, maxLineWidth=None): print("WEBVTT\n", file=file) for segment in transcript: text = process_text(segment['text'], maxLineWidth).replace('-->', '->') print( f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n" f"{text}\n", file=file, flush=True, ) def write_srt(transcript: Iterator[dict], file: TextIO, maxLineWidth=None): """ Write a transcript to a file in SRT format. Example usage: from pathlib import Path from whisper.utils import write_srt result = transcribe(model, audio_path, temperature=temperature, **args) # save SRT audio_basename = Path(audio_path).stem with open(Path(output_dir) / (audio_basename + ".srt"), "w", encoding="utf-8") as srt: write_srt(result["segments"], file=srt) """ for i, segment in enumerate(transcript, start=1): text = process_text(segment['text'].strip(), maxLineWidth).replace('-->', '->') # write srt lines print( f"{i}\n" f"{format_timestamp(segment['start'], always_include_hours=True, fractionalSeperator=',')} --> " f"{format_timestamp(segment['end'], always_include_hours=True, fractionalSeperator=',')}\n" f"{text}\n", file=file, flush=True, ) def process_text(text: str, maxLineWidth=None): if (maxLineWidth is None or maxLineWidth < 0): return text lines = textwrap.wrap(text, width=maxLineWidth, tabsize=4) return '\n'.join(lines) def slugify(value, allow_unicode=False): """ Taken from https://github.com/django/django/blob/master/django/utils/text.py Convert to ASCII if 'allow_unicode' is False. Convert spaces or repeated dashes to single dashes. Remove characters that aren't alphanumerics, underscores, or hyphens. Convert to lowercase. Also strip leading and trailing whitespace, dashes, and underscores. """ value = str(value) if allow_unicode: value = unicodedata.normalize('NFKC', value) else: value = unicodedata.normalize('NFKD', value).encode('ascii', 'ignore').decode('ascii') value = re.sub(r'[^\w\s-]', '', value.lower()) return re.sub(r'[-\s]+', '-', value).strip('-_') def download_file(url: str, destination: str): with urllib3.request.urlopen(url) as source, open(destination, "wb") as output: with tqdm( total=int(source.info().get("Content-Length")), ncols=80, unit="iB", unit_scale=True, unit_divisor=1024, ) as loop: while True: buffer = source.read(8192) if not buffer: break output.write(buffer) loop.update(len(buffer))