from diffusers import DiffusionPipeline import gradio as gr import torch import time import psutil start_time = time.time() device = "GPU 🔥" if torch.cuda.is_available() else "CPU 🥶" def error_str(error, title="Error"): return ( f"""#### {title} {error}""" if error else "" ) def inference( repo_id, discuss_nr, prompt, ): print(psutil.virtual_memory()) # print memory usage seed = 0 torch_device = "cuda" if "GPU" in device else "cpu" generator = torch.Generator(torch_device).manual_seed(seed) dtype = torch.float16 if torch_device == "cuda" else torch.float32 try: revision = f"refs/pr/{discuss_nr}" if (discuss_nr != "" or discuss_nr is None) else None pipe = DiffusionPipeline.from_pretrained(repo_id, revision=revision, torch_dtype=dtype) pipe.to(torch_device) return pipe(prompt, generator=generator, num_inference_steps=25).images, f"Done. Seed: {seed}" except Exception as e: url = f"https://huggingface.co/{repo_id}/discussions/{discuss_nr}" message = f"There is a problem with your diffusers weights of the PR: {url}. Error message: \n" return None, error_str(message + e) with gr.Blocks(css="style.css") as demo: gr.HTML( f"""

Space to test whether `diffusers` PRs work.

Running on {device}

""" ) with gr.Row(): with gr.Column(scale=55): with gr.Group(): repo_id = gr.Textbox( label="Repo id on Hub", placeholder="Path to model, e.g. CompVis/stable-diffusion-v1-4 for https://huggingface.co/CompVis/stable-diffusion-v1-4", ) discuss_nr = gr.Textbox( label="Discussion number", placeholder="Number of the discussion that should be checked, e.g. 171 for https://huggingface.co/CompVis/stable-diffusion-v1-4/discussions/171", ) prompt = gr.Textbox( label="Prompt", default="An astronaut riding a horse on Mars.", placeholder="Enter prompt.", ) gallery = gr.Gallery( label="Generated images", show_label=False, elem_id="gallery" ).style(grid=[2], height="auto") error_output = gr.Markdown() generate = gr.Button(value="Generate").style( rounded=(False, True, True, False) ) inputs = [ repo_id, discuss_nr, prompt, ] outputs = [gallery, error_output] prompt.submit(inference, inputs=inputs, outputs=outputs) generate.click(inference, inputs=inputs, outputs=outputs) print(f"Space built in {time.time() - start_time:.2f} seconds") demo.queue(concurrency_count=1) demo.launch()