K00B404 commited on
Commit
bfdcf5e
1 Parent(s): ec519ae

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +76 -0
app.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import argparse
3
+ import os
4
+ import time
5
+ from os import path
6
+ from safetensors.torch import load_file
7
+ from huggingface_hub import hf_hub_download
8
+
9
+ cache_path = path.join(path.dirname(path.abspath(__file__)), "models")
10
+ os.environ["TRANSFORMERS_CACHE"] = cache_path
11
+ os.environ["HF_HUB_CACHE"] = cache_path
12
+ os.environ["HF_HOME"] = cache_path
13
+
14
+ import gradio as gr
15
+ import torch
16
+ from diffusers import StableDiffusionXLPipeline, LCMScheduler
17
+
18
+ # Remove CUDA-specific configuration
19
+ # torch.backends.cuda.matmul.allow_tf32 = True
20
+
21
+ class timer:
22
+ def __init__(self, method_name="timed process"):
23
+ self.method = method_name
24
+
25
+ def __enter__(self):
26
+ self.start = time.time()
27
+ print(f"{self.method} starts")
28
+
29
+ def __exit__(self, exc_type, exc_val, exc_tb):
30
+ end = time.time()
31
+ print(f"{self.method} took {str(round(end - self.start, 2))}s")
32
+
33
+ if not path.exists(cache_path):
34
+ os.makedirs(cache_path, exist_ok=True)
35
+
36
+ # Change device to CPU and dtype to float32
37
+ pipe = StableDiffusionXLPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float32)
38
+ pipe.to(device="cpu")
39
+
40
+ # Load Unet state on CPU
41
+ unet_state = load_file(hf_hub_download("ByteDance/Hyper-SD", "Hyper-SDXL-1step-Unet.safetensors"), device="cpu")
42
+ pipe.unet.load_state_dict(unet_state)
43
+ pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
44
+
45
+ with gr.Blocks() as demo:
46
+ with gr.Column():
47
+ with gr.Row():
48
+ with gr.Column():
49
+ num_images = gr.Slider(label="Number of Images", minimum=1, maximum=8, step=1, value=4, interactive=True)
50
+ height = gr.Number(label="Image Height", value=1024, interactive=True)
51
+ width = gr.Number(label="Image Width", value=1024, interactive=True)
52
+ prompt = gr.Text(label="Prompt", value="a photo of a cat", interactive=True)
53
+ seed = gr.Number(label="Seed", value=3413, interactive=True)
54
+ btn = gr.Button(value="run")
55
+ with gr.Column():
56
+ output = gr.Gallery(height=1024)
57
+
58
+ # Remove @spaces.GPU decorator
59
+ def process_image(num_images, height, width, prompt, seed):
60
+ global pipe
61
+ with torch.inference_mode(), timer("inference"):
62
+ return pipe(
63
+ prompt=[prompt]*num_images,
64
+ generator=torch.Generator().manual_seed(int(seed)),
65
+ num_inference_steps=1,
66
+ guidance_scale=0.,
67
+ height=int(height),
68
+ width=int(width),
69
+ timesteps=[800]
70
+ ).images
71
+
72
+ reactive_controls = [num_images, height, width, prompt, seed]
73
+ btn.click(process_image, inputs=reactive_controls, outputs=[output])
74
+
75
+ if __name__ == "__main__":
76
+ demo.launch()