Hervé BREDIN
feat: add support for gated models
4044b32
raw history blame
No virus
5.45 kB
# MIT License
#
# Copyright (c) 2022- CNRS
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import io
import base64
import numpy as np
import scipy.io.wavfile
from typing import Text
from huggingface_hub import HfApi
import streamlit as st
from pyannote.audio import Pipeline
from pyannote.audio import Audio
from pyannote.core import Segment
import streamlit.components.v1 as components
def to_base64(waveform: np.ndarray, sample_rate: int = 16000) -> Text:
"""Convert waveform to base64 data"""
waveform /= np.max(np.abs(waveform)) + 1e-8
with io.BytesIO() as content:
scipy.io.wavfile.write(content, sample_rate, waveform)
content.seek(0)
b64 = base64.b64encode(content.read()).decode()
b64 = f"data:audio/x-wav;base64,{b64}"
return b64
PYANNOTE_LOGO = "https://avatars.githubusercontent.com/u/7559051?s=400&v=4"
EXCERPT = 30.0
st.set_page_config(
page_title="pyannote.audio pretrained pipelines", page_icon=PYANNOTE_LOGO
)
st.sidebar.image(PYANNOTE_LOGO)
st.markdown("""# 🎹 Pretrained pipelines
""")
PIPELINES = [
p.modelId
for p in HfApi().list_models(filter="pyannote-audio-pipeline")
if p.modelId.startswith("pyannote/")
]
audio = Audio(sample_rate=16000, mono=True)
selected_pipeline = st.selectbox("Select a pipeline", PIPELINES, index=0)
with st.spinner("Loading pipeline..."):
pipeline = Pipeline.from_pretrained(selected_pipeline, use_auth_token=st.secrets["PYANNOTE_TOKEN"])
uploaded_file = st.file_uploader("Choose an audio file")
if uploaded_file is not None:
try:
duration = audio.get_duration(uploaded_file)
except RuntimeError as e:
st.error(e)
st.stop()
waveform, sample_rate = audio.crop(
uploaded_file, Segment(0, min(duration, EXCERPT))
)
uri = "".join(uploaded_file.name.split())
file = {"waveform": waveform, "sample_rate": sample_rate, "uri": uri}
with st.spinner(f"Processing first {EXCERPT:g} seconds..."):
output = pipeline(file)
with open('assets/template.html') as html, open('assets/style.css') as css:
html_template = html.read()
st.markdown('<style>{}</style>'.format(css.read()), unsafe_allow_html=True)
colors = [
"#ffd70033",
"#00ffff33",
"#ff00ff33",
"#00ff0033",
"#9932cc33",
"#00bfff33",
"#ff7f5033",
"#66cdaa33",
]
num_colors = len(colors)
label2color = {label: colors[k % num_colors] for k, label in enumerate(sorted(output.labels()))}
BASE64 = to_base64(waveform.numpy().T)
REGIONS = ""
LEGENDS = ""
labels=[]
for segment, _, label in output.itertracks(yield_label=True):
REGIONS += f"var re = wavesurfer.addRegion({{start: {segment.start:g}, end: {segment.end:g}, color: '{label2color[label]}', resize : false, drag : false}});"
if not label in labels:
LEGENDS += f"<li><span style='background-color:{label2color[label]}'></span>{label}</li>"
labels.append(label)
html = html_template.replace("BASE64", BASE64).replace("REGIONS", REGIONS)
components.html(html, height=250, scrolling=True)
st.markdown("<div style='overflow : auto'><ul class='legend'>"+LEGENDS+"</ul></div>", unsafe_allow_html=True)
st.markdown("---")
with io.StringIO() as fp:
output.write_rttm(fp)
content = fp.getvalue()
b64 = base64.b64encode(content.encode()).decode()
href = f'Download as <a download="{output.uri}.rttm" href="data:file/text;base64,{b64}">RTTM</a> or run it on the whole {int(duration):d}s file:'
st.markdown(href, unsafe_allow_html=True)
code = f"""
from pyannote.audio import Pipeline
pipeline = Pipeline.from_pretrained("{selected_pipeline}")
output = pipeline("{uploaded_file.name}")
"""
st.code(code, language='python')
st.sidebar.markdown(
"""
-------------------
To use these pipelines on more and longer files on your own (GPU, hence much faster) servers, check the [documentation](https://github.com/pyannote/pyannote-audio).
For [technical questions](https://github.com/pyannote/pyannote-audio/discussions) and [bug reports](https://github.com/pyannote/pyannote-audio/issues), please check [pyannote.audio](https://github.com/pyannote/pyannote-audio) Github repository.
For commercial enquiries and scientific consulting, please contact [me](mailto:herve@niderb.fr).
"""
)