check_pr / app.py
patrickvonplaten's picture
Update app.py
47fe7d6
raw
history blame
3 kB
from diffusers import DiffusionPipeline
import gradio as gr
import torch
import time
import psutil
start_time = time.time()
device = "GPU 🔥" if torch.cuda.is_available() else "CPU 🥶"
def error_str(error, title="Error"):
return (
f"""#### {title}
{error}"""
if error
else ""
)
def inference(
repo_id,
discuss_nr,
prompt,
):
print(psutil.virtual_memory()) # print memory usage
seed = 0
torch_device = "cuda" if "GPU" in device else "cpu"
generator = torch.Generator(torch_device).manual_seed(seed)
dtype = torch.float16 if torch_device == "cuda" else torch.float32
try:
revision = f"refs/pr/{discuss_nr}"
pipe = DiffusionPipeline.from_pretrained(repo_id, revision=revision, torch_dtype=dtype)
pipe.to(torch_device)
return pipe(prompt, generator=generator, num_inference_steps=25).images, f"Done. Seed: {seed}"
except Exception as e:
url = f"https://huggingface.co/{repo_id}/discussions/{discuss_nr}"
message = f"There is a problem with your diffusers weights of the PR: {url}. Error message: \n"
return None, error_str(message + e)
with gr.Blocks(css="style.css") as demo:
gr.HTML(
f"""
<div class="diffusion">
<p>
Space to test whether `diffusers` PRs work.
</p>
<p>
Running on <b>{device}</b>
</p>
</div>
"""
)
with gr.Row():
with gr.Column(scale=55):
with gr.Group():
repo_id = gr.Textbox(
label="Repo id on Hub",
placeholder="Path to model, e.g. CompVis/stable-diffusion-v1-4 for https://huggingface.co/CompVis/stable-diffusion-v1-4",
)
discuss_nr = gr.Textbox(
label="Discussion number",
placeholder="Number of the discussion that should be checked, e.g. 171 for https://huggingface.co/CompVis/stable-diffusion-v1-4/discussions/171",
)
prompt = gr.Textbox(
label="Prompt",
default="An astronaut riding a horse on Mars.",
placeholder="Enter prompt.",
)
gallery = gr.Gallery(
label="Generated images", show_label=False, elem_id="gallery"
).style(grid=[2], height="auto")
error_output = gr.Markdown()
generate = gr.Button(value="Generate").style(
rounded=(False, True, True, False)
)
inputs = [
repo_id,
discuss_nr,
prompt,
]
outputs = [gallery, error_output]
prompt.submit(inference, inputs=inputs, outputs=outputs)
generate.click(inference, inputs=inputs, outputs=outputs)
print(f"Space built in {time.time() - start_time:.2f} seconds")
demo.queue(concurrency_count=1)
demo.launch()