|
--- |
|
license: mit |
|
datasets: |
|
- numind/NuNER |
|
language: |
|
- en |
|
pipeline_tag: zero-shot-classification |
|
tags: |
|
- asr |
|
- Automatic Speech Recognition |
|
- Whisper |
|
- Named entity recognition |
|
--- |
|
|
|
# Whisper-NER |
|
|
|
- Demo: https://huggingface.co/spaces/aiola/whisper-ner-v1 |
|
- Peper: [_WhisperNER: Unified Open Named Entity and Speech Recognition_](https://arxiv.org/abs/2409.08107). |
|
- Code: https://github.com/aiola-lab/whisper-ner |
|
|
|
We introduce WhisperNER, a novel model that allows joint speech transcription and entity recognition. |
|
WhisperNER supports open-type NER, enabling recognition of diverse and evolving entities at inference. |
|
|
|
The WhisperNER model is designed as a strong base model for the downstream task of ASR with NER, and can be fine-tuned on specific datasets for improved performance. |
|
|
|
--------- |
|
|
|
## Training Details |
|
`aiola/whisper-ner-v1` was trained on the NuNER dataset to perform joint audio transcription and NER tagging. |
|
The model was trained and evaluated only on English data. Check out the [paper](https://arxiv.org/abs/2409.08107) for full details. |
|
|
|
--------- |
|
|
|
## Usage |
|
|
|
Inference can be done using the following code (for inference code and more details check out the [whisper-ner repo](https://github.com/aiola-lab/whisper-ner)).: |
|
|
|
```python |
|
import torch |
|
from transformers import WhisperProcessor, WhisperForConditionalGeneration |
|
|
|
model_path = "aiola/whisper-ner-v1" |
|
audio_file_path = "path/to/audio/file" |
|
prompt = "person, company, location" # comma separated entity tags |
|
|
|
# load model and processor from pre-trained |
|
processor = WhisperProcessor.from_pretrained(model_path) |
|
model = WhisperForConditionalGeneration.from_pretrained(model_path) |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
model = model.to(device) |
|
|
|
# load audio file: user is responsible for loading the audio files themselves |
|
target_sample_rate = 16000 |
|
signal, sampling_rate = torchaudio.load(audio_file_path) |
|
resampler = torchaudio.transforms.Resample(sampling_rate, target_sample_rate) |
|
signal = resampler(signal) |
|
# convert to mono or remove first dim if needed |
|
if signal.ndim == 2: |
|
signal = torch.mean(signal, dim=0) |
|
# pre-process to get the input features |
|
input_features = processor( |
|
signal, sampling_rate=target_sample_rate, return_tensors="pt" |
|
).input_features |
|
input_features = input_features.to(device) |
|
|
|
prompt_ids = processor.get_prompt_ids(prompt.lower(), return_tensors="pt") |
|
prompt_ids = prompt_ids.to(device) |
|
|
|
# generate token ids by running model forward sequentially |
|
with torch.no_grad(): |
|
predicted_ids = model.generate( |
|
input_features, |
|
prompt_ids=prompt_ids, |
|
generation_config=model.generation_config, |
|
language="en", |
|
) |
|
|
|
# post-process token ids to text, remove prompt |
|
transcription = processor.batch_decode( |
|
predicted_ids[:, prompt_ids.shape[0]:], skip_special_tokens=True |
|
)[0] |
|
print(transcription) |
|
``` |