|
from diffusers import DiffusionPipeline |
|
import gradio as gr |
|
import torch |
|
import time |
|
import psutil |
|
|
|
|
|
start_time = time.time() |
|
|
|
device = "GPU 🔥" if torch.cuda.is_available() else "CPU 🥶" |
|
|
|
|
|
def error_str(error, title="Error"): |
|
return ( |
|
f"""#### {title} |
|
{error}""" |
|
if error |
|
else "" |
|
) |
|
|
|
|
|
def inference( |
|
repo_id, |
|
discuss_nr, |
|
prompt, |
|
): |
|
|
|
print(psutil.virtual_memory()) |
|
|
|
seed = 0 |
|
torch_device = "cuda" if "GPU" in device else "cpu" |
|
|
|
generator = torch.Generator(torch_device).manual_seed(seed) |
|
|
|
dtype = torch.float16 if torch_device == "cuda" else torch.float32 |
|
|
|
try: |
|
revision = f"refs/pr/{discuss_nr}" |
|
pipe = DiffusionPipeline.from_pretrained(repo_id, revision=revision, torch_dtype=dtype) |
|
pipe.to(torch_device) |
|
|
|
return pipe(prompt, generator=generator, num_inference_steps=25).images, f"Done. Seed: {seed}" |
|
except Exception as e: |
|
url = f"https://huggingface.co/{repo_id}/discussions/{discuss_nr}" |
|
message = f"There is a problem with your diffusers weights of the PR: {url}. Error message: \n" |
|
return None, error_str(message + e) |
|
|
|
|
|
with gr.Blocks(css="style.css") as demo: |
|
gr.HTML( |
|
f""" |
|
<div class="diffusion"> |
|
<p> |
|
Space to test whether `diffusers` PRs work. |
|
</p> |
|
<p> |
|
Running on <b>{device}</b> |
|
</p> |
|
</div> |
|
""" |
|
) |
|
with gr.Row(): |
|
|
|
with gr.Column(scale=55): |
|
with gr.Group(): |
|
repo_id = gr.Textbox( |
|
label="Repo id on Hub", |
|
placeholder="Path to model, e.g. CompVis/stable-diffusion-v1-4 for https://huggingface.co/CompVis/stable-diffusion-v1-4", |
|
) |
|
discuss_nr = gr.Textbox( |
|
label="Discussion number", |
|
placeholder="Number of the discussion that should be checked, e.g. 171 for https://huggingface.co/CompVis/stable-diffusion-v1-4/discussions/171", |
|
) |
|
prompt = gr.Textbox( |
|
label="Prompt", |
|
default="An astronaut riding a horse on Mars.", |
|
placeholder="Enter prompt.", |
|
) |
|
gallery = gr.Gallery( |
|
label="Generated images", show_label=False, elem_id="gallery" |
|
).style(grid=[2], height="auto") |
|
|
|
error_output = gr.Markdown() |
|
|
|
generate = gr.Button(value="Generate").style( |
|
rounded=(False, True, True, False) |
|
) |
|
|
|
inputs = [ |
|
repo_id, |
|
discuss_nr, |
|
prompt, |
|
] |
|
outputs = [gallery, error_output] |
|
prompt.submit(inference, inputs=inputs, outputs=outputs) |
|
generate.click(inference, inputs=inputs, outputs=outputs) |
|
|
|
print(f"Space built in {time.time() - start_time:.2f} seconds") |
|
|
|
demo.queue(concurrency_count=1) |
|
demo.launch() |
|
|