aspram / aspram /utils.py
lilitket's picture
Move to package
cab7f7b
import re
def clean_characters(sample, lower: bool = False, only_mesropatar: bool = False):
if 'sentence' not in sample:
if 'transcription' not in sample:
raise NotImplementedError()
else:
sample['sentence'] = sample['transcription']
allowed_chars = (
"-"
"a-z"
"A-Z"
"0-9"
"ԱԲԳԴԵԶԷԸԹԺԻԼԽԾԿՀՁՂՃՄՅՆՇՈՉՊՋՌՍՎՏՐՑՒՓՔՕՖ"
"աբգդեզէըթժիլխծկհձղճմյնշոչպջռսվտրցւփքօֆև"
" \"'։֊.:?;,ՙ՚՛՜՝՞՟\(\)"
)
if lower:
sample["sentence"] = sample["sentence"].lower()
if only_mesropatar:
allowed_chars = (
"ԱԲԳԴԵԶԷԸԹԺԻԼԽԾԿՀՁՂՃՄՅՆՇՈՉՊՋՌՍՎՏՐՑՒՓՔՕՖ"
"աբգդեզէըթժիլխծկհձղճմյնշոչպջռսվտրցւփքօֆև"
" -"
)
sample["sentence"] = re.sub(f"[^{allowed_chars}]", "", sample["sentence"])
# print(sample["sentence"])
return sample
def extract_all_chars(batch):
all_text = " ".join(batch["sentence"])
vocab = list(set(all_text))
return {"vocab": [vocab], "all_text": [all_text]}
def prepare_dataset(smaple, processor):
audio = smaple["audio"]
smaple["input_values"] = processor(
audio["array"], sampling_rate=audio["sampling_rate"]
).input_values[0]
smaple["input_length"] = len(smaple["input_values"])
with processor.as_target_processor():
smaple["labels"] = processor(smaple["sentence"]).input_ids
return smaple
def batched_prepare_dataset(batch, processor):
batch = batch.copy()
audio = batch["audio"]
batch["input_values"] = processor(
[i["array"] for i in audio], sampling_rate=16_000
).input_values
batch["input_length"] = [len(i) for i in batch["input_values"] ]
with processor.as_target_processor():
batch["labels"] = processor(batch["sentence"]).input_ids
return batch