annoyingpixel commited on
Commit
11cc416
·
verified ·
1 Parent(s): 7ceefea

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +326 -0
app.py ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ FLUX.1 Space App Template - Enhanced with Model and LoRA Management
4
+ """
5
+
6
+ import gradio as gr
7
+ import torch
8
+ import numpy as np
9
+ from PIL import Image
10
+ import os
11
+ import json
12
+ from typing import Dict, List, Optional
13
+
14
+ # Import our managers
15
+ from flux_space_model_manager import FluxModelManager
16
+ from flux_space_lora_manager import FluxLoRAManager
17
+
18
+ class FluxSpaceApp:
19
+ """
20
+ Enhanced FLUX.1 Space application with model and LoRA management
21
+ """
22
+
23
+ def __init__(self):
24
+ self.model_manager = FluxModelManager()
25
+ self.lora_manager = FluxLoRAManager()
26
+ self.current_model = None
27
+
28
+ def create_interface(self):
29
+ """
30
+ Create the Gradio interface
31
+ """
32
+ with gr.Blocks(title="FLUX.1 Enhanced Space", theme=gr.themes.Default()) as demo:
33
+
34
+ # Header
35
+ gr.Markdown("""
36
+ # FLUX.1 Enhanced Space
37
+ **Multiple Models + LoRA Support**
38
+
39
+ Choose your base model and load custom LoRAs for enhanced image generation.
40
+ """)
41
+
42
+ with gr.Row():
43
+ with gr.Column(scale=1):
44
+ # Model Selection
45
+ gr.Markdown("### Model Selection")
46
+ model_selector = gr.Dropdown(
47
+ choices=list(self.model_manager.models.keys()),
48
+ value="flux1-dev",
49
+ label="Base Model",
50
+ info="Select the base model for generation"
51
+ )
52
+
53
+ model_info = gr.Markdown("**Model Info:** Select a model to see details")
54
+
55
+ # Load Model Button
56
+ load_model_btn = gr.Button("Load Model", variant="primary")
57
+
58
+ # Model Status
59
+ model_status = gr.Markdown("**Status:** No model loaded")
60
+
61
+ with gr.Column(scale=1):
62
+ # LoRA Management
63
+ gr.Markdown("### LoRA Management")
64
+
65
+ lora_upload = gr.File(
66
+ label="Upload LoRA (.safetensors)",
67
+ file_types=[".safetensors"],
68
+ file_count="single"
69
+ )
70
+
71
+ lora_name = gr.Textbox(
72
+ label="LoRA Name (optional)",
73
+ placeholder="Custom name for the LoRA"
74
+ )
75
+
76
+ lora_strength = gr.Slider(
77
+ minimum=0.0,
78
+ maximum=2.0,
79
+ value=1.0,
80
+ step=0.1,
81
+ label="LoRA Strength",
82
+ info="How strongly to apply the LoRA"
83
+ )
84
+
85
+ with gr.Row():
86
+ load_lora_btn = gr.Button("Load LoRA", variant="secondary")
87
+ unload_lora_btn = gr.Button("Unload LoRA", variant="stop")
88
+
89
+ # LoRA Status
90
+ lora_status = gr.Markdown("**LoRAs:** None loaded")
91
+
92
+ # Generation Parameters
93
+ with gr.Row():
94
+ with gr.Column(scale=2):
95
+ gr.Markdown("### Generation")
96
+
97
+ prompt = gr.Textbox(
98
+ label="Prompt",
99
+ placeholder="Enter your prompt here...",
100
+ lines=3
101
+ )
102
+
103
+ negative_prompt = gr.Textbox(
104
+ label="Negative Prompt",
105
+ placeholder="Enter negative prompt...",
106
+ lines=2
107
+ )
108
+
109
+ with gr.Row():
110
+ with gr.Column():
111
+ steps = gr.Slider(
112
+ minimum=10,
113
+ maximum=100,
114
+ value=50,
115
+ step=1,
116
+ label="Inference Steps"
117
+ )
118
+ guidance_scale = gr.Slider(
119
+ minimum=1.0,
120
+ maximum=20.0,
121
+ value=7.5,
122
+ step=0.1,
123
+ label="Guidance Scale"
124
+ )
125
+
126
+ with gr.Column():
127
+ width = gr.Slider(
128
+ minimum=512,
129
+ maximum=2048,
130
+ value=1024,
131
+ step=64,
132
+ label="Width"
133
+ )
134
+ height = gr.Slider(
135
+ minimum=512,
136
+ maximum=2048,
137
+ value=1024,
138
+ step=64,
139
+ label="Height"
140
+ )
141
+
142
+ seed = gr.Number(
143
+ label="Seed",
144
+ value=-1,
145
+ info="Use -1 for random seed"
146
+ )
147
+
148
+ generate_btn = gr.Button("Generate Image", variant="primary", size="lg")
149
+
150
+ with gr.Column(scale=1):
151
+ # Advanced Options
152
+ gr.Markdown("### Advanced")
153
+
154
+ # LoRA Blending
155
+ gr.Markdown("#### LoRA Blending")
156
+
157
+ lora_list = gr.Dropdown(
158
+ choices=[],
159
+ label="Select LoRAs to Blend",
160
+ multiselect=True
161
+ )
162
+
163
+ blend_weights = gr.Textbox(
164
+ label="Blend Weights (comma-separated)",
165
+ placeholder="1.0, 0.5, 0.3",
166
+ info="Weights for each LoRA in order"
167
+ )
168
+
169
+ blend_btn = gr.Button("Blend LoRAs", variant="secondary")
170
+
171
+ # Generation Info
172
+ gr.Markdown("#### Generation Info")
173
+ generation_info = gr.JSON(label="Last Generation Details")
174
+
175
+ # Output
176
+ with gr.Row():
177
+ output_image = gr.Image(
178
+ label="Generated Image",
179
+ type="pil"
180
+ )
181
+
182
+ with gr.Column():
183
+ gr.Markdown("### Generation Log")
184
+ generation_log = gr.Textbox(
185
+ label="Log",
186
+ lines=10,
187
+ max_lines=20,
188
+ interactive=False
189
+ )
190
+
191
+ # Event Handlers
192
+ def load_model_handler(model_name):
193
+ """Handle model loading"""
194
+ try:
195
+ success = self.model_manager.load_model(model_name)
196
+ if success:
197
+ model_info = self.model_manager.get_model_info()
198
+ status_text = f"Model Loaded: {model_name}"
199
+ info_text = f"""
200
+ **Current Model:** {model_info['current_model']}
201
+ **Description:** {model_info['model_description']}
202
+ **Device:** {model_info['device']}
203
+ """
204
+ self.current_model = model_name
205
+ else:
206
+ status_text = f"Failed to load: {model_name}"
207
+ info_text = "Error: Model loading failed"
208
+
209
+ return status_text, info_text
210
+
211
+ except Exception as e:
212
+ return f"Error: {str(e)}", "Error: Model loading failed"
213
+
214
+ def load_lora_handler(file, name, strength):
215
+ """Handle LoRA loading"""
216
+ try:
217
+ if file is None:
218
+ return "Error: No file uploaded", "LoRAs: None loaded"
219
+
220
+ file_path = file.name
221
+ lora_name = name if name else os.path.splitext(os.path.basename(file_path))[0]
222
+
223
+ # Load LoRA
224
+ result = self.lora_manager.load_lora_file(file_path, lora_name)
225
+
226
+ if result['success']:
227
+ # Apply to current model if available
228
+ if self.model_manager.current_pipeline is not None:
229
+ self.lora_manager.apply_lora_to_model(
230
+ lora_name,
231
+ self.model_manager.current_pipeline,
232
+ strength
233
+ )
234
+
235
+ # Update LoRA list
236
+ lora_list = list(self.lora_manager.loaded_loras.keys())
237
+
238
+ status_text = f"LoRA Loaded: {lora_name}"
239
+ lora_status_text = f"LoRAs: {', '.join(lora_list)}"
240
+
241
+ return status_text, lora_status_text, lora_list
242
+ else:
243
+ return f"Error: {result.get('error', 'Unknown error')}", "LoRAs: None loaded", []
244
+
245
+ except Exception as e:
246
+ return f"Error: {str(e)}", "LoRAs: None loaded", []
247
+
248
+ def generate_handler(prompt, negative_prompt, steps, guidance_scale, width, height, seed):
249
+ """Handle image generation"""
250
+ try:
251
+ if self.model_manager.current_pipeline is None:
252
+ return None, "Error: No model loaded", {}
253
+
254
+ # Set seed
255
+ if seed == -1:
256
+ seed = torch.randint(0, 2**32, (1,)).item()
257
+
258
+ # Generate image
259
+ image, gen_info = self.model_manager.generate_image(
260
+ prompt=prompt,
261
+ negative_prompt=negative_prompt,
262
+ num_inference_steps=steps,
263
+ guidance_scale=guidance_scale,
264
+ width=width,
265
+ height=height,
266
+ seed=seed
267
+ )
268
+
269
+ # Convert to PIL
270
+ if isinstance(image, torch.Tensor):
271
+ image = image.cpu().numpy()
272
+ if image.shape[0] == 3: # CHW format
273
+ image = np.transpose(image, (1, 2, 0))
274
+ image = (image * 255).astype(np.uint8)
275
+ image = Image.fromarray(image)
276
+
277
+ # Create log entry
278
+ log_entry = f"""
279
+ Generation Complete
280
+ Prompt: {prompt}
281
+ Negative: {negative_prompt}
282
+ Steps: {steps}, Guidance: {guidance_scale}
283
+ Size: {width}x{height}
284
+ Seed: {seed}
285
+ Model: {gen_info['model']}
286
+ LoRAs: {', '.join(gen_info['loras']) if gen_info['loras'] else 'None'}
287
+ """.strip()
288
+
289
+ return image, log_entry, gen_info
290
+
291
+ except Exception as e:
292
+ return None, f"Error: {str(e)}", {}
293
+
294
+ # Connect events
295
+ load_model_btn.click(
296
+ fn=load_model_handler,
297
+ inputs=[model_selector],
298
+ outputs=[model_status, model_info]
299
+ )
300
+
301
+ load_lora_btn.click(
302
+ fn=load_lora_handler,
303
+ inputs=[lora_upload, lora_name, lora_strength],
304
+ outputs=[lora_status, lora_status, lora_list]
305
+ )
306
+
307
+ generate_btn.click(
308
+ fn=generate_handler,
309
+ inputs=[prompt, negative_prompt, steps, guidance_scale, width, height, seed],
310
+ outputs=[output_image, generation_log, generation_info]
311
+ )
312
+
313
+ # Auto-load model when selected
314
+ model_selector.change(
315
+ fn=load_model_handler,
316
+ inputs=[model_selector],
317
+ outputs=[model_status, model_info]
318
+ )
319
+
320
+ return demo
321
+
322
+ # Main execution
323
+ if __name__ == "__main__":
324
+ app = FluxSpaceApp()
325
+ demo = app.create_interface()
326
+ demo.launch(share=True, debug=True)