whisat / main.py
rosyvs
new READMe, tidy up main and add hparams
e404b97
import os
from peft import PeftModel, PeftConfig
import torch
from torch.cuda.amp import autocast
from torch.utils.data import DataLoader
from tqdm import tqdm
import transformers
from transformers import WhisperFeatureExtractor, WhisperTokenizer, WhisperProcessor, WhisperForConditionalGeneration, GenerationConfig
from transformers import pipeline, AutomaticSpeechRecognitionPipeline
import argparse
import time
from pathlib import Path
import json
import pandas as pd
import csv
def prepare_pipeline(model_path, generate_kwargs):
"""Prepare a pipeline for ASR inference
Args:
model_path (str): path to model directory / huggingface model name
generate_kwargs (dict): options to pass to pipeline
Returns:
pipeline: ASR pipeline
"""
processor = WhisperProcessor.from_pretrained(model_path)
asr_pipeline = pipeline(
"automatic-speech-recognition",
model=model_path,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
generate_kwargs=generate_kwargs,
model_kwargs={"load_in_8bit": False},
device_map='auto')
return asr_pipeline
def ASRdirWhisat(
audio_dir,
out_dir = '../whisat_results/',
model_dir=".",
max_new_tokens=112,
num_beams=1,
do_sample=False,
repetition_penalty=1,
):
## ASR using fine-tuned Transformers Whisper
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Simply trancsribe each file in the specified folder separately
# Whisper takes 30-second input. Anything shorter than this will be 0 padded. Longer will be concatenated.
# Save output in same directory structure as input in specified top-level folder
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
asr_model=prepare_pipeline(
model_type=model_type,
model_dir=model_dir,
use_stock_model=use_stock_model,
generate_kwargs={'max_new_tokens':max_new_tokens,
'num_beams':num_beams,
'repetition_penalty':repetition_penalty,
'do_sample':do_sample
}
)
audio_files = [str(f) for f in Path(audio_dir).rglob("*") if (str(f).rsplit('.',maxsplit=1)[-1] in ['MOV', 'mov', 'WAV', 'wav', 'mp4', 'mp3', 'm4a', 'aac', 'flac', 'alac', 'ogg'] and f.is_file() )]
# audio_identifier = os.path.basename(audio_dir)
os.makedirs(out_dir, exist_ok=True)
message = "This may take a while on CPU." if asr_model.device.type=="cpu" else "Running on GPU"
print(f'Running ASR for {len(audio_files)} files. {message} ...')
compute_time=0
total_audio_dur=0
# get the start time
st = time.time()
asrDir = out_dir
for audiofile in tqdm(audio_files):
sessname=Path(audiofile).stem
sesspath=os.path.relpath(os.path.dirname(Path(audiofile).resolve()),Path(audio_dir).resolve())
asrFullFile = os.path.join(asrDir,sesspath,f"{sessname}.asr.txt") # full session ASR results file
os.makedirs(os.path.join(asrDir,sesspath),exist_ok=True)
with torch.no_grad():
with autocast():
try:
result = asr_model(audiofile)
except ValueError as e:
print(f'{e}: {audiofile}')
continue
asrtext = result['text']
with open(asrFullFile,'w') as outfile:
outfile.write(asrtext)
# print(asrtext)
et = time.time()
compute_time = (et-st)
print(f'...transcription complete in {compute_time:.1f} sec')