Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	File size: 2,551 Bytes
			
			| 79f947d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 | import torch
from torch import nn
import typing as tp
import torchaudio
import einops
from abc import ABC, abstractmethod
class AbstractVAE(ABC, nn.Module):
    @property
    @abstractmethod
    def frame_rate(self) -> float:
        ...
    @property
    @abstractmethod
    def orig_sample_rate(self) -> int:
        ...
    
    @property
    @abstractmethod
    def channel_dim(self) -> int:
        ...
    @property
    @abstractmethod
    def split_bands(self) -> int:
        ...
    @property
    @abstractmethod
    def input_channel(self) -> int:
        ...
    def encode(self, wav) -> torch.Tensor:
        ...
        
    def decode(self, latents) -> torch.Tensor:
        ...
from .autoencoders import create_autoencoder_from_config, AudioAutoencoder
class StableVAE(AbstractVAE):
    def __init__(self, vae_ckpt, vae_cfg, sr=48000) -> None:
        super().__init__()
        import json
        with open(vae_cfg) as f:
            config = json.load(f) 
        self.vae: AudioAutoencoder = create_autoencoder_from_config(config)
        self.vae.load_state_dict(torch.load(vae_ckpt, map_location=torch.device('cpu'))['state_dict'])  
        self.sample_rate = sr
        self.rsp48k = torchaudio.transforms.Resample(sr, self.orig_sample_rate) if sr != self.orig_sample_rate else nn.Identity()    
    @torch.no_grad()
    def encode(self, wav: torch.Tensor, sample=True) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
        wav = self.rsp48k(wav)
        if wav.shape[-1] < 2048:
            return torch.zeros((wav.shape[0], self.channel_dim, 0), device=wav.device, dtype=wav.dtype)
        if wav.ndim == 2:
            wav = wav.unsqueeze(1)
        if wav.shape[1] == 1:
            wav = wav.repeat(1, self.vae.in_channels, 1)
        latent = self.vae.encode_audio(wav) # B, 64, T  
        return latent
            
        
    def decode(self, latents: torch.Tensor, **kwargs):
        # B, 64, T
        with torch.no_grad():
            audio_recon = self.vae.decode_audio(latents, **kwargs)
            
        return audio_recon
        
    @property
    def frame_rate(self) -> float:
        return float(self.vae.sample_rate) / self.vae.downsampling_ratio
    @property
    def orig_sample_rate(self) -> int:
        return self.vae.sample_rate
    @property
    def channel_dim(self) -> int:
        return self.vae.latent_dim
    @property
    def split_bands(self) -> int:
        return 1
    
    @property
    def input_channel(self) -> int:
        return self.vae.in_channels
 | 
