CAPT-ReadAloud / wav2vec_aligen.py
seba3y's picture
Upload 4 files
be1b9b7 verified
raw history blame
No virus
1.83 kB
import torch
import librosa
import os
from model import Wav2Vec2ForWav2Vec2ForCTCAndUttranceRegression
from transformers import Wav2Vec2Processor
from optimum.bettertransformer import BetterTransformer
device = 'cuda' if torch.cuda.is_available() else 'cpu'
os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python'
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = '1'
os.environ['TRANSFORMERS_VERBOSITY'] = 'error'
torch.random.manual_seed(0);
# protobuf==3.20.0
model_name = "seba3y/wav2vec-base-en-pronunciation-assesment"
processor = Wav2Vec2Processor.from_pretrained(model_name)
model = Wav2Vec2ForWav2Vec2ForCTCAndUttranceRegression.from_pretrained(model_name).to(device)
model = BetterTransformer.transform(model)
def load_audio(audio_path, processor):
audio, sr = librosa.load(audio_path, sr=16000)
input_values = processor(audio, sampling_rate=16000, return_tensors="pt").input_values
return input_values
@torch.inference_mode()
def get_emissions(input_values, model):
results = model(input_values,).logits
results.pop('logits')
return results
def speaker_pronunciation_assesment(audio_path):
input_values = load_audio(audio_path, processor)
result_scores = get_emissions(input_values, model)
content_scores = round(result_scores['content'].cpu().item())
pronunciation_score = round(result_scores['accuracy'].cpu().item())
fluency_score = round(result_scores['fluency'].cpu().item())
total_score = round(result_scores['total score'].cpu().item())
result = {'pronunciation_accuracy': pronunciation_score,
'content_scores': content_scores,
'total_score': total_score,
'fluency_score': fluency_score}
return result
if __name__ == '__main__':
print(__naem__)