Spaces:
Running
on
L4
Running
on
L4
File size: 3,789 Bytes
0a3525d 0c92d16 0a3525d 69e8a46 0a3525d 0c92d16 0a3525d 69e8a46 0a3525d 69e8a46 0a3525d 69e8a46 0a3525d 69e8a46 0a3525d 69e8a46 0a3525d 69e8a46 0a3525d 69e8a46 0a3525d 69e8a46 0a3525d 69e8a46 0a3525d 69e8a46 0a3525d 69e8a46 0a3525d 69e8a46 0a3525d 69e8a46 0a3525d 69e8a46 0a3525d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 |
from pathlib import Path
import click
import hydra
import numpy as np
import soundfile as sf
import torch
import torchaudio
from hydra import compose, initialize
from hydra.utils import instantiate
from loguru import logger
from omegaconf import OmegaConf
from fish_speech.utils.file import AUDIO_EXTENSIONS
# register eval resolver
OmegaConf.register_new_resolver("eval", eval)
def load_model(config_name, checkpoint_path, device="cuda"):
hydra.core.global_hydra.GlobalHydra.instance().clear()
with initialize(version_base="1.3", config_path="../../fish_speech/configs"):
cfg = compose(config_name=config_name)
model = instantiate(cfg)
state_dict = torch.load(
checkpoint_path,
map_location=device,
)
if "state_dict" in state_dict:
state_dict = state_dict["state_dict"]
if any("generator" in k for k in state_dict):
state_dict = {
k.replace("generator.", ""): v
for k, v in state_dict.items()
if "generator." in k
}
result = model.load_state_dict(state_dict, strict=False)
model.eval()
model.to(device)
logger.info(f"Loaded model: {result}")
return model
@torch.no_grad()
@click.command()
@click.option(
"--input-path",
"-i",
default="test.wav",
type=click.Path(exists=True, path_type=Path),
)
@click.option(
"--output-path", "-o", default="fake.wav", type=click.Path(path_type=Path)
)
@click.option("--config-name", default="firefly_gan_vq")
@click.option(
"--checkpoint-path",
default="checkpoints/fish-speech-1.2-sft/firefly-gan-vq-fsq-4x1024-42hz-generator.pth",
)
@click.option(
"--device",
"-d",
default="cuda",
)
def main(input_path, output_path, config_name, checkpoint_path, device):
model = load_model(config_name, checkpoint_path, device=device)
if input_path.suffix in AUDIO_EXTENSIONS:
logger.info(f"Processing in-place reconstruction of {input_path}")
# Load audio
audio, sr = torchaudio.load(str(input_path))
if audio.shape[0] > 1:
audio = audio.mean(0, keepdim=True)
audio = torchaudio.functional.resample(
audio, sr, model.spec_transform.sample_rate
)
audios = audio[None].to(device)
logger.info(
f"Loaded audio with {audios.shape[2] / model.spec_transform.sample_rate:.2f} seconds"
)
# VQ Encoder
audio_lengths = torch.tensor([audios.shape[2]], device=device, dtype=torch.long)
indices = model.encode(audios, audio_lengths)[0][0]
logger.info(f"Generated indices of shape {indices.shape}")
# Save indices
np.save(output_path.with_suffix(".npy"), indices.cpu().numpy())
elif input_path.suffix == ".npy":
logger.info(f"Processing precomputed indices from {input_path}")
indices = np.load(input_path)
indices = torch.from_numpy(indices).to(device).long()
assert indices.ndim == 2, f"Expected 2D indices, got {indices.ndim}"
else:
raise ValueError(f"Unknown input type: {input_path}")
# Restore
feature_lengths = torch.tensor([indices.shape[1]], device=device)
fake_audios = model.decode(indices=indices[None], feature_lengths=feature_lengths)
audio_time = fake_audios.shape[-1] / model.spec_transform.sample_rate
logger.info(
f"Generated audio of shape {fake_audios.shape}, equivalent to {audio_time:.2f} seconds from {indices.shape[1]} features, features/second: {indices.shape[1] / audio_time:.2f}"
)
# Save audio
fake_audio = fake_audios[0, 0].float().cpu().numpy()
sf.write(output_path, fake_audio, model.spec_transform.sample_rate)
logger.info(f"Saved audio to {output_path}")
if __name__ == "__main__":
main()
|