gokaygokay commited on
Commit
a3a1971
1 Parent(s): b150e59

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +244 -337
app.py CHANGED
@@ -2,61 +2,30 @@ import spaces
2
  import os
3
  import requests
4
  import time
5
- import io
6
  import torch
 
 
 
 
 
7
  from PIL import Image
8
  import cv2
9
  import numpy as np
10
- from diffusers import StableDiffusionControlNetImg2ImgPipeline, ControlNetModel, DDIMScheduler
11
- from diffusers.models import AutoencoderKL
12
- from diffusers.models.attention_processor import AttnProcessor2_0
13
- from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
14
  from RealESRGAN import RealESRGAN
 
 
 
 
 
15
  import gradio as gr
16
- import subprocess
17
- from tqdm import tqdm
18
- import shutil
19
- import uuid
20
- import json
21
- import threading
22
-
23
- # Constants
24
  USE_TORCH_COMPILE = False
25
  ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
26
 
27
- # Ensure CUDA is available
28
- if not torch.cuda.is_available():
29
- raise RuntimeError("CUDA is not available. This script requires a CUDA-capable GPU.")
30
-
31
- device = torch.device("cuda")
32
- print(f"Using device: {device}")
33
-
34
- # Replace the global abort_status with an Event
35
- abort_event = threading.Event()
36
-
37
- css = """
38
- .gradio-container {
39
- max-width: 100% !important;
40
- padding: 20px !important;
41
- }
42
- #component-0 {
43
- height: auto !important;
44
- overflow: visible !important;
45
- }
46
- """
47
-
48
- def abort_job():
49
- if abort_event.is_set():
50
- return "Job is already being aborted."
51
- abort_event.set()
52
- return "Aborting job... Processing will stop after the current frame."
53
-
54
- def check_ffmpeg():
55
- try:
56
- subprocess.run(["ffmpeg", "-version"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=True)
57
- return True
58
- except (subprocess.CalledProcessError, FileNotFoundError):
59
- return False
60
 
61
  def download_file(url, folder_path, filename):
62
  if not os.path.exists(folder_path):
@@ -102,42 +71,44 @@ def timer_func(func):
102
  return result
103
  return wrapper
104
 
105
- class ModelManager:
 
 
 
 
 
 
 
 
 
 
106
  def __init__(self):
107
  self.pipe = None
108
- self.realesrgan_x2 = None
109
- self.realesrgan_x4 = None
110
 
111
- def load_models(self):
112
- if self.pipe is None:
 
 
113
  self.pipe = self.setup_pipeline()
 
114
  self.pipe.to(device)
115
- self.pipe.unet.set_attn_processor(AttnProcessor2_0())
116
- self.pipe.vae.set_attn_processor(AttnProcessor2_0())
117
  if USE_TORCH_COMPILE:
 
118
  self.pipe.unet = torch.compile(self.pipe.unet, mode="reduce-overhead", fullgraph=True)
119
 
120
- if self.realesrgan_x2 is None:
121
- self.realesrgan_x2 = RealESRGAN(device, scale=2)
122
- self.realesrgan_x2.load_weights('models/upscalers/RealESRGAN_x2.pth', download=False)
123
-
124
- if self.realesrgan_x4 is None:
125
- self.realesrgan_x4 = RealESRGAN(device, scale=4)
126
- self.realesrgan_x4.load_weights('models/upscalers/RealESRGAN_x4.pth', download=False)
127
-
128
-
129
  def setup_pipeline(self):
 
130
  controlnet = ControlNetModel.from_single_file(
131
  "models/ControlNet/control_v11f1e_sd15_tile.pth", torch_dtype=torch.float16
132
  )
133
- safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")
134
  model_path = "models/models/Stable-diffusion/juggernaut_reborn.safetensors"
135
  pipe = StableDiffusionControlNetImg2ImgPipeline.from_single_file(
136
  model_path,
137
  controlnet=controlnet,
138
  torch_dtype=torch.float16,
139
  use_safetensors=True,
140
- safety_checker=safety_checker
141
  )
142
  vae = AutoencoderKL.from_single_file(
143
  "models/VAE/vae-ft-mse-840000-ema-pruned.safetensors",
@@ -154,291 +125,227 @@ class ModelManager:
154
  pipe.enable_freeu(s1=0.9, s2=0.2, b1=1.3, b2=1.4)
155
  return pipe
156
 
157
- @timer_func
158
- def process_image_batch(self, input_images, resolution, num_inference_steps, strength, hdr, guidance_scale):
159
- condition_images = [self.prepare_image(img, resolution, hdr) for img in input_images]
160
-
161
- prompt = "masterpiece, best quality, highres"
162
- negative_prompt = "low quality, normal quality, ugly, blurry, blur, lowres, bad anatomy, bad hands, cropped, worst quality, verybadimagenegative_v1.3, JuggernautNegative-neg"
163
-
164
- options = {
165
- "prompt": [prompt] * len(input_images),
166
- "negative_prompt": [negative_prompt] * len(input_images),
167
- "image": condition_images,
168
- "control_image": condition_images,
169
- "width": condition_images[0].size[0],
170
- "height": condition_images[0].size[1],
171
- "strength": strength,
172
- "num_inference_steps": num_inference_steps,
173
- "guidance_scale": guidance_scale,
174
- "generator": torch.Generator(device=device).manual_seed(0),
175
- }
176
-
177
- print("Running inference on batch...")
178
- results = self.pipe(**options).images
179
- print("Batch processing completed successfully")
180
-
181
- return results
182
-
183
- def prepare_image(self, input_image, resolution, hdr):
184
- condition_image = self.resize_and_upscale(input_image, resolution)
185
- condition_image = self.create_hdr_effect(condition_image, hdr)
186
- return condition_image
187
 
188
- @timer_func
189
- def resize_and_upscale(self, input_image, resolution):
190
- scale = 2 if resolution <= 2048 else 4
191
-
192
- if isinstance(input_image, str):
193
- input_image = Image.open(input_image).convert("RGB")
194
- elif isinstance(input_image, io.IOBase):
195
- input_image = Image.open(input_image).convert("RGB")
196
- elif isinstance(input_image, Image.Image):
197
- input_image = input_image.convert("RGB")
198
- elif isinstance(input_image, np.ndarray):
199
- input_image = Image.fromarray(input_image).convert("RGB")
200
- else:
201
- raise ValueError(f"Unsupported input type for input_image: {type(input_image)}")
202
 
203
- W, H = input_image.size
204
- k = float(resolution) / min(H, W)
205
- H = int(round(H * k / 64.0)) * 64
206
- W = int(round(W * k / 64.0)) * 64
207
- img = input_image.resize((W, H), resample=Image.LANCZOS)
208
 
209
- if scale == 2:
210
- img = self.realesrgan_x2.predict(img)
211
- else:
212
- img = self.realesrgan_x4.predict(img)
 
 
 
213
 
214
- return img
 
215
 
216
- @timer_func
217
- def create_hdr_effect(self, original_image, hdr):
218
- if hdr == 0:
219
- return original_image
220
- cv_original = cv2.cvtColor(np.array(original_image), cv2.COLOR_RGB2BGR)
221
- factors = [1.0 - 0.9 * hdr, 1.0 - 0.7 * hdr, 1.0 - 0.45 * hdr,
222
- 1.0 - 0.25 * hdr, 1.0, 1.0 + 0.2 * hdr,
223
- 1.0 + 0.4 * hdr, 1.0 + 0.6 * hdr, 1.0 + 0.8 * hdr]
224
- images = [cv2.convertScaleAbs(cv_original, alpha=factor) for factor in factors]
225
- merge_mertens = cv2.createMergeMertens()
226
- hdr_image = merge_mertens.process(images)
227
- hdr_image_8bit = np.clip(hdr_image * 255, 0, 255).astype('uint8')
228
- hdr_result = Image.fromarray(cv2.cvtColor(hdr_image_8bit, cv2.COLOR_BGR2RGB))
229
-
230
- return hdr_result
231
-
232
- model_manager = ModelManager()
233
- model_manager.load_models() # Ensure models are loaded
234
-
235
- def extract_frames(video_path, output_folder):
236
- os.makedirs(output_folder, exist_ok=True)
237
- command = [
238
- 'ffmpeg',
239
- '-i', video_path,
240
- '-vf', 'fps=30',
241
- f'{output_folder}/frame_%06d.png'
242
- ]
243
- subprocess.run(command, check=True)
244
-
245
- def frames_to_video(input_folder, output_path, fps, original_video_path):
246
- # First, create the video from frames without audio
247
- temp_output_path = output_path + "_temp.mp4"
248
- video_command = [
249
- 'ffmpeg',
250
- '-framerate', str(fps),
251
- '-i', f'{input_folder}/frame_%06d.png',
252
- '-c:v', 'libx264',
253
- '-pix_fmt', 'yuv420p',
254
- temp_output_path
255
- ]
256
- subprocess.run(video_command, check=True)
257
-
258
- # Then, copy the audio from the original video and add it to the new video
259
- final_command = [
260
- 'ffmpeg',
261
- '-i', temp_output_path,
262
- '-i', original_video_path,
263
- '-c:v', 'copy',
264
- '-c:a', 'aac',
265
- '-map', '0:v:0',
266
- '-map', '1:a:0?',
267
- '-shortest',
268
- output_path
269
- ]
270
- subprocess.run(final_command, check=True)
271
-
272
- # Remove the temporary file
273
- os.remove(temp_output_path)
274
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
275
 
276
  @timer_func
277
- def process_video(input_video, resolution, num_inference_steps, strength, hdr, guidance_scale, max_frames=None, frame_interval=1, preserve_frames=False, batch_size=8, progress=gr.Progress()):
278
- abort_event.clear() # Clear the abort flag at the start of a new job
279
- print("Starting video processing...")
280
 
281
-
282
- # Create a new job folder
283
- job_id = str(uuid.uuid4())
284
- job_folder = os.path.join("jobs", job_id)
285
- os.makedirs(job_folder, exist_ok=True)
286
-
287
- # Save job config
288
- config = {
289
- "resolution": resolution,
290
- "num_inference_steps": num_inference_steps,
291
- "strength": strength,
292
- "hdr": hdr,
293
- "guidance_scale": guidance_scale,
294
- "max_frames": max_frames,
295
- "frame_interval": frame_interval,
296
- "preserve_frames": preserve_frames,
297
- "batch_size": batch_size
298
- }
299
- with open(os.path.join(job_folder, "config.json"), "w") as f:
300
- json.dump(config, f)
301
-
302
- # If input_video is a file object or has a 'name' attribute, use its name
303
- if isinstance(input_video, io.IOBase) or hasattr(input_video, 'name'):
304
- input_video = input_video.name
305
-
306
- # Set up folders
307
- frames_folder = os.path.join(job_folder, "video_frames")
308
- processed_frames_folder = os.path.join(job_folder, "processed_frames")
309
- os.makedirs(frames_folder, exist_ok=True)
310
- os.makedirs(processed_frames_folder, exist_ok=True)
311
-
312
- # Extract frames
313
- progress(0.1, desc="Extracting frames...")
314
- extract_frames(input_video, frames_folder)
315
-
316
- # Process selected frames
317
- frame_files = sorted(os.listdir(frames_folder))
318
- total_frames = len(frame_files)
319
- frames_to_process = min(max_frames, total_frames) if max_frames else total_frames
320
-
321
- try:
322
- progress(0.2, desc="Processing frames...")
323
- for i in tqdm(range(0, frames_to_process, batch_size), desc="Processing batches"):
324
- if abort_event.is_set():
325
- print("Job aborted. Stopping processing of new frames.")
326
- break
327
-
328
- batch_frames = frame_files[i:min(i+batch_size, frames_to_process)]
329
- input_images = [Image.open(os.path.join(frames_folder, frame)) for frame in batch_frames]
330
-
331
- processed_images = model_manager.process_image_batch(input_images, resolution, num_inference_steps, strength, hdr, guidance_scale)
332
-
333
- for frame_file, processed_image in zip(batch_frames, processed_images):
334
- output_frame_path = os.path.join(processed_frames_folder, frame_file)
335
- if not preserve_frames or not os.path.exists(output_frame_path):
336
- processed_image.save(output_frame_path)
337
-
338
- progress((0.2 + 0.7 * (i + batch_size) / frames_to_process), desc=f"Processed batch {i//batch_size + 1}/{(frames_to_process-1)//batch_size + 1}")
339
-
340
- # Always attempt to reassemble video
341
- progress(0.9, desc="Reassembling video...")
342
- input_filename = os.path.splitext(os.path.basename(input_video))[0]
343
- output_video = os.path.join(job_folder, f"{input_filename}_upscaled.mp4")
344
- frames_to_video(processed_frames_folder, output_video, 30, input_video)
345
-
346
- if abort_event.is_set():
347
- progress(1.0, desc="Video processing aborted, but partial result saved")
348
- print("Video processing aborted, but partial result saved")
349
  else:
350
- progress(1.0, desc="Video processing completed successfully")
351
- print("Video processing completed successfully")
352
-
353
- return output_video
354
-
355
- except Exception as e:
356
- print(f"An error occurred during processing: {str(e)}")
357
- progress(1.0, desc=f"Error: {str(e)}")
358
- return None
359
-
360
- @spaces.GPU(duration=200)
361
- def gradio_process_media(input_media, resolution, num_inference_steps, strength, hdr, guidance_scale, max_frames, frame_interval, preserve_frames, batch_size, progress=gr.Progress()):
362
- abort_event.clear() # Clear the abort flag at the start of a new job
363
- if input_media is None:
364
- return None, "No input media provided."
365
-
366
- print(f"Input media type: {type(input_media)}")
367
-
368
- # Get the file path
369
- if isinstance(input_media, str):
370
- file_path = input_media
371
- elif isinstance(input_media, io.IOBase):
372
- file_path = input_media.name
373
- elif hasattr(input_media, 'name'):
374
- file_path = input_media.name
 
 
 
 
 
 
 
375
  else:
376
- raise ValueError(f"Unsupported input type: {type(input_media)}")
 
 
377
 
378
- print(f"File path: {file_path}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
379
 
380
- # Check if the file is a video
381
- video_extensions = ('.mp4', '.avi', '.mov', '.mkv')
382
- if file_path.lower().endswith(video_extensions):
383
- print("Processing video...")
384
- result = process_video(file_path, resolution, num_inference_steps, strength, hdr, guidance_scale, max_frames, frame_interval, preserve_frames, batch_size, progress)
385
- if result:
386
- return result, "Video processing completed successfully."
387
- else:
388
- return None, "Error occurred during video processing."
389
- else:
390
- print("Processing image...")
391
- result = model_manager.process_image(file_path, resolution, num_inference_steps, strength, hdr, guidance_scale)
392
- if result:
393
- # Save the processed image
394
- output_path = os.path.join("processed_images", f"processed_{os.path.basename(file_path)}")
395
- os.makedirs(os.path.dirname(output_path), exist_ok=True)
396
- result.save(output_path)
397
- return output_path, "Image processing completed successfully."
398
- else:
399
- return None, "Error occurred during image processing."
400
-
401
- title = """
402
- <h1 align="center">Simple Slow Video Upscaler</h1>
403
- <p align="center">Currently takes too long :( Please try 1-2 seconds videos only.</p>
404
- <p align="center">
405
- <a href="https://twitter.com/hrishioa" target="_blank">[Hrishi]</a>
406
- <a href="https://huggingface.co/spaces/gokaygokay/Tile-Upscaler" target="_blank">[gokaygokay/Tile-Upscaler]</a>
407
- <a href="https://github.com/philz1337x/clarity-upscaler" target="_blank">[philz1337x]</a>
408
- <a href="https://github.com/BatouResearch/controlnet-tile-upscale" target="_blank">[BatouResearch]</a>
409
- <a href="https://github.com/hrishioa/SimpleSlowVideoUpscaler" target="_blank">[Hrishi GitHub]</a>
410
- </p>
411
- """
412
- # Update the Gradio interface
413
- with gr.Blocks(css=css, theme=gr.themes.Default(primary_hue="blue")) as iface:
414
- gr.HTML(title)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
415
 
 
 
416
  with gr.Row():
417
- with gr.Column(scale=2):
418
- input_media = gr.File(label="Input Media (Image or Video)")
419
- resolution = gr.Slider(256, 2048, 512, step=256, label="Resolution")
420
- num_inference_steps = gr.Slider(1, 50, 20, step=1, label="Number of Inference Steps")
421
- strength = gr.Slider(0, 1, 0.05, step=0.01, label="Strength")
422
- hdr = gr.Slider(0, 1, 0, step=0.1, label="HDR Effect")
423
- guidance_scale = gr.Slider(0, 20, 5, step=0.5, label="Guidance Scale")
424
- max_frames = gr.Number(label="Max Frames to Process (leave empty for full video)", precision=0)
425
- frame_interval = gr.Slider(1, 30, 1, step=1, label="Frame Interval (process every nth frame)")
426
- preserve_frames = gr.Checkbox(label="Preserve Existing Processed Frames", value=True)
427
- batch_size = gr.Slider(1, 16, 8, step=1, label="Batch Size")
428
-
429
- with gr.Column(scale=1):
430
- submit_button = gr.Button("Process Media")
431
- abort_button = gr.Button("Abort Job")
432
- output = gr.File(label="Processed Media")
433
- status = gr.Markdown("Ready to process media.")
434
-
435
- submit_button.click(
436
- gradio_process_media,
437
- inputs=[input_media, resolution, num_inference_steps, strength, hdr, guidance_scale, max_frames, frame_interval, preserve_frames, batch_size],
438
- outputs=[output, status]
439
- )
440
-
441
- abort_button.click(abort_job, inputs=[], outputs=status)
442
-
443
- # Launch the Gradio app
444
- iface.launch()
 
2
  import os
3
  import requests
4
  import time
5
+
6
  import torch
7
+
8
+ from diffusers import StableDiffusionControlNetImg2ImgPipeline, ControlNetModel, DDIMScheduler, DPMSolverMultistepScheduler
9
+ from diffusers.models import AutoencoderKL
10
+ from diffusers.models.attention_processor import AttnProcessor2_0
11
+
12
  from PIL import Image
13
  import cv2
14
  import numpy as np
15
+
 
 
 
16
  from RealESRGAN import RealESRGAN
17
+
18
+ import random
19
+ import math
20
+ from scipy.signal import gaussian
21
+
22
  import gradio as gr
23
+ from gradio_imageslider import ImageSlider
24
+
 
 
 
 
 
 
25
  USE_TORCH_COMPILE = False
26
  ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
27
 
28
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
  def download_file(url, folder_path, filename):
31
  if not os.path.exists(folder_path):
 
71
  return result
72
  return wrapper
73
 
74
+ def get_scheduler(scheduler_name, config):
75
+ if scheduler_name == "DDIM":
76
+ return DDIMScheduler.from_config(config)
77
+ elif scheduler_name == "DPM++ 3M SDE Karras":
78
+ return DPMSolverMultistepScheduler.from_config(config, algorithm_type="sde-dpmsolver++", use_karras_sigmas=True)
79
+ elif scheduler_name == "DPM++ 3M Karras":
80
+ return DPMSolverMultistepScheduler.from_config(config, algorithm_type="dpmsolver++", use_karras_sigmas=True)
81
+ else:
82
+ raise ValueError(f"Unknown scheduler: {scheduler_name}")
83
+
84
+ class LazyLoadPipeline:
85
  def __init__(self):
86
  self.pipe = None
 
 
87
 
88
+ @timer_func
89
+ def load(self):
90
+ if self.pipe is None:
91
+ print("Starting to load the pipeline...")
92
  self.pipe = self.setup_pipeline()
93
+ print(f"Moving pipeline to device: {device}")
94
  self.pipe.to(device)
 
 
95
  if USE_TORCH_COMPILE:
96
+ print("Compiling the model...")
97
  self.pipe.unet = torch.compile(self.pipe.unet, mode="reduce-overhead", fullgraph=True)
98
 
99
+ @timer_func
 
 
 
 
 
 
 
 
100
  def setup_pipeline(self):
101
+ print("Setting up the pipeline...")
102
  controlnet = ControlNetModel.from_single_file(
103
  "models/ControlNet/control_v11f1e_sd15_tile.pth", torch_dtype=torch.float16
104
  )
 
105
  model_path = "models/models/Stable-diffusion/juggernaut_reborn.safetensors"
106
  pipe = StableDiffusionControlNetImg2ImgPipeline.from_single_file(
107
  model_path,
108
  controlnet=controlnet,
109
  torch_dtype=torch.float16,
110
  use_safetensors=True,
111
+ safety_checker=None
112
  )
113
  vae = AutoencoderKL.from_single_file(
114
  "models/VAE/vae-ft-mse-840000-ema-pruned.safetensors",
 
125
  pipe.enable_freeu(s1=0.9, s2=0.2, b1=1.3, b2=1.4)
126
  return pipe
127
 
128
+ def set_scheduler(self, scheduler_name):
129
+ if self.pipe is not None:
130
+ self.pipe.scheduler = get_scheduler(scheduler_name, self.pipe.scheduler.config)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
 
132
+ def __call__(self, *args, **kwargs):
133
+ return self.pipe(*args, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
134
 
135
+ class LazyRealESRGAN:
136
+ def __init__(self, device, scale):
137
+ self.device = device
138
+ self.scale = scale
139
+ self.model = None
140
 
141
+ def load_model(self):
142
+ if self.model is None:
143
+ self.model = RealESRGAN(self.device, scale=self.scale)
144
+ self.model.load_weights(f'models/upscalers/RealESRGAN_x{self.scale}.pth', download=False)
145
+ def predict(self, img):
146
+ self.load_model()
147
+ return self.model.predict(img)
148
 
149
+ lazy_realesrgan_x2 = LazyRealESRGAN(device, scale=2)
150
+ lazy_realesrgan_x4 = LazyRealESRGAN(device, scale=4)
151
 
152
+ @timer_func
153
+ def resize_and_upscale(input_image, resolution):
154
+ scale = 2 if resolution <= 2048 else 4
155
+ input_image = input_image.convert("RGB")
156
+ W, H = input_image.size
157
+ k = float(resolution) / min(H, W)
158
+ H = int(round(H * k / 64.0)) * 64
159
+ W = int(round(W * k / 64.0)) * 64
160
+ img = input_image.resize((W, H), resample=Image.LANCZOS)
161
+ if scale == 2:
162
+ img = lazy_realesrgan_x2.predict(img)
163
+ else:
164
+ img = lazy_realesrgan_x4.predict(img)
165
+ return img
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
 
167
+ @timer_func
168
+ def create_hdr_effect(original_image, hdr):
169
+ if hdr == 0:
170
+ return original_image
171
+ cv_original = cv2.cvtColor(np.array(original_image), cv2.COLOR_RGB2BGR)
172
+ factors = [1.0 - 0.9 * hdr, 1.0 - 0.7 * hdr, 1.0 - 0.45 * hdr,
173
+ 1.0 - 0.25 * hdr, 1.0, 1.0 + 0.2 * hdr,
174
+ 1.0 + 0.4 * hdr, 1.0 + 0.6 * hdr, 1.0 + 0.8 * hdr]
175
+ images = [cv2.convertScaleAbs(cv_original, alpha=factor) for factor in factors]
176
+ merge_mertens = cv2.createMergeMertens()
177
+ hdr_image = merge_mertens.process(images)
178
+ hdr_image_8bit = np.clip(hdr_image * 255, 0, 255).astype('uint8')
179
+ return Image.fromarray(cv2.cvtColor(hdr_image_8bit, cv2.COLOR_BGR2RGB))
180
+
181
+ lazy_pipe = LazyLoadPipeline()
182
+ lazy_pipe.load()
183
 
184
  @timer_func
185
+ def progressive_upscale(input_image, target_resolution, steps=3):
186
+ current_image = input_image.convert("RGB")
187
+ current_size = max(current_image.size)
188
 
189
+ for _ in range(steps):
190
+ if current_size >= target_resolution:
191
+ break
192
+
193
+ scale_factor = min(2, target_resolution / current_size)
194
+ new_size = (int(current_image.width * scale_factor), int(current_image.height * scale_factor))
195
+
196
+ if scale_factor <= 1.5:
197
+ current_image = current_image.resize(new_size, Image.LANCZOS)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
  else:
199
+ current_image = lazy_realesrgan_x2.predict(current_image)
200
+
201
+ current_size = max(current_image.size)
202
+
203
+ # Final resize to exact target resolution
204
+ if current_size != target_resolution:
205
+ aspect_ratio = current_image.width / current_image.height
206
+ if current_image.width > current_image.height:
207
+ new_size = (target_resolution, int(target_resolution / aspect_ratio))
208
+ else:
209
+ new_size = (int(target_resolution * aspect_ratio), target_resolution)
210
+ current_image = current_image.resize(new_size, Image.LANCZOS)
211
+
212
+ return current_image
213
+
214
+ def prepare_image(input_image, resolution, hdr):
215
+ upscaled_image = progressive_upscale(input_image, resolution)
216
+ return create_hdr_effect(upscaled_image, hdr)
217
+
218
+ def create_gaussian_weight(tile_size, sigma=0.3):
219
+ x = np.linspace(-1, 1, tile_size)
220
+ y = np.linspace(-1, 1, tile_size)
221
+ xx, yy = np.meshgrid(x, y)
222
+ gaussian_weight = np.exp(-(xx**2 + yy**2) / (2 * sigma**2))
223
+ return gaussian_weight
224
+
225
+ def adaptive_tile_size(image_size, base_tile_size=512, max_tile_size=1024):
226
+ w, h = image_size
227
+ aspect_ratio = w / h
228
+ if aspect_ratio > 1:
229
+ tile_w = min(w, max_tile_size)
230
+ tile_h = min(int(tile_w / aspect_ratio), max_tile_size)
231
  else:
232
+ tile_h = min(h, max_tile_size)
233
+ tile_w = min(int(tile_h * aspect_ratio), max_tile_size)
234
+ return max(tile_w, base_tile_size), max(tile_h, base_tile_size)
235
 
236
+ def process_tile(tile, num_inference_steps, strength, guidance_scale, controlnet_strength):
237
+ prompt = "masterpiece, best quality, highres"
238
+ negative_prompt = "low quality, normal quality, ugly, blurry, blur, lowres, bad anatomy, bad hands, cropped, worst quality, verybadimagenegative_v1.3, JuggernautNegative-neg"
239
+
240
+ options = {
241
+ "prompt": prompt,
242
+ "negative_prompt": negative_prompt,
243
+ "image": tile,
244
+ "control_image": tile,
245
+ "num_inference_steps": num_inference_steps,
246
+ "strength": strength,
247
+ "guidance_scale": guidance_scale,
248
+ "controlnet_conditioning_scale": float(controlnet_strength),
249
+ "generator": torch.Generator(device=device).manual_seed(random.randint(0, 2147483647)),
250
+ }
251
+
252
+ return np.array(lazy_pipe(**options).images[0])
253
 
254
+ @spaces.GPU
255
+ @timer_func
256
+ def gradio_process_image(input_image, resolution, num_inference_steps, strength, hdr, guidance_scale, controlnet_strength, scheduler_name):
257
+ print("Starting image processing...")
258
+ torch.cuda.empty_cache()
259
+ lazy_pipe.set_scheduler(scheduler_name)
260
+
261
+ # Convert input_image to numpy array
262
+ input_array = np.array(input_image)
263
+
264
+ # Prepare the condition image
265
+ condition_image = prepare_image(input_image, resolution, hdr)
266
+ W, H = condition_image.size
267
+
268
+ # Adaptive tiling
269
+ tile_width, tile_height = adaptive_tile_size((W, H))
270
+
271
+ # Calculate the number of tiles
272
+ overlap = min(64, tile_width // 8, tile_height // 8) # Adaptive overlap
273
+ num_tiles_x = math.ceil((W - overlap) / (tile_width - overlap))
274
+ num_tiles_y = math.ceil((H - overlap) / (tile_height - overlap))
275
+
276
+ # Create a blank canvas for the result
277
+ result = np.zeros((H, W, 3), dtype=np.float32)
278
+ weight_sum = np.zeros((H, W, 1), dtype=np.float32)
279
+
280
+ # Create gaussian weight
281
+ gaussian_weight = create_gaussian_weight(max(tile_width, tile_height))
282
+
283
+ for i in range(num_tiles_y):
284
+ for j in range(num_tiles_x):
285
+ # Calculate tile coordinates
286
+ left = j * (tile_width - overlap)
287
+ top = i * (tile_height - overlap)
288
+ right = min(left + tile_width, W)
289
+ bottom = min(top + tile_height, H)
290
+
291
+ # Adjust tile size if it's at the edge
292
+ current_tile_size = (bottom - top, right - left)
293
+
294
+ tile = condition_image.crop((left, top, right, bottom))
295
+ tile = tile.resize((tile_width, tile_height))
296
+
297
+ # Process the tile
298
+ result_tile = process_tile(tile, num_inference_steps, strength, guidance_scale, controlnet_strength)
299
+
300
+ # Apply gaussian weighting
301
+ if current_tile_size != (tile_width, tile_height):
302
+ result_tile = cv2.resize(result_tile, current_tile_size[::-1])
303
+ tile_weight = cv2.resize(gaussian_weight, current_tile_size[::-1])
304
+ else:
305
+ tile_weight = gaussian_weight[:current_tile_size[0], :current_tile_size[1]]
306
+
307
+ # Add the tile to the result with gaussian weighting
308
+ result[top:bottom, left:right] += result_tile * tile_weight[:, :, np.newaxis]
309
+ weight_sum[top:bottom, left:right] += tile_weight[:, :, np.newaxis]
310
+
311
+ # Normalize the result
312
+ final_result = (result / weight_sum).astype(np.uint8)
313
+
314
+ print("Image processing completed successfully")
315
+
316
+ return [input_array, final_result]
317
+
318
+ title = """<h1 align="center">Tiled Upscaler V2</h1>
319
+ <p align="center">The main ideas come from</p>
320
+ <p><center>
321
+ <a href="https://github.com/philz1337x/clarity-upscaler" target="_blank">[philz1337x]</a>
322
+ <a href="https://github.com/BatouResearch/controlnet-tile-upscale" target="_blank">[Pau-Lozano]</a>
323
+ </center></p>
324
+ """
325
 
326
+ with gr.Blocks() as demo:
327
+ gr.HTML(title)
328
  with gr.Row():
329
+ with gr.Column():
330
+ input_image = gr.Image(type="pil", label="Input Image")
331
+ run_button = gr.Button("Enhance Image")
332
+ with gr.Column():
333
+ output_slider = ImageSlider(label="Before / After", type="numpy")
334
+ with gr.Accordion("Advanced Options", open=False):
335
+ resolution = gr.Slider(minimum=128, maximum=2048, value=1536, step=128, label="Resolution")
336
+ num_inference_steps = gr.Slider(minimum=1, maximum=50, value=20, step=1, label="Number of Inference Steps")
337
+ strength = gr.Slider(minimum=0, maximum=1, value=0.4, step=0.01, label="Strength")
338
+ hdr = gr.Slider(minimum=0, maximum=1, value=0, step=0.1, label="HDR Effect")
339
+ guidance_scale = gr.Slider(minimum=0, maximum=20, value=6, step=0.5, label="Guidance Scale")
340
+ controlnet_strength = gr.Slider(minimum=0.0, maximum=2.0, value=0.75, step=0.05, label="ControlNet Strength")
341
+ scheduler_name = gr.Dropdown(
342
+ choices=["DDIM", "DPM++ 3M SDE Karras", "DPM++ 3M Karras"],
343
+ value="DDIM",
344
+ label="Scheduler"
345
+ )
346
+
347
+ run_button.click(fn=gradio_process_image,
348
+ inputs=[input_image, resolution, num_inference_steps, strength, hdr, guidance_scale, controlnet_strength, scheduler_name],
349
+ outputs=output_slider)
350
+
351
+ demo.launch(debug=True, share=True)