Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			Zero
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			Zero
	| # Copyright (c) ByteDance, Inc. and its affiliates. | |
| # Copyright (c) Chutong Meng | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| # Based on fairseq (https://github.com/facebookresearch/fairseq) and | |
| # Whisper (https://github.com/openai/whisper/) | |
| import io | |
| import logging | |
| import os | |
| from typing import Optional, Union | |
| import soundfile as sf | |
| import torch | |
| from whisper import _MODELS, _download, _ALIGNMENT_HEADS, available_models | |
| from whisper.audio import log_mel_spectrogram | |
| from whisper.model import ModelDimensions | |
| from whisper_model import Whisper_ | |
| logger = logging.getLogger("dump_feature") | |
| def load_model( | |
| name: str, | |
| device: Optional[Union[str, torch.device]] = None, | |
| download_root: str = None, | |
| in_memory: bool = False, | |
| ) -> Whisper_: | |
| """ | |
| Reference: https://github.com/openai/whisper/blob/main/whisper/__init__.py#L97 | |
| But we will load a `Whisper_` model for feature extraction. | |
| Parameters | |
| ---------- | |
| name : str | |
| one of the official model names listed by `whisper.available_models()`, or | |
| path to a model checkpoint containing the model dimensions and the model state_dict. | |
| device : Union[str, torch.device] | |
| the PyTorch device to put the model into | |
| download_root: str | |
| path to download the model files; by default, it uses "~/.cache/whisper" | |
| in_memory: bool | |
| whether to preload the model weights into host memory | |
| Returns | |
| ------- | |
| model : Whisper | |
| The Whisper ASR model instance | |
| """ | |
| if device is None: | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| if download_root is None: | |
| default = os.path.join(os.path.expanduser("~"), ".cache") | |
| download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper") | |
| if name in _MODELS: | |
| checkpoint_file = _download(_MODELS[name], download_root, in_memory) | |
| alignment_heads = _ALIGNMENT_HEADS[name] | |
| elif os.path.isfile(name): | |
| checkpoint_file = open(name, "rb").read() if in_memory else name | |
| alignment_heads = None | |
| else: | |
| raise RuntimeError( | |
| f"Model {name} not found; available models = {available_models()}" | |
| ) | |
| with ( | |
| io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb") | |
| ) as fp: | |
| checkpoint = torch.load(fp, map_location=device) | |
| del checkpoint_file | |
| dims = ModelDimensions(**checkpoint["dims"]) | |
| model = Whisper_(dims) | |
| model.load_state_dict(checkpoint["model_state_dict"]) | |
| if alignment_heads is not None: | |
| model.set_alignment_heads(alignment_heads) | |
| return model.to(device) | |
| class WhisperFeatureReader(object): | |
| def __init__(self, root, ckpt, layer, device): | |
| self.device = device | |
| logger.info(f"device = {self.device}") | |
| self.model: Whisper_ = load_model(name=ckpt, device=self.device, download_root=root).eval() | |
| self.model.decoder = None # to save some memory by deleting the decoder | |
| self.layer = layer # one-based | |
| def read_audio(self, path, ref_len=None): | |
| wav, sample_rate = sf.read(path) | |
| assert sample_rate == 16000, sample_rate | |
| if ref_len is not None and abs(ref_len - len(wav)) > 160: | |
| logger.warning(f"ref {ref_len} != read {len(wav)} ({path})") | |
| return wav | |
| def get_feats(self, path, ref_len=None): | |
| wav = self.read_audio(path, ref_len) | |
| audio_length = len(wav) | |
| with torch.no_grad(): | |
| mel = log_mel_spectrogram(torch.from_numpy(wav).float().to(self.device)) | |
| hidden = self.model.extract_features(mel.unsqueeze(0), target_layer=self.layer) | |
| feature_length = audio_length // 320 | |
| hidden = hidden[0, :feature_length] | |
| return hidden.contiguous() | |

