fantaxy commited on
Commit
8d65abf
·
verified ·
1 Parent(s): 5ef4699

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +97 -115
app.py CHANGED
@@ -12,6 +12,11 @@ from gradio_imageslider import ImageSlider
12
  from PIL import Image
13
  from huggingface_hub import snapshot_download
14
 
 
 
 
 
 
15
  css = """
16
  #col-container {
17
  margin: 0 auto;
@@ -19,97 +24,96 @@ css = """
19
  }
20
  """
21
 
22
- # Device and dtype setup with lower precision
23
  if torch.cuda.is_available():
24
  power_device = "GPU"
25
  device = "cuda"
26
- dtype = torch.float16 # Changed to float16 for less memory usage
 
 
27
  else:
28
  power_device = "CPU"
29
  device = "cpu"
30
  dtype = torch.float32
31
 
32
- # Reduce CUDA memory usage
33
- torch.cuda.empty_cache()
34
- if torch.cuda.is_available():
35
- torch.cuda.set_per_process_memory_fraction(0.7) # Use only 70% of GPU memory
36
-
37
  huggingface_token = os.getenv("HUGGINFACE_TOKEN")
38
 
 
 
 
 
 
 
 
 
39
  model_path = snapshot_download(
40
  repo_id="black-forest-labs/FLUX.1-dev",
41
  repo_type="model",
42
- ignore_patterns=["*.md", "*..gitattributes"],
43
  local_dir="FLUX.1-dev",
44
  token=huggingface_token,
45
  )
46
 
47
- # Load pipeline with more memory optimizations
48
- controlnet = FluxControlNetModel.from_pretrained(
49
- "jasperai/Flux.1-dev-Controlnet-Upscaler",
50
- torch_dtype=dtype,
51
- low_cpu_mem_usage=True,
52
- use_safetensors=True
53
- ).to(device)
54
-
55
- pipe = FluxControlNetPipeline.from_pretrained(
56
- model_path,
57
- controlnet=controlnet,
58
- torch_dtype=dtype,
59
- low_cpu_mem_usage=True,
60
- use_safetensors=True
61
- )
62
 
63
- # Enable all possible memory optimizations
64
- pipe.enable_model_cpu_offload()
65
- pipe.enable_attention_slicing(1)
66
- pipe.enable_sequential_cpu_offload()
67
- pipe.enable_vae_slicing()
 
 
 
 
 
 
 
 
68
 
69
- # Further reduce memory usage
70
  MAX_SEED = 1000000
71
- MAX_PIXEL_BUDGET = 256 * 256 # Further reduced from 512 * 512
72
 
73
  def check_resources():
74
  if torch.cuda.is_available():
75
- gpu_memory = torch.cuda.get_device_properties(0).total_memory
76
  memory_allocated = torch.cuda.memory_allocated(0)
77
- if memory_allocated/gpu_memory > 0.8: # 80% threshold
78
- return False
 
 
79
  return True
80
 
81
  def process_input(input_image, upscale_factor, **kwargs):
82
- # Convert image to RGB mode to ensure compatibility
83
  input_image = input_image.convert('RGB')
84
 
 
85
  w, h = input_image.size
86
- w_original, h_original = w, h
87
- aspect_ratio = w / h
88
-
89
- was_resized = False
90
-
91
- if w * h * upscale_factor**2 > MAX_PIXEL_BUDGET:
92
- warnings.warn(
93
- f"Requested output image is too large. Resizing..."
94
- )
95
- gr.Info(
96
- f"Resizing input image to fit memory constraints..."
97
- )
98
- input_image = input_image.resize(
99
- (
100
- int(aspect_ratio * MAX_PIXEL_BUDGET**0.5 // upscale_factor),
101
- int(MAX_PIXEL_BUDGET**0.5 // aspect_ratio // upscale_factor),
102
- ),
103
- Image.LANCZOS
104
- )
105
- was_resized = True
106
-
107
- # resize to multiple of 8
108
  w, h = input_image.size
109
  w = w - w % 8
110
  h = h - h % 8
111
-
112
- return input_image.resize((w, h)), w_original, h_original, was_resized
113
 
114
  @spaces.GPU
115
  def infer(
@@ -122,55 +126,32 @@ def infer(
122
  progress=gr.Progress(track_tqdm=True),
123
  ):
124
  try:
125
- if not check_resources():
126
- gr.Warning("System resources are running low. Try reducing parameters.")
127
- return None
128
-
129
- # Clear CUDA cache before processing
130
- if device == "cuda":
131
- torch.cuda.empty_cache()
132
-
133
  if randomize_seed:
134
  seed = random.randint(0, MAX_SEED)
135
 
136
- true_input_image = input_image
137
- input_image, w_original, h_original, was_resized = process_input(
138
- input_image, upscale_factor
139
- )
140
-
141
- # rescale with upscale factor
142
- w, h = input_image.size
143
- control_image = input_image.resize((w * upscale_factor, h * upscale_factor))
144
-
145
- generator = torch.Generator().manual_seed(seed)
146
-
147
- gr.Info("Upscaling image...")
148
- with torch.inference_mode(): # Use inference mode to save memory
149
  image = pipe(
150
  prompt="",
151
- control_image=control_image,
152
  controlnet_conditioning_scale=controlnet_conditioning_scale,
153
  num_inference_steps=num_inference_steps,
154
- guidance_scale=3.5,
155
- height=control_image.size[1],
156
- width=control_image.size[0],
157
  generator=generator,
158
  ).images[0]
 
 
 
 
 
159
 
160
- if was_resized:
161
- gr.Info(
162
- f"Resizing output image to final size..."
163
- )
164
-
165
- # resize to target desired size
166
- image = image.resize((w_original * upscale_factor, h_original * upscale_factor))
167
- return [true_input_image, image, seed]
168
-
169
- except RuntimeError as e:
170
- if "out of memory" in str(e):
171
- gr.Warning("Not enough GPU memory. Try reducing the upscale factor or image size.")
172
- return None
173
- raise e
174
  except Exception as e:
175
  gr.Error(f"An error occurred: {str(e)}")
176
  return None
@@ -184,25 +165,25 @@ with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange", css=css) as demo:
184
  input_im = gr.Image(label="Input Image", type="pil")
185
  with gr.Column(scale=1):
186
  num_inference_steps = gr.Slider(
187
- label="Number of Inference Steps",
188
- minimum=8,
189
- maximum=30, # Reduced from 50
190
  step=1,
191
- value=20, # Reduced from 28
192
  )
193
  upscale_factor = gr.Slider(
194
- label="Upscale Factor",
195
  minimum=1,
196
- maximum=2,
197
  step=1,
198
- value=1, # Reduced default
199
  )
200
  controlnet_conditioning_scale = gr.Slider(
201
- label="Controlnet Conditioning Scale",
202
  minimum=0.1,
203
- maximum=1.0, # Reduced from 1.5
204
  step=0.1,
205
- value=0.5, # Reduced from 0.6
206
  )
207
  seed = gr.Slider(
208
  label="Seed",
@@ -211,18 +192,17 @@ with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange", css=css) as demo:
211
  step=1,
212
  value=42,
213
  )
214
-
215
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
216
 
217
  with gr.Row():
218
- result = ImageSlider(label="Input / Output", type="pil", interactive=True)
219
 
220
  current_dir = os.path.dirname(os.path.abspath(__file__))
221
 
222
  examples = gr.Examples(
223
  examples=[
224
- [42, False, os.path.join(current_dir, "z1.webp"), 20, 1, 0.5], # Reduced parameters
225
- [42, False, os.path.join(current_dir, "z2.webp"), 20, 1, 0.5], # Reduced parameters
226
  ],
227
  inputs=[
228
  seed,
@@ -234,7 +214,7 @@ with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange", css=css) as demo:
234
  ],
235
  fn=infer,
236
  outputs=result,
237
- cache_examples="lazy",
238
  )
239
 
240
  gr.on(
@@ -252,11 +232,13 @@ with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange", css=css) as demo:
252
  show_api=False,
253
  )
254
 
255
- # Launch with minimal memory usage
256
  demo.queue(max_size=1).launch(
257
  share=False,
258
  debug=True,
259
  show_error=True,
260
  max_threads=1,
261
- enable_queue=True
 
 
262
  )
 
12
  from PIL import Image
13
  from huggingface_hub import snapshot_download
14
 
15
+ # 메모리 관리를 위한 gc 추가
16
+ import gc
17
+ gc.collect()
18
+ torch.cuda.empty_cache()
19
+
20
  css = """
21
  #col-container {
22
  margin: 0 auto;
 
24
  }
25
  """
26
 
27
+ # Device setup with minimal memory usage
28
  if torch.cuda.is_available():
29
  power_device = "GPU"
30
  device = "cuda"
31
+ dtype = torch.float16 # Use float16 for minimum memory
32
+ # Set CUDA memory fraction to 50%
33
+ torch.cuda.set_per_process_memory_fraction(0.5)
34
  else:
35
  power_device = "CPU"
36
  device = "cpu"
37
  dtype = torch.float32
38
 
 
 
 
 
 
39
  huggingface_token = os.getenv("HUGGINFACE_TOKEN")
40
 
41
+ # Minimal model configuration
42
+ model_config = {
43
+ "low_cpu_mem_usage": True,
44
+ "torch_dtype": dtype,
45
+ "use_safetensors": True,
46
+ "variant": "fp16", # Use fp16 variant if available
47
+ }
48
+
49
  model_path = snapshot_download(
50
  repo_id="black-forest-labs/FLUX.1-dev",
51
  repo_type="model",
52
+ ignore_patterns=["*.md", "*..gitattributes", "*.bin"], # Ignore unnecessary files
53
  local_dir="FLUX.1-dev",
54
  token=huggingface_token,
55
  )
56
 
57
+ # Load models with minimal configuration
58
+ try:
59
+ controlnet = FluxControlNetModel.from_pretrained(
60
+ "jasperai/Flux.1-dev-Controlnet-Upscaler",
61
+ **model_config
62
+ ).to(device)
63
+
64
+ pipe = FluxControlNetPipeline.from_pretrained(
65
+ model_path,
66
+ controlnet=controlnet,
67
+ **model_config
68
+ )
 
 
 
69
 
70
+ # Enable all memory optimizations
71
+ pipe.enable_model_cpu_offload()
72
+ pipe.enable_attention_slicing(1)
73
+ pipe.enable_sequential_cpu_offload()
74
+ pipe.enable_vae_slicing()
75
+
76
+ # Clear memory after loading
77
+ gc.collect()
78
+ torch.cuda.empty_cache()
79
+
80
+ except Exception as e:
81
+ print(f"Error loading models: {e}")
82
+ raise
83
 
84
+ # Extremely reduced parameters
85
  MAX_SEED = 1000000
86
+ MAX_PIXEL_BUDGET = 128 * 128 # Extremely reduced from 256 * 256
87
 
88
  def check_resources():
89
  if torch.cuda.is_available():
 
90
  memory_allocated = torch.cuda.memory_allocated(0)
91
+ memory_reserved = torch.cuda.memory_reserved(0)
92
+ if memory_allocated/memory_reserved > 0.7: # 70% threshold
93
+ gc.collect()
94
+ torch.cuda.empty_cache()
95
  return True
96
 
97
  def process_input(input_image, upscale_factor, **kwargs):
 
98
  input_image = input_image.convert('RGB')
99
 
100
+ # Reduce image size more aggressively
101
  w, h = input_image.size
102
+ max_size = int(np.sqrt(MAX_PIXEL_BUDGET))
103
+ if w > max_size or h > max_size:
104
+ if w > h:
105
+ new_w = max_size
106
+ new_h = int(h * max_size / w)
107
+ else:
108
+ new_h = max_size
109
+ new_w = int(w * max_size / h)
110
+ input_image = input_image.resize((new_w, new_h), Image.LANCZOS)
111
+
 
 
 
 
 
 
 
 
 
 
 
 
112
  w, h = input_image.size
113
  w = w - w % 8
114
  h = h - h % 8
115
+
116
+ return input_image.resize((w, h)), w, h, True
117
 
118
  @spaces.GPU
119
  def infer(
 
126
  progress=gr.Progress(track_tqdm=True),
127
  ):
128
  try:
129
+ gc.collect()
130
+ torch.cuda.empty_cache()
131
+
 
 
 
 
 
132
  if randomize_seed:
133
  seed = random.randint(0, MAX_SEED)
134
 
135
+ input_image, w, h, _ = process_input(input_image, upscale_factor)
136
+
137
+ with torch.inference_mode():
138
+ generator = torch.Generator().manual_seed(seed)
 
 
 
 
 
 
 
 
 
139
  image = pipe(
140
  prompt="",
141
+ control_image=input_image,
142
  controlnet_conditioning_scale=controlnet_conditioning_scale,
143
  num_inference_steps=num_inference_steps,
144
+ guidance_scale=2.0, # Reduced from 3.5
145
+ height=h,
146
+ width=w,
147
  generator=generator,
148
  ).images[0]
149
+
150
+ gc.collect()
151
+ torch.cuda.empty_cache()
152
+
153
+ return [input_image, image, seed]
154
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  except Exception as e:
156
  gr.Error(f"An error occurred: {str(e)}")
157
  return None
 
165
  input_im = gr.Image(label="Input Image", type="pil")
166
  with gr.Column(scale=1):
167
  num_inference_steps = gr.Slider(
168
+ label="Steps",
169
+ minimum=1,
170
+ maximum=20, # Reduced from 30
171
  step=1,
172
+ value=10, # Reduced from 20
173
  )
174
  upscale_factor = gr.Slider(
175
+ label="Scale",
176
  minimum=1,
177
+ maximum=1, # Fixed at 1
178
  step=1,
179
+ value=1,
180
  )
181
  controlnet_conditioning_scale = gr.Slider(
182
+ label="Control Scale",
183
  minimum=0.1,
184
+ maximum=0.5, # Reduced from 1.0
185
  step=0.1,
186
+ value=0.3, # Reduced from 0.5
187
  )
188
  seed = gr.Slider(
189
  label="Seed",
 
192
  step=1,
193
  value=42,
194
  )
195
+ randomize_seed = gr.Checkbox(label="Random Seed", value=True)
 
196
 
197
  with gr.Row():
198
+ result = ImageSlider(label="Result", type="pil", interactive=True)
199
 
200
  current_dir = os.path.dirname(os.path.abspath(__file__))
201
 
202
  examples = gr.Examples(
203
  examples=[
204
+ [42, False, os.path.join(current_dir, "z1.webp"), 10, 1, 0.3],
205
+ [42, False, os.path.join(current_dir, "z2.webp"), 10, 1, 0.3],
206
  ],
207
  inputs=[
208
  seed,
 
214
  ],
215
  fn=infer,
216
  outputs=result,
217
+ cache_examples=False, # Disable caching
218
  )
219
 
220
  gr.on(
 
232
  show_api=False,
233
  )
234
 
235
+ # Launch with minimal resources
236
  demo.queue(max_size=1).launch(
237
  share=False,
238
  debug=True,
239
  show_error=True,
240
  max_threads=1,
241
+ enable_queue=True,
242
+ cache_examples=False,
243
+ quiet=True,
244
  )