| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import io |
| | import logging |
| | import os |
| | from typing import Optional, Union |
| |
|
| | import soundfile as sf |
| | import torch |
| | from whisper import _MODELS, _download, _ALIGNMENT_HEADS, available_models |
| | from whisper.audio import log_mel_spectrogram |
| | from whisper.model import ModelDimensions |
| |
|
| | from whisper_model import Whisper_ |
| |
|
| | logger = logging.getLogger("dump_feature") |
| |
|
| |
|
| | def load_model( |
| | name: str, |
| | device: Optional[Union[str, torch.device]] = None, |
| | download_root: str = None, |
| | in_memory: bool = False, |
| | ) -> Whisper_: |
| | """ |
| | Reference: https://github.com/openai/whisper/blob/main/whisper/__init__.py#L97 |
| | But we will load a `Whisper_` model for feature extraction. |
| | |
| | Parameters |
| | ---------- |
| | name : str |
| | one of the official model names listed by `whisper.available_models()`, or |
| | path to a model checkpoint containing the model dimensions and the model state_dict. |
| | device : Union[str, torch.device] |
| | the PyTorch device to put the model into |
| | download_root: str |
| | path to download the model files; by default, it uses "~/.cache/whisper" |
| | in_memory: bool |
| | whether to preload the model weights into host memory |
| | |
| | Returns |
| | ------- |
| | model : Whisper |
| | The Whisper ASR model instance |
| | """ |
| |
|
| | if device is None: |
| | device = "cuda" if torch.cuda.is_available() else "cpu" |
| | if download_root is None: |
| | default = os.path.join(os.path.expanduser("~"), ".cache") |
| | download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper") |
| |
|
| | if name in _MODELS: |
| | checkpoint_file = _download(_MODELS[name], download_root, in_memory) |
| | alignment_heads = _ALIGNMENT_HEADS[name] |
| | elif os.path.isfile(name): |
| | checkpoint_file = open(name, "rb").read() if in_memory else name |
| | alignment_heads = None |
| | else: |
| | raise RuntimeError( |
| | f"Model {name} not found; available models = {available_models()}" |
| | ) |
| |
|
| | with ( |
| | io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb") |
| | ) as fp: |
| | checkpoint = torch.load(fp, map_location=device) |
| | del checkpoint_file |
| |
|
| | dims = ModelDimensions(**checkpoint["dims"]) |
| | model = Whisper_(dims) |
| | model.load_state_dict(checkpoint["model_state_dict"]) |
| |
|
| | if alignment_heads is not None: |
| | model.set_alignment_heads(alignment_heads) |
| |
|
| | return model.to(device) |
| |
|
| |
|
| | class WhisperFeatureReader(object): |
| | def __init__(self, root, ckpt, layer, device): |
| | self.device = device |
| | logger.info(f"device = {self.device}") |
| |
|
| | self.model: Whisper_ = load_model(name=ckpt, device=self.device, download_root=root).eval() |
| | self.model.decoder = None |
| | self.layer = layer |
| |
|
| | def read_audio(self, path, ref_len=None): |
| | wav, sample_rate = sf.read(path) |
| | assert sample_rate == 16000, sample_rate |
| | if ref_len is not None and abs(ref_len - len(wav)) > 160: |
| | logger.warning(f"ref {ref_len} != read {len(wav)} ({path})") |
| | return wav |
| |
|
| | def get_feats(self, path, ref_len=None): |
| | wav = self.read_audio(path, ref_len) |
| | audio_length = len(wav) |
| | with torch.no_grad(): |
| | mel = log_mel_spectrogram(torch.from_numpy(wav).float().to(self.device)) |
| | hidden = self.model.extract_features(mel.unsqueeze(0), target_layer=self.layer) |
| | feature_length = audio_length // 320 |
| | hidden = hidden[0, :feature_length] |
| | return hidden.contiguous() |
| |
|