rhg-elan-transcriber / functions.py
micahg's picture
Initial commit
4dc8f4b
# coding=utf8
import os
import shutil
import sys
import subprocess
import xml.etree.ElementTree as ET
from transformers import pipeline
from environment import DEFAULT_MODEL_LANGUAGE, DEFAULT_MODEL, DEFAULT_LANGUAGE, DEVICE
# class for annotation segments
class Segment:
def __init__(self, segment_id: str, start: int, end: int):
self.segment_id:str = segment_id
self.start:int = start
self.end:int = end
self.transcription:str = ''
#self.ts_start:str = ts_start
#self.ts_end:str = ts_end
def getTimeSlots(eaf):
time_slot_dic = {}
order = eaf.find('TIME_ORDER')
for slot in order:
time_slot_dic[slot.get('TIME_SLOT_ID')] = slot.get('TIME_VALUE')
return time_slot_dic
def getAnnotationSegments(eaf, tier_type):
segment_list = []
time_slot_dic = getTimeSlots(eaf)
for tier in eaf.findall('TIER'):
if tier.get('LINGUISTIC_TYPE_REF') == tier_type:
for annotation in tier:
alignable_annotation = annotation.find('ALIGNABLE_ANNOTATION')
segment_id = alignable_annotation.get('ANNOTATION_ID')
start = time_slot_dic[alignable_annotation.get('TIME_SLOT_REF1')]
end = time_slot_dic[alignable_annotation.get('TIME_SLOT_REF2')]
segment_list.append(Segment(segment_id, start, end))
return segment_list
def splice_audio(audio_path, segment_id, start, end, temp_dir):
file_path = f"{temp_dir}/{segment_id}.wav"
if os.path.exists(file_path):
os.remove(file_path)
subprocess.call([
"ffmpeg",
"-loglevel", "fatal",
"-hide_banner",
"-nostdin",
"-i", audio_path,
"-ss", f"{int(start)/1000}",
"-to", f"{int(end)/1000}",
file_path
])
return f"{temp_dir}/{segment_id}.wav"
# transcribes a single and returns the transcription
def transcribe_audio(model_id, audio_path):
transcribe = pipeline(
task = "automatic-speech-recognition",
model = model_id,
chunk_length_s = 30,
device = DEVICE,
)
transcribe.model.config.forced_decoder_ids = transcribe.tokenizer.get_decoder_prompt_ids(language='bengali', task="transcribe")
result = transcribe(audio_path, max_new_tokens=448)
transcription = result["text"].strip()
#print(f"Transcription for {audio_path}: {transcription}")
return transcription
def transcribe_eaf(eaf_path, audio_path, tier_type):
eaf_tree = ET.parse(eaf_path)
eaf_root = eaf_tree.getroot()
segment_list = getAnnotationSegments(eaf_root, tier_type)
if not os.path.exists('temp_dir'):
os.makedirs('temp_dir')
for segment in segment_list:
# get the audio segment
segment_audio_file = splice_audio(audio_path, segment.segment_id, segment.start, segment.end, 'temp_dir')
segment.transcription = transcribe_audio(DEFAULT_MODEL, segment_audio_file)
os.remove(segment_audio_file)
print(f'{segment.segment_id}\t{segment.transcription}')
shutil.rmtree('temp_dir')
for segment in segment_list:
for e in eaf_root.iter():
if e.tag == 'ALIGNABLE_ANNOTATION' and e.get('ANNOTATION_ID') == segment.segment_id:
e.find('ANNOTATION_VALUE').text = segment.transcription
new_eaf_path = f'{eaf_path[:-4]}_autotranscribed.eaf'
eaf_tree.write(new_eaf_path, encoding='utf-8', xml_declaration=True)
return new_eaf_path