import gradio as gr
from pathlib import Path
import soundfile as sf
# forcing torch.load to CPU
import torch
device = 'cuda' if torch.cuda.is_available() else 'cpu'
_old_load = torch.load
def safe_torch_load(*args, **kwargs):
args = list(args)
if len(args) >= 2:
args[1] = device
else:
kwargs['map_location'] = device
return _old_load(*args, **kwargs)
torch.load = safe_torch_load
import torchaudio
import hydra
from omegaconf import OmegaConf
import diffusers.schedulers as noise_schedulers
from utils.config import register_omegaconf_resolvers
from models.common import LoadPretrainedBase
from huggingface_hub import hf_hub_download
import fairseq
register_omegaconf_resolvers()
config = OmegaConf.load("configs/infer.yaml")
ckpt_path = hf_hub_download(
repo_id="assasinatee/STAR",
filename="model.safetensors",
repo_type="model",
force_download=False
)
exp_config = OmegaConf.load("configs/config.yaml")
if "pretrained_ckpt" in exp_config["model"]:
exp_config["model"]["pretrained_ckpt"] = ckpt_path
model: LoadPretrainedBase = hydra.utils.instantiate(exp_config["model"])
model = model.to(device)
ckpt_path = hf_hub_download(
repo_id="assasinatee/STAR",
filename="hubert_large_ll60k.pt",
repo_type="model",
force_download=False
)
hubert_models, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task([ckpt_path])
hubert_model = hubert_models[0].eval().to(device)
scheduler = getattr(
noise_schedulers,
config["noise_scheduler"]["type"],
).from_pretrained(
config["noise_scheduler"]["name"],
subfolder="scheduler",
)
@torch.no_grad()
def infer(audio_path: str) -> str:
waveform_tts, sample_rate = torchaudio.load(audio_path)
if sample_rate != 16000:
waveform_tts = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform_tts)
if waveform_tts.shape[0] > 1:
waveform_tts = torch.mean(waveform_tts, dim=0, keepdim=True)
with torch.no_grad():
features, _ = hubert_model.extract_features(waveform_tts.to(device))
kwargs = OmegaConf.to_container(config["infer_args"].copy(), resolve=True)
kwargs['content'] = [features]
kwargs['condition'] = None
kwargs['task'] = ["speech_to_audio"]
model.eval()
waveform = model.inference(
scheduler=scheduler,
**kwargs,
)
output_file = "output_audio.wav"
sf.write(output_file, waveform.squeeze().cpu().numpy(), samplerate=exp_config["sample_rate"])
return output_file
with gr.Blocks(title="STAR Online Inference", theme=gr.themes.Soft()) as demo:
gr.Markdown("# STAR: Speech-to-Audio Generation via Representation Learning")
gr.Markdown("""
## 📚️ Introduction
STAR is the first end-to-end speech-to-audio generation framework, designed to enhance efficiency and address error propagation inherent in cascaded systems.
Within this space, you have the opportunity to directly control our model through voice input, thereby generating the corresponding audio output.
## 🗣️ Input
A brief input speech utterance for the overall audio scene.
> Example:A cat meowing and young female speaking
### 🎙️ Input Speech Example
""")
speech = gr.Audio(value="wav/speech.wav", label="Input Speech Example", type="filepath")
gr.Markdown("""
## 🎧️ Output
Capture both auditory events and scene cues and generate corresponding audio
### 🔊 Output Audio Example
""")
audio = gr.Audio(value="wav/audio.wav", label="Generated Audio Example", type="filepath")
gr.Markdown("""
---
## 🛠️ Online Inference
You can upload your own samples, or try the quick examples provided below.
""")
with gr.Column():
input_audio = gr.Audio(label="🗣️ Speech Input", type="filepath")
btn = gr.Button("🎵Generate Audio!", variant="primary")
output_audio = gr.Audio(label="🎧️ Generated Audio", type="filepath")
btn.click(fn=infer, inputs=input_audio, outputs=output_audio)
gr.Markdown("""
## 🎯 Quick Examples
""")
display_caption = gr.Textbox(label="📝 Caption" ,visible=False)
with gr.Tabs():
with gr.Tab("VITS Generated Speech"):
gr.Examples(
examples=[
["wav/vits/1.wav", "A cat meowing and young female speaking"],
["wav/vits/2.wav", "Sustained industrial engine noise"],
["wav/vits/3.wav", "A woman talks and a baby whispers"],
["wav/vits/4.wav", "A man speaks followed by a toilet flush"],
["wav/vits/5.wav", "It is raining and thundering, and then a man speaks"],
["wav/vits/6.wav", "A man speaking as birds are chirping"],
["wav/vits/7.wav", "A muffled man talking as a goat baas before and after two goats baaing in the distance while wind blows into a microphone"],
["wav/vits/8.wav", "Birds chirping and a horse neighing"],
["wav/vits/9.wav", "Several church bells ringing"],
["wav/vits/10.wav", "A telephone rings with bell sounds"]
],
inputs=[input_audio, display_caption],
label="Click examples below to try!",
cache_examples = False,
examples_per_page = 10,
)
with gr.Tab("Real Human Speech"):
gr.Examples(
examples=[
["wav/human/1.wav", "A cat meowing and young female speaking"],
["wav/human/2.wav", "Sustained industrial engine noise"],
["wav/human/3.wav", "A woman talks and a baby whispers"],
["wav/human/4.wav", "A man speaks followed by a toilet flush"],
["wav/human/5.wav", "It is raining and thundering, and then a man speaks"],
["wav/human/6.wav", "A man speaking as birds are chirping"],
["wav/human/7.wav", "A muffled man talking as a goat baas before and after two goats baaing in the distance while wind blows into a microphone"],
["wav/human/8.wav", "Birds chirping and a horse neighing"],
["wav/human/9.wav", "Several church bells ringing"],
["wav/human/10.wav", "A telephone rings with bell sounds"]
],
inputs=[input_audio, display_caption],
label="Click examples below to try!",
cache_examples = False,
examples_per_page = 10,
)
demo.launch()