cc_audio_8 / tabs /split_tabs.py
HoneyTian's picture
add split tab
a92b815
#!/usr/bin/python3
# -*- coding: utf-8 -*-
import argparse
import json
from functools import lru_cache, partial
from pathlib import Path
import shutil
import tempfile
import zipfile
import gradio as gr
import numpy as np
import torch
import torch.nn as nn
from typing import List
from project_settings import project_path
from toolbox.cv2.misc import erode, dilate
from toolbox.torch.utils.data.vocabulary import Vocabulary
@lru_cache(maxsize=100)
def load_model(model_file: Path):
with zipfile.ZipFile(model_file, "r") as f_zip:
out_root = Path(tempfile.gettempdir()) / "cc_audio_8"
if out_root.exists():
shutil.rmtree(out_root.as_posix())
out_root.mkdir(parents=True, exist_ok=True)
f_zip.extractall(path=out_root)
tgt_path = out_root / model_file.stem
jit_model_file = tgt_path / "trace_model.zip"
vocab_path = tgt_path / "vocabulary"
vocabulary = Vocabulary.from_files(vocab_path.as_posix())
with open(jit_model_file.as_posix(), "rb") as f:
model = torch.jit.load(f)
model.eval()
shutil.rmtree(tgt_path)
d = {
"model": model,
"vocabulary": vocabulary
}
return d
class Tagger(object):
def __init__(self,
model_file: str,
win_size: int,
win_step: int,
sample_rate: int = 8000,
):
self.model_file = Path(model_file)
self.win_size = win_size
self.win_step = win_step
self.sample_rate = sample_rate
self.model: nn.Module = None
self.vocabulary: Vocabulary = None
self.load_models()
def load_models(self):
m = load_model(self.model_file)
model = m["model"]
vocabulary = m["vocabulary"]
self.model = model
self.vocabulary = vocabulary
return model, vocabulary
def tag(self, signal: np.ndarray):
signal_length = len(signal)
win_size = int(self.win_size * self.sample_rate)
win_step = int(self.win_step * self.sample_rate)
signal = np.concatenate([
np.zeros(shape=(win_size // 2,), dtype=np.int16),
signal,
np.zeros(shape=(win_size // 2,), dtype=np.int16),
])
result = list()
for i in range(0, signal_length, win_step):
sub_signal = signal[i: i+win_size]
if len(sub_signal) < win_size:
break
inputs = torch.tensor(sub_signal, dtype=torch.float32)
inputs = torch.unsqueeze(inputs, dim=0)
probs = self.model(inputs)
probs = probs.tolist()[0]
argidx = np.argmax(probs)
label_str = self.vocabulary.get_token_from_index(argidx, namespace="labels")
prob = probs[argidx]
result.append(label_str)
return result
def correct_labels(labels: List[str], target_label: str = "noise", n_erode: int = 2, n_dilate: int = 2):
labels = erode(labels, erode_label=target_label, n=n_erode)
labels = dilate(labels, dilate_label=target_label, n=n_dilate)
return labels
def split_signal_by_labels(signal: np.ndarray, labels: List[str], target_label: str):
l = len(labels)
noise_list = list()
begin = None
for idx, label in enumerate(labels):
if label == target_label:
if begin is None:
begin = idx
elif label != target_label:
if begin is not None:
noise_list.append((begin, idx))
begin = None
else:
pass
else:
if begin is not None:
noise_list.append((begin, l))
result = list()
win_step = signal.shape[0] / l
for begin, end in noise_list:
begin = int(begin * win_step)
end = int(end * win_step)
sub_signal = signal[begin: end + 1]
result.append({
"begin": begin,
"end": end + 1,
"sub_signal": sub_signal,
})
return result
@lru_cache(maxsize=100)
def get_tagger(model_file: str,
win_size: int = 2.0,
win_step: int = 0.25,
):
tagger = Tagger(
model_file=model_file,
win_size=win_size,
win_step=win_step,
)
return tagger
def when_model_name_change(model_name: str, split_trained_model_dir: Path):
m = load_model(
model_file=(split_trained_model_dir / f"{model_name}.zip")
)
token_to_index: dict = m["vocabulary"].get_token_to_index_vocabulary(namespace="labels")
label_choices = list(token_to_index.keys())
split_label = gr.Dropdown(choices=label_choices, value=label_choices[0], label="label")
return split_label
def get_split_tab(examples_dir: str, trained_model_dir: str):
split_examples_dir = Path(examples_dir)
split_trained_model_dir = Path(trained_model_dir)
# models
split_model_choices = list()
for filename in split_trained_model_dir.glob("*.zip"):
model_name = filename.stem
if model_name == "examples":
continue
split_model_choices.append(model_name)
model_choices = list(sorted(split_model_choices))
# model_labels_choices
m = load_model(
model_file=(split_trained_model_dir / f"{model_choices[0]}.zip")
)
token_to_index = m["vocabulary"].get_token_to_index_vocabulary(namespace="labels")
model_labels_choices = list(token_to_index.keys())
# examples
split_examples = list()
for filename in split_examples_dir.glob("**/*/*.wav"):
label = filename.parts[-2]
target_label = m["vocabulary"].get_token_from_index(index=0, namespace="labels")
split_examples.append([
filename.as_posix(),
model_choices[0],
model_labels_choices[0]
])
with gr.TabItem("split"):
with gr.Row():
with gr.Column(scale=3):
split_audio = gr.Audio(label="audio")
with gr.Row():
split_model_name = gr.Dropdown(choices=model_choices, value=model_choices[0], label="model_name")
split_label = gr.Dropdown(choices=model_labels_choices, value=model_labels_choices[0], label="label")
split_win_size = gr.Number(value=2.0, minimum=0, maximum=5, step=0.05, label="win_size")
split_win_step = gr.Number(value=0.25, minimum=0, maximum=5, step=0.05, label="win_step")
split_n_erode = gr.Number(value=2, minimum=0, maximum=5, step=1, label="n_erode")
split_n_dilate = gr.Number(value=2, minimum=0, maximum=5, step=1, label="n_dilate")
split_button = gr.Button("run", variant="primary")
with gr.Column(scale=3):
split_sub_audio = gr.Audio(label="sub_audio")
split_sub_audio_message = gr.Textbox(max_lines=10, label="sub_audio_message")
split_sub_audio_dataset_state = gr.State(value=[])
split_sub_audio_dataset = gr.Dataset(
components=[split_sub_audio, split_sub_audio_message],
samples=split_sub_audio_dataset_state.value,
)
split_sub_audio_dataset.click(
fn=lambda x: (
x[0], x[1]
),
inputs=[split_sub_audio_dataset],
outputs=[split_sub_audio, split_sub_audio_message]
)
def when_click_split_button(audio_t,
model_name: str,
label: str,
win_size: int,
win_step: int,
n_erode: int = 2,
n_dilate: int = 2
):
max_wave_value = 32768.0
sample_rate, signal = audio_t
model_file = project_path / f"trained_models/{model_name}.zip"
tagger = get_tagger(model_file.as_posix(), win_size, win_step)
signal_ = signal / max_wave_value
labels = tagger.tag(signal_)
labels = correct_labels(labels, target_label=label, n_erode=n_erode, n_dilate=n_dilate)
sub_signal_list = split_signal_by_labels(signal, labels, target_label=label)
_split_sub_audio_dataset_state = [
[
(sample_rate, item["sub_signal"]),
json.dumps({"begin": item["begin"], "end": item["end"]}, ensure_ascii=False, indent=2),
]
for item in sub_signal_list
]
_split_sub_audio_dataset = gr.Dataset(
components=[split_sub_audio, split_sub_audio_message],
samples=_split_sub_audio_dataset_state,
visible=True
)
return _split_sub_audio_dataset_state, _split_sub_audio_dataset
gr.Examples(
split_examples,
inputs=[
split_audio,
split_model_name, split_label,
split_win_size, split_win_step,
split_n_erode, split_n_dilate,
],
outputs=[split_sub_audio_dataset_state, split_sub_audio_dataset],
fn=when_click_split_button,
examples_per_page=5,
)
split_model_name.change(
partial(when_model_name_change, split_trained_model_dir=split_trained_model_dir),
inputs=[split_model_name],
outputs=[split_label],
)
split_button.click(
when_click_split_button,
inputs=[
split_audio,
split_model_name, split_label,
split_win_size, split_win_step,
split_n_erode, split_n_dilate,
],
outputs=[split_sub_audio_dataset_state, split_sub_audio_dataset],
)
return locals()
if __name__ == "__main__":
pass