Spaces:
Paused
Paused
import gradio as gr | |
import numpy as np | |
import torch | |
import transformers | |
from pathlib import Path | |
from transformers import pipeline | |
from transformers.utils import logging | |
# Log | |
#logging.set_verbosity_debug() | |
logger = logging.get_logger("transformers") | |
# Pipelines | |
## Automatic Speech Recognition | |
## https://huggingface.co/docs/transformers/task_summary#automatic-speech-recognition | |
## Require ffmpeg to be installed | |
asr_device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
asr_torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
asr_model = "openai/whisper-tiny" | |
asr = pipeline( | |
"automatic-speech-recognition", | |
model=asr_model, | |
torch_dtype=asr_torch_dtype, | |
device=asr_device | |
) | |
## Token Classification / Name Entity Recognition | |
## https://huggingface.co/docs/transformers/task_summary#token-classification | |
tc_device = 0 if torch.cuda.is_available() else "cpu" | |
tc_model = "dslim/distilbert-NER" | |
tc = pipeline( | |
"token-classification", # ner | |
model=tc_model, | |
device=tc_device | |
) | |
# --- | |
# Transformers | |
# https://www.gradio.app/main/docs/gradio/audio#behavior | |
# As output component: expects audio data in any of these formats: | |
# - a str or pathlib.Path filepath | |
# - or URL to an audio file, | |
# - or a bytes object (recommended for streaming), | |
# - or a tuple of (sample rate in Hz, audio data as numpy array) | |
def transcribe(audio: str | Path | bytes | tuple[int, np.ndarray] | None): | |
logger.debug(">Transcribe") | |
if audio is None: | |
return "..." | |
# TODO Manage str/Path | |
text = "" | |
# https://huggingface.co/docs/transformers/main_classes/pipelines#transformers.AutomaticSpeechRecognitionPipeline.__call__ | |
# Whisper input format for tuple differ from output provided by gradio audio component | |
if asr_model.startswith("openai/whisper") and type(audio) is tuple: | |
sampling_rate, raw = audio | |
# Convert to mono if stereo | |
if raw.ndim > 1: | |
raw = raw.mean(axis=1) | |
# Convert according to asr_torch_dtype | |
raw = raw.astype(np.float16 if type(asr_torch_dtype) is torch.float16 else np.float32) | |
raw /= np.max(np.abs(raw)) | |
inputs = {"sampling_rate": sampling_rate, "raw": raw} | |
logger.debug(inputs) | |
transcript = asr(inputs) | |
text = transcript['text'] | |
logger.debug(text) | |
return text | |
def tokenize(text: str): | |
logger.debug(">Tokenize") | |
entities = tc(text) | |
logger.debug(entities) | |
# TODO Add Text Classification for sentiment analysis | |
return {"text": text, "entities": entities} | |
def classify(text: str): | |
logger.debug(">Classify") | |
return None | |
def transcribe_tokenize(*arg): | |
return tokenize(transcribe(arg)) | |
# --- | |
# Gradio | |
## Interfaces | |
# https://www.gradio.app/main/docs/gradio/audio | |
input_audio = gr.Audio( | |
sources=["upload", "microphone"], | |
show_share_button=False | |
) | |
## App | |
asrner_app = gr.Interface( | |
transcribe_tokenize, | |
inputs=[ | |
input_audio | |
], | |
outputs=[ | |
gr.HighlightedText() | |
], | |
title="ASR>NER", | |
description=( | |
"Transcribe, Tokenize, Classify" | |
), | |
flagging_mode="never" | |
) | |
ner_app = gr.Interface( | |
tokenize, | |
inputs=[ | |
gr.Textbox() | |
], | |
outputs=[ | |
gr.HighlightedText() | |
], | |
title="NER", | |
description=( | |
"Tokenize, Classify" | |
), | |
flagging_mode="never" | |
) | |
gradio_app = gr.TabbedInterface( | |
interface_list=[ | |
asrner_app, | |
ner_app | |
], | |
tab_names=[ | |
asrner_app.title, | |
ner_app.title | |
], | |
title="ASRNERSBX" | |
) | |
## Start! | |
gradio_app.launch() |