import gradio as gr import torch from torch import autocast from diffusers import StableDiffusionPipeline from datasets import load_dataset from PIL import Image import re from styles import css, header_html, footer_html from examples import examples from transformers import pipeline ars_model = pipeline("automatic-speech-recognition") model_id = "CompVis/stable-diffusion-v1-4" device = "cuda" if torch.cuda.is_available() else "cpu" # If you are running this code locally, you need to either do a 'huggingface-cli login` or paste your User Access Token from here https://huggingface.co/settings/tokens into the use_auth_token field below. pipe = StableDiffusionPipeline.from_pretrained( model_id, use_auth_token=True, revision="fp16", torch_dtype=torch.float16) pipe = pipe.to(device) # When running locally, you won`t have access to this, so you can remove this part word_list_dataset = load_dataset( "stabilityai/word-list", data_files="list.txt", use_auth_token=True) word_list = word_list_dataset["train"]['text'] def transcribe(audio): text = ars_model(audio)["text"] return text def infer(audio, samples, steps, scale, seed): prompt = transcribe(audio) # When running locally you can also remove this filter for filter in word_list: if re.search(rf"\b{filter}\b", prompt): raise gr.Error( "Unsafe content found. Please try again with different prompts.") generator = torch.Generator(device=device).manual_seed(seed) # If you are running locally with CPU, you can remove the `with autocast("cuda")` if device == "cuda": with autocast("cuda"): images_list = pipe( [prompt] * samples, num_inference_steps=steps, guidance_scale=scale, generator=generator, ) else: images_list = pipe( [prompt] * samples, num_inference_steps=steps, guidance_scale=scale, generator=generator, ) images = [] safe_image = Image.open(r"unsafe.png") for i, image in enumerate(images_list["sample"]): if(images_list["nsfw_content_detected"][i]): images.append(safe_image) else: images.append(image) return images block = gr.Blocks(css=css) with block: gr.HTML(header_html) with gr.Group(): with gr.Box(): with gr.Row().style(mobile_collapse=False, equal_height=True): audio = gr.Audio( label="Describe a prompt", source="microphone", type="filepath" ).style( border=(True, False, True, True), rounded=(True, False, False, True), container=False, ) btn = gr.Button("Generate image").style( margin=False, rounded=(False, True, True, False), ) gallery = gr.Gallery( label="Generated images", show_label=False, elem_id="gallery" ).style(grid=[2], height="auto") advanced_button = gr.Button("Advanced options", elem_id="advanced-btn") with gr.Row(elem_id="advanced-options"): samples = gr.Slider(label="Images", minimum=1, maximum=4, value=4, step=1) steps = gr.Slider(label="Steps", minimum=1, maximum=50, value=45, step=1) scale = gr.Slider( label="Guidance Scale", minimum=0, maximum=50, value=7.5, step=0.1 ) seed = gr.Slider( label="Seed", minimum=0, maximum=2147483647, step=1, randomize=True, ) ex = gr.Examples(fn=infer, inputs=[ audio, samples, steps, scale, seed], outputs=gallery) ex.dataset.headers = [""] audio.submit(infer, inputs=[audio, samples, steps, scale, seed], outputs=gallery) btn.click(infer, inputs=[audio, samples, steps, scale, seed], outputs=gallery) advanced_button.click( None, [], audio, _js=""" () => { const options = document.querySelector("body > gradio-app").querySelector("#advanced-options"); options.style.display = ["none", ""].includes(options.style.display) ? "flex" : "none"; }""", ) gr.HTML(footer_html) block.queue(max_size=25).launch()