annoyingpixel commited on
Commit
7c02427
Β·
verified Β·
1 Parent(s): 9ccc0e4

Upload flux_space_model_manager.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. flux_space_model_manager.py +209 -0
flux_space_model_manager.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Model Manager for FLUX.1 Space - Handles multiple models and LoRA integration
4
+ """
5
+
6
+ import torch
7
+ from diffusers import DiffusionPipeline
8
+ from safetensors.torch import load_file
9
+ import os
10
+ from typing import Dict, Optional, Tuple
11
+
12
+ class FluxModelManager:
13
+ """
14
+ Manages multiple FLUX models and LoRA integration
15
+ """
16
+
17
+ def __init__(self):
18
+ self.models = {
19
+ 'flux1-dev': {
20
+ 'repo_id': 'black-forest-labs/FLUX.1-dev',
21
+ 'model_id': 'black-forest-labs/FLUX.1-dev',
22
+ 'description': 'Original FLUX.1-dev model'
23
+ },
24
+ 'flux1-krea': {
25
+ 'repo_id': 'black-forest-labs/FLUX.1-Krea-dev',
26
+ 'model_id': 'black-forest-labs/FLUX.1-Krea-dev',
27
+ 'description': 'FLUX.1-Krea-dev model'
28
+ },
29
+ 'merged': {
30
+ 'repo_id': 'local/merged_krea_55_flux_45_complete',
31
+ 'model_id': 'local/merged_krea_55_flux_45_complete',
32
+ 'description': 'Merged Krea 55% + FLUX 45% model'
33
+ }
34
+ }
35
+
36
+ self.current_model = None
37
+ self.current_pipeline = None
38
+ self.loaded_loras = {}
39
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
40
+
41
+ def load_model(self, model_name: str) -> bool:
42
+ """
43
+ Load a specific FLUX model
44
+ """
45
+ if model_name not in self.models:
46
+ print(f"❌ Model '{model_name}' not found")
47
+ return False
48
+
49
+ try:
50
+ print(f"πŸ”„ Loading model: {model_name}")
51
+ model_info = self.models[model_name]
52
+
53
+ # Load the pipeline
54
+ self.current_pipeline = DiffusionPipeline.from_pretrained(
55
+ model_info['model_id'],
56
+ torch_dtype=torch.bfloat16,
57
+ use_safetensors=True
58
+ )
59
+
60
+ if self.device == "cuda":
61
+ self.current_pipeline = self.current_pipeline.to(self.device)
62
+ self.current_pipeline.enable_attention_slicing()
63
+ self.current_pipeline.enable_vae_slicing()
64
+
65
+ self.current_model = model_name
66
+ print(f"βœ… Model '{model_name}' loaded successfully")
67
+ return True
68
+
69
+ except Exception as e:
70
+ print(f"❌ Error loading model '{model_name}': {e}")
71
+ return False
72
+
73
+ def load_lora(self, lora_path: str, lora_name: str = None, strength: float = 1.0) -> bool:
74
+ """
75
+ Load and apply a LoRA to the current model
76
+ """
77
+ if self.current_pipeline is None:
78
+ print("❌ No model loaded. Load a model first.")
79
+ return False
80
+
81
+ try:
82
+ print(f"πŸ”„ Loading LoRA: {lora_path}")
83
+
84
+ # Load LoRA weights
85
+ lora_state_dict = load_file(lora_path)
86
+
87
+ # Apply LoRA to the pipeline
88
+ self.current_pipeline.load_lora_weights(
89
+ lora_state_dict,
90
+ weight_name=lora_name,
91
+ adapter_name=lora_name or "default"
92
+ )
93
+
94
+ # Store LoRA info
95
+ lora_name = lora_name or os.path.basename(lora_path)
96
+ self.loaded_loras[lora_name] = {
97
+ 'path': lora_path,
98
+ 'strength': strength,
99
+ 'state_dict': lora_state_dict
100
+ }
101
+
102
+ print(f"βœ… LoRA '{lora_name}' loaded with strength {strength}")
103
+ return True
104
+
105
+ except Exception as e:
106
+ print(f"❌ Error loading LoRA: {e}")
107
+ return False
108
+
109
+ def unload_lora(self, lora_name: str) -> bool:
110
+ """
111
+ Unload a specific LoRA
112
+ """
113
+ if lora_name in self.loaded_loras:
114
+ try:
115
+ # Remove LoRA from pipeline
116
+ self.current_pipeline.unload_lora_weights(lora_name)
117
+ del self.loaded_loras[lora_name]
118
+ print(f"βœ… LoRA '{lora_name}' unloaded")
119
+ return True
120
+ except Exception as e:
121
+ print(f"❌ Error unloading LoRA: {e}")
122
+ return False
123
+ else:
124
+ print(f"❌ LoRA '{lora_name}' not found")
125
+ return False
126
+
127
+ def unload_all_loras(self) -> bool:
128
+ """
129
+ Unload all LoRAs
130
+ """
131
+ try:
132
+ for lora_name in list(self.loaded_loras.keys()):
133
+ self.unload_lora(lora_name)
134
+ print("βœ… All LoRAs unloaded")
135
+ return True
136
+ except Exception as e:
137
+ print(f"❌ Error unloading LoRAs: {e}")
138
+ return False
139
+
140
+ def get_model_info(self) -> Dict:
141
+ """
142
+ Get information about the current model and loaded LoRAs
143
+ """
144
+ info = {
145
+ 'current_model': self.current_model,
146
+ 'model_description': self.models.get(self.current_model, {}).get('description', 'Unknown'),
147
+ 'device': self.device,
148
+ 'loaded_loras': list(self.loaded_loras.keys()),
149
+ 'available_models': list(self.models.keys())
150
+ }
151
+ return info
152
+
153
+ def generate_image(self, prompt: str, negative_prompt: str = "",
154
+ num_inference_steps: int = 50, guidance_scale: float = 7.5,
155
+ width: int = 1024, height: int = 1024, seed: int = None) -> Tuple[torch.Tensor, Dict]:
156
+ """
157
+ Generate an image with the current model and LoRAs
158
+ """
159
+ if self.current_pipeline is None:
160
+ raise ValueError("No model loaded. Load a model first.")
161
+
162
+ # Set seed if provided
163
+ if seed is not None:
164
+ torch.manual_seed(seed)
165
+
166
+ # Generate image
167
+ result = self.current_pipeline(
168
+ prompt=prompt,
169
+ negative_prompt=negative_prompt,
170
+ num_inference_steps=num_inference_steps,
171
+ guidance_scale=guidance_scale,
172
+ width=width,
173
+ height=height,
174
+ output_type="pt"
175
+ )
176
+
177
+ # Get generation info
178
+ generation_info = {
179
+ 'model': self.current_model,
180
+ 'loras': list(self.loaded_loras.keys()),
181
+ 'prompt': prompt,
182
+ 'negative_prompt': negative_prompt,
183
+ 'steps': num_inference_steps,
184
+ 'guidance_scale': guidance_scale,
185
+ 'seed': seed
186
+ }
187
+
188
+ return result.images[0], generation_info
189
+
190
+ # Example usage for Gradio integration
191
+ def create_model_manager():
192
+ """
193
+ Create and return a model manager instance
194
+ """
195
+ return FluxModelManager()
196
+
197
+ def get_model_options():
198
+ """
199
+ Get list of available models for dropdown
200
+ """
201
+ manager = FluxModelManager()
202
+ return list(manager.models.keys())
203
+
204
+ def get_model_descriptions():
205
+ """
206
+ Get model descriptions for UI
207
+ """
208
+ manager = FluxModelManager()
209
+ return {name: info['description'] for name, info in manager.models.items()}