AlekseyCalvin commited on
Commit
05a3ba6
1 Parent(s): 04250c0

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +185 -0
app.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces # Import this first to avoid CUDA initialization issues
2
+ import os
3
+ import gradio as gr
4
+ import json
5
+ import torch
6
+ import random
7
+ import time
8
+ from PIL import Image
9
+ from diffusers import DiffusionPipeline
10
+
11
+ # Define the device
12
+ device = "cuda" if torch.cuda.is_available() else "cpu"
13
+
14
+ # Use the 'waffles' environment variable as the access token
15
+ hf_token = os.getenv('waffles')
16
+
17
+ # Ensure the token is loaded correctly
18
+ if not hf_token:
19
+ raise ValueError("Hugging Face API token not found. Please set the 'waffles' environment variable.")
20
+
21
+ # Load LoRAs from JSON file
22
+ with open('loras.json', 'r') as f:
23
+ loras = json.load(f)
24
+
25
+ # Initialize the base model with authentication and specify the device
26
+ # Initialize the base model with authentication and specify the device
27
+ pipe = DiffusionPipeline.from_pretrained(
28
+ "black-forest-labs/FLUX.1-schnell",
29
+ torch_dtype=torch.bfloat16,
30
+ token=hf_token
31
+ ).to(device)
32
+
33
+ MAX_SEED = 2**32 - 1
34
+
35
+ class calculateDuration:
36
+ def __init__(self, activity_name=""):
37
+ self.activity_name = activity_name
38
+
39
+ def __enter__(self):
40
+ self.start_time = time.time()
41
+ return self
42
+
43
+ def __exit__(self, exc_type, exc_value, traceback):
44
+ self.end_time = time.time()
45
+ self.elapsed_time = self.end_time - self.start_time
46
+ if self.activity_name:
47
+ print(f"Elapsed time for {self.activity_name}: {self.elapsed_time:.6f} seconds")
48
+ else:
49
+ print(f"Elapsed time: {self.elapsed_time:.6f} seconds")
50
+
51
+ @spaces.GPU(duration=90)
52
+ def generate_images(prompt, trigger_word, steps, seed, cfg_scale, width, height, lora_scale, num_images, progress):
53
+ generator = torch.Generator(device=device).manual_seed(seed)
54
+ images = []
55
+
56
+ with calculateDuration("Generating images"):
57
+ for _ in range(num_images):
58
+ # Generate each image
59
+ image = pipe(
60
+ prompt=f"{prompt} {trigger_word}",
61
+ num_inference_steps=steps,
62
+ guidance_scale=cfg_scale,
63
+ width=width,
64
+ height=height,
65
+ generator=generator,
66
+ joint_attention_kwargs={"scale": lora_scale},
67
+ ).images[0]
68
+ images.append(image)
69
+ return images
70
+
71
+ def run_lora(prompt, cfg_scale, steps, selected_repo, randomize_seed, seed, width, height, lora_scale, num_images, progress=gr.Progress(track_tqdm=True)):
72
+ if not selected_repo:
73
+ raise gr.Error("You must select a LoRA before proceeding.")
74
+
75
+ selected_lora = next((lora for lora in loras if lora["repo"] == selected_repo), None)
76
+ if not selected_lora:
77
+ raise gr.Error("Selected LoRA not found.")
78
+
79
+ lora_path = selected_lora["repo"]
80
+ trigger_word = selected_lora["trigger_word"]
81
+
82
+ # Load LoRA weights
83
+ with calculateDuration(f"Loading LoRA weights for {selected_lora['title']}"):
84
+ if "weights" in selected_lora:
85
+ pipe.load_lora_weights(lora_path, weight_name=selected_lora["weights"])
86
+ else:
87
+ pipe.load_lora_weights(lora_path)
88
+
89
+ # Set random seed for reproducibility
90
+ with calculateDuration("Randomizing seed"):
91
+ if randomize_seed:
92
+ seed = random.randint(0, MAX_SEED)
93
+
94
+ images = generate_images(prompt, trigger_word, steps, seed, cfg_scale, width, height, lora_scale, num_images, progress)
95
+ pipe.to("cpu")
96
+ pipe.unload_lora_weights()
97
+ return images, seed
98
+
99
+ def update_selection(evt: gr.SelectData):
100
+ index = evt.index
101
+ selected_lora = loras[index]
102
+ return f"Selected LoRA: {selected_lora['title']}", selected_lora["repo"]
103
+
104
+ run_lora.zerogpu = True
105
+
106
+ css = '''
107
+ #gen_btn{height: 100%}
108
+ #title{text-align: center}
109
+ #title h1{font-size: 3em; display:inline-flex; align-items:center}
110
+ #title img{width: 100px; margin-right: 0.5em}
111
+ #gallery .grid-wrap{height: auto; width: auto;}
112
+ #gallery .gallery-item{width: 50px; height: 50px; margin: 0px;} /* Make buttons 50% height and width */
113
+ #gallery img{width: 100%; height: 100%; object-fit: cover;} /* Resize images to fit buttons */
114
+ #info_blob {
115
+ background-color: #f0f0f0;
116
+ border: 2px solid #ccc;
117
+ padding: 10px;
118
+ margin: 10px 0;
119
+ text-align: center;
120
+ font-size: 1.2em;
121
+ font-weight: bold;
122
+ color: #333;
123
+ border-radius: 8px;
124
+ }
125
+ '''
126
+ with gr.Blocks(theme=gr.themes.Soft(), css=css) as app:
127
+ title = gr.HTML(
128
+ """<h1><img src="https://huggingface.co/spaces/multimodalart/flux-lora-the-explorer/resolve/main/flux_lora.png" alt="LoRA"> FLUX LoRA the Explorer</h1>""",
129
+ elem_id="title",
130
+ )
131
+
132
+ # Info blob stating what the app is running
133
+ info_blob = gr.HTML(
134
+ """<div id="info_blob"> Activist, Futurist, and Realist LoRa-stocked Quick-Use Image Manufactory (over Flux Schnell)</div>"""
135
+ )
136
+
137
+ selected_lora_text = gr.Markdown("Selected LoRA: None")
138
+ selected_repo = gr.State(value="")
139
+
140
+ # Prompt takes the full line
141
+ prompt = gr.Textbox(label="Prompt", lines=5, placeholder="Type a prompt after selecting a LoRA", elem_id="full_line_prompt")
142
+
143
+ with gr.Row():
144
+ with gr.Column(scale=1): # LoRA collection on the left
145
+ gallery = gr.Gallery(
146
+ [(item["image"], item["title"]) for item in loras],
147
+ label="LoRA Gallery",
148
+ allow_preview=False,
149
+ columns=3,
150
+ elem_id="gallery"
151
+ )
152
+ with gr.Column(scale=1): # Generated images on the right
153
+ result = gr.Gallery(label="Generated Images")
154
+ seed = gr.Number(label="Seed", value=0, interactive=False)
155
+
156
+ with gr.Column():
157
+ with gr.Row():
158
+ cfg_scale = gr.Slider(label="CFG Scale", minimum=1, maximum=20, step=0.5, value=1)
159
+ steps = gr.Slider(label="Steps", minimum=1, maximum=50, step=1, value=4)
160
+
161
+ with gr.Row():
162
+ width = gr.Slider(label="Width", minimum=256, maximum=1536, step=64, value=1024)
163
+ height = gr.Slider(label="Height", minimum=256, maximum=1536, step=64, value=1024)
164
+
165
+ with gr.Row():
166
+ randomize_seed = gr.Checkbox(True, label="Randomize seed")
167
+ seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
168
+ lora_scale = gr.Slider(label="LoRA Scale", minimum=0, maximum=1, step=0.01, value=0.95)
169
+ num_images = gr.Slider(label="Number of Images", minimum=1, maximum=4, step=1, value=1)
170
+
171
+ gallery.select(
172
+ fn=update_selection,
173
+ inputs=[],
174
+ outputs=[selected_lora_text, selected_repo]
175
+ )
176
+
177
+ generate_button = gr.Button("Generate", variant="primary", elem_id="gen_btn")
178
+ generate_button.click(
179
+ run_lora,
180
+ inputs=[prompt, cfg_scale, steps, selected_repo, randomize_seed, seed, width, height, lora_scale, num_images],
181
+ outputs=[result, seed]
182
+ )
183
+
184
+ app.queue()
185
+ app.launch()