File size: 3,679 Bytes
7c2d6fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e404b97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7c2d6fa
 
 
 
e404b97
7c2d6fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e404b97
7c2d6fa
 
 
 
 
 
 
e404b97
7c2d6fa
 
e404b97
7c2d6fa
e404b97
7c2d6fa
 
 
 
 
e404b97
7c2d6fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import os 
from peft import PeftModel, PeftConfig
import torch
from torch.cuda.amp import autocast
from torch.utils.data import DataLoader
from tqdm import tqdm
import transformers
from transformers import WhisperFeatureExtractor, WhisperTokenizer, WhisperProcessor, WhisperForConditionalGeneration, GenerationConfig
from transformers import pipeline, AutomaticSpeechRecognitionPipeline
import argparse
import time
from pathlib import Path
import json
import pandas as pd
import csv 

def prepare_pipeline(model_path, generate_kwargs):
    """Prepare a pipeline for ASR inference
    Args:
        model_path (str): path to model directory / huggingface model name
        generate_kwargs (dict): options to pass to pipeline
    Returns:
        pipeline: ASR pipeline
    """
    processor = WhisperProcessor.from_pretrained(model_path)

    asr_pipeline = pipeline(
        "automatic-speech-recognition",
        model=model_path,
        tokenizer=processor.tokenizer,
        feature_extractor=processor.feature_extractor,
        generate_kwargs=generate_kwargs,
        model_kwargs={"load_in_8bit": False},
        device_map='auto')
    return asr_pipeline

def ASRdirWhisat(
                audio_dir, 
                out_dir = '../whisat_results/',
                model_dir=".",
                max_new_tokens=112,
                num_beams=1,
                do_sample=False,
                repetition_penalty=1,
                ):

    ## ASR using fine-tuned Transformers Whisper
    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    # Simply trancsribe each file in the specified folder separately
    # Whisper takes 30-second input. Anything shorter than this will be 0 padded. Longer will be concatenated.
    # Save output in same directory structure as input in specified top-level folder
    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~


    asr_model=prepare_pipeline(
        model_type=model_type,
        model_dir=model_dir,
        use_stock_model=use_stock_model,
        generate_kwargs={'max_new_tokens':max_new_tokens,
                'num_beams':num_beams,
                'repetition_penalty':repetition_penalty,
                'do_sample':do_sample
                            }
                )


    audio_files = [str(f) for f in Path(audio_dir).rglob("*") if (str(f).rsplit('.',maxsplit=1)[-1] in ['MOV', 'mov', 'WAV', 'wav', 'mp4', 'mp3', 'm4a', 'aac', 'flac', 'alac', 'ogg'] and f.is_file() )]

    # audio_identifier = os.path.basename(audio_dir)
    os.makedirs(out_dir, exist_ok=True)

    message = "This may take a while on CPU." if asr_model.device.type=="cpu" else "Running on GPU"
    print(f'Running ASR for {len(audio_files)} files. {message} ...')
    compute_time=0
    total_audio_dur=0
    # get the start time
    st = time.time()    
    asrDir = out_dir
    for audiofile in tqdm(audio_files): 
        sessname=Path(audiofile).stem
        sesspath=os.path.relpath(os.path.dirname(Path(audiofile).resolve()),Path(audio_dir).resolve())
        asrFullFile = os.path.join(asrDir,sesspath,f"{sessname}.asr.txt") # full session ASR results file
        os.makedirs(os.path.join(asrDir,sesspath),exist_ok=True)

        with torch.no_grad():
            with autocast():
                try:
                    result = asr_model(audiofile)
                except ValueError as e:
                    print(f'{e}: {audiofile}')
                    continue

        asrtext = result['text']

        with open(asrFullFile,'w') as outfile:
            outfile.write(asrtext)
        # print(asrtext)
    et = time.time()
    compute_time = (et-st)
    print(f'...transcription complete in {compute_time:.1f} sec')