File size: 1,987 Bytes
d5c679f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import argparse
import faulthandler
import gc
import os
import tempfile

import torch
import whisperx

from whisperx.asr import FasterWhisperPipeline



def get_device():
    device = "cuda" if torch.cuda.is_available() else "cpu"
    # device = "mps" if torch.backends.mps.is_available() else device
    return device


def generate_subtitles_from_audio(
        audio_file_path: str,
        model: FasterWhisperPipeline,
        batch_size: int = 8
):
    audio = whisperx.load_audio(audio_file_path)
    result = model.transcribe(audio, batch_size=batch_size, language="ru", )
    return result


def generate_subtitles_from_video(
        video_path: str,
        model_name: str = "base",
        batch_size: int = 8,
        compute_type: str = "int8",
):
    _, audio_file = tempfile.mkstemp()

    device = get_device()


    print("Loading model:")
    model = whisperx.load_model(model_name, device, compute_type=compute_type, language="ru")
    print("Parsing audio:")
    parse_audio(video_path, audio_file)
    print("Generating subtitles:")
    result = generate_subtitles_from_audio(audio_file, model, batch_size=batch_size)

    os.remove(audio_file)
    del model
    gc.collect()
    return result


def add_whisper_args(arg_parser: argparse.ArgumentParser):
    arg_parser.add_argument("video", help="video file")
    arg_parser.add_argument("--compute_type", help="Base type for model", default="int8",
                            choices=["int8", "float16", "float32"])
    arg_parser.add_argument("--whisper_model", help="model to use", default="large-v2")
    arg_parser.add_argument("--batch_size", help="Batch size for inference", default=4, type=int)


if __name__ == "__main__":
    faulthandler.enable()
    parser = argparse.ArgumentParser(description="Get video subtitles from a video")
    add_whisper_args(parser)
    args = parser.parse_args()
    print(generate_subtitles_from_video(args.video, args.whisper_model, args.batch_size, args.compute_type))