import gradio as gr from pathlib import Path import soundfile as sf # forcing torch.load to CPU import torch device = 'cuda' if torch.cuda.is_available() else 'cpu' _old_load = torch.load def safe_torch_load(*args, **kwargs): args = list(args) if len(args) >= 2: args[1] = device else: kwargs['map_location'] = device return _old_load(*args, **kwargs) torch.load = safe_torch_load import torchaudio import hydra from omegaconf import OmegaConf import diffusers.schedulers as noise_schedulers from utils.config import register_omegaconf_resolvers from models.common import LoadPretrainedBase from huggingface_hub import hf_hub_download import fairseq register_omegaconf_resolvers() config = OmegaConf.load("configs/infer.yaml") ckpt_path = hf_hub_download( repo_id="assasinatee/STAR", filename="model.safetensors", repo_type="model", force_download=False ) exp_config = OmegaConf.load("configs/config.yaml") if "pretrained_ckpt" in exp_config["model"]: exp_config["model"]["pretrained_ckpt"] = ckpt_path model: LoadPretrainedBase = hydra.utils.instantiate(exp_config["model"]) model = model.to(device) ckpt_path = hf_hub_download( repo_id="assasinatee/STAR", filename="hubert_large_ll60k.pt", repo_type="model", force_download=False ) hubert_models, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task([ckpt_path]) hubert_model = hubert_models[0].eval().to(device) scheduler = getattr( noise_schedulers, config["noise_scheduler"]["type"], ).from_pretrained( config["noise_scheduler"]["name"], subfolder="scheduler", ) @torch.no_grad() def infer(audio_path: str) -> str: waveform_tts, sample_rate = torchaudio.load(audio_path) if sample_rate != 16000: waveform_tts = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform_tts) if waveform_tts.shape[0] > 1: waveform_tts = torch.mean(waveform_tts, dim=0, keepdim=True) with torch.no_grad(): features, _ = hubert_model.extract_features(waveform_tts.to(device)) kwargs = OmegaConf.to_container(config["infer_args"].copy(), resolve=True) kwargs['content'] = [features] kwargs['condition'] = None kwargs['task'] = ["speech_to_audio"] model.eval() waveform = model.inference( scheduler=scheduler, **kwargs, ) output_file = "output_audio.wav" sf.write(output_file, waveform.squeeze().cpu().numpy(), samplerate=exp_config["sample_rate"]) return output_file with gr.Blocks(title="STAR Online Inference", theme=gr.themes.Soft()) as demo: gr.Markdown("# STAR: Speech-to-Audio Generation via Representation Learning") gr.Markdown("""
## 📚️ Introduction STAR is the first end-to-end speech-to-audio generation framework, designed to enhance efficiency and address error propagation inherent in cascaded systems. Within this space, you have the opportunity to directly control our model through voice input, thereby generating the corresponding audio output. ## 🗣️ Input A brief input speech utterance for the overall audio scene. > Example:A cat meowing and young female speaking ### 🎙️ Input Speech Example """) speech = gr.Audio(value="wav/speech.wav", label="Input Speech Example", type="filepath") gr.Markdown("""
## 🎧️ Output Capture both auditory events and scene cues and generate corresponding audio ### 🔊 Output Audio Example """) audio = gr.Audio(value="wav/audio.wav", label="Generated Audio Example", type="filepath") gr.Markdown("""
---
## 🛠️ Online Inference You can upload your own samples, or try the quick examples provided below. """) with gr.Column(): input_audio = gr.Audio(label="🗣️ Speech Input", type="filepath") btn = gr.Button("🎵Generate Audio!", variant="primary") output_audio = gr.Audio(label="🎧️ Generated Audio", type="filepath") btn.click(fn=infer, inputs=input_audio, outputs=output_audio) gr.Markdown("""
## 🎯 Quick Examples """) display_caption = gr.Textbox(label="📝 Caption" ,visible=False) with gr.Tabs(): with gr.Tab("VITS Generated Speech"): gr.Examples( examples=[ ["wav/vits/1.wav", "A cat meowing and young female speaking"], ["wav/vits/2.wav", "Sustained industrial engine noise"], ["wav/vits/3.wav", "A woman talks and a baby whispers"], ["wav/vits/4.wav", "A man speaks followed by a toilet flush"], ["wav/vits/5.wav", "It is raining and thundering, and then a man speaks"], ["wav/vits/6.wav", "A man speaking as birds are chirping"], ["wav/vits/7.wav", "A muffled man talking as a goat baas before and after two goats baaing in the distance while wind blows into a microphone"], ["wav/vits/8.wav", "Birds chirping and a horse neighing"], ["wav/vits/9.wav", "Several church bells ringing"], ["wav/vits/10.wav", "A telephone rings with bell sounds"] ], inputs=[input_audio, display_caption], label="Click examples below to try!", cache_examples = False, examples_per_page = 10, ) with gr.Tab("Real Human Speech"): gr.Examples( examples=[ ["wav/human/1.wav", "A cat meowing and young female speaking"], ["wav/human/2.wav", "Sustained industrial engine noise"], ["wav/human/3.wav", "A woman talks and a baby whispers"], ["wav/human/4.wav", "A man speaks followed by a toilet flush"], ["wav/human/5.wav", "It is raining and thundering, and then a man speaks"], ["wav/human/6.wav", "A man speaking as birds are chirping"], ["wav/human/7.wav", "A muffled man talking as a goat baas before and after two goats baaing in the distance while wind blows into a microphone"], ["wav/human/8.wav", "Birds chirping and a horse neighing"], ["wav/human/9.wav", "Several church bells ringing"], ["wav/human/10.wav", "A telephone rings with bell sounds"] ], inputs=[input_audio, display_caption], label="Click examples below to try!", cache_examples = False, examples_per_page = 10, ) demo.launch()