gokaygokay commited on
Commit
9d6b28e
1 Parent(s): 237335d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +443 -0
app.py ADDED
@@ -0,0 +1,443 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import requests
3
+ import time
4
+ import io
5
+ import torch
6
+ from PIL import Image
7
+ import cv2
8
+ import numpy as np
9
+ from diffusers import StableDiffusionControlNetImg2ImgPipeline, ControlNetModel, DDIMScheduler
10
+ from diffusers.models import AutoencoderKL
11
+ from RealESRGAN import RealESRGAN
12
+ import gradio as gr
13
+ import subprocess
14
+ from tqdm import tqdm
15
+ import shutil
16
+ import uuid
17
+ import json
18
+ import threading
19
+
20
+ # Constants
21
+ USE_TORCH_COMPILE = False
22
+ ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
23
+
24
+ # Ensure CUDA is available
25
+ if not torch.cuda.is_available():
26
+ raise RuntimeError("CUDA is not available. This script requires a CUDA-capable GPU.")
27
+
28
+ device = torch.device("cuda")
29
+ print(f"Using device: {device}")
30
+
31
+ # Replace the global abort_status with an Event
32
+ abort_event = threading.Event()
33
+
34
+ css = """
35
+ .gradio-container {
36
+ max-width: 100% !important;
37
+ padding: 20px !important;
38
+ }
39
+ #component-0 {
40
+ height: auto !important;
41
+ overflow: visible !important;
42
+ }
43
+ """
44
+
45
+ def abort_job():
46
+ if abort_event.is_set():
47
+ return "Job is already being aborted."
48
+ abort_event.set()
49
+ return "Aborting job... Processing will stop after the current frame."
50
+
51
+ def check_ffmpeg():
52
+ try:
53
+ subprocess.run(["ffmpeg", "-version"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=True)
54
+ return True
55
+ except (subprocess.CalledProcessError, FileNotFoundError):
56
+ return False
57
+
58
+ def download_file(url, folder_path, filename):
59
+ if not os.path.exists(folder_path):
60
+ os.makedirs(folder_path)
61
+ file_path = os.path.join(folder_path, filename)
62
+
63
+ if os.path.isfile(file_path):
64
+ print(f"File already exists: {file_path}")
65
+ else:
66
+ response = requests.get(url, stream=True)
67
+ if response.status_code == 200:
68
+ with open(file_path, 'wb') as file:
69
+ for chunk in response.iter_content(chunk_size=1024):
70
+ file.write(chunk)
71
+ print(f"File successfully downloaded and saved: {file_path}")
72
+ else:
73
+ print(f"Error downloading the file. Status code: {response.status_code}")
74
+
75
+ def download_models():
76
+ models = {
77
+ "MODEL": ("https://huggingface.co/dantea1118/juggernaut_reborn/resolve/main/juggernaut_reborn.safetensors?download=true", "models/models/Stable-diffusion", "juggernaut_reborn.safetensors"),
78
+ "UPSCALER_X2": ("https://huggingface.co/ai-forever/Real-ESRGAN/resolve/main/RealESRGAN_x2.pth?download=true", "models/upscalers/", "RealESRGAN_x2.pth"),
79
+ "UPSCALER_X4": ("https://huggingface.co/ai-forever/Real-ESRGAN/resolve/main/RealESRGAN_x4.pth?download=true", "models/upscalers/", "RealESRGAN_x4.pth"),
80
+ "NEGATIVE_1": ("https://huggingface.co/philz1337x/embeddings/resolve/main/verybadimagenegative_v1.3.pt?download=true", "models/embeddings", "verybadimagenegative_v1.3.pt"),
81
+ "NEGATIVE_2": ("https://huggingface.co/datasets/AddictiveFuture/sd-negative-embeddings/resolve/main/JuggernautNegative-neg.pt?download=true", "models/embeddings", "JuggernautNegative-neg.pt"),
82
+ "LORA_1": ("https://huggingface.co/philz1337x/loras/resolve/main/SDXLrender_v2.0.safetensors?download=true", "models/Lora", "SDXLrender_v2.0.safetensors"),
83
+ "LORA_2": ("https://huggingface.co/philz1337x/loras/resolve/main/more_details.safetensors?download=true", "models/Lora", "more_details.safetensors"),
84
+ "CONTROLNET": ("https://huggingface.co/lllyasviel/ControlNet-v1-1/resolve/main/control_v11f1e_sd15_tile.pth?download=true", "models/ControlNet", "control_v11f1e_sd15_tile.pth"),
85
+ "VAE": ("https://huggingface.co/stabilityai/sd-vae-ft-mse-original/resolve/main/vae-ft-mse-840000-ema-pruned.safetensors?download=true", "models/VAE", "vae-ft-mse-840000-ema-pruned.safetensors"),
86
+ }
87
+
88
+ for model, (url, folder, filename) in models.items():
89
+ download_file(url, folder, filename)
90
+
91
+ def timer_func(func):
92
+ def wrapper(*args, **kwargs):
93
+ start_time = time.time()
94
+ result = func(*args, **kwargs)
95
+ end_time = time.time()
96
+ print(f"{func.__name__} took {end_time - start_time:.2f} seconds")
97
+ return result
98
+ return wrapper
99
+
100
+ class ModelManager:
101
+ def __init__(self):
102
+ self.pipe = None
103
+ self.realesrgan_x2 = None
104
+ self.realesrgan_x4 = None
105
+
106
+ def load_models(self, progress=gr.Progress()):
107
+ if self.pipe is None:
108
+ progress(0, desc="Loading Stable Diffusion pipeline...")
109
+ self.pipe = self.setup_pipeline()
110
+ self.pipe.to(device)
111
+ if USE_TORCH_COMPILE:
112
+ progress(0.5, desc="Compiling the model...")
113
+ self.pipe.unet = torch.compile(self.pipe.unet, mode="reduce-overhead", fullgraph=True)
114
+
115
+ if self.realesrgan_x2 is None:
116
+ progress(0.7, desc="Loading RealESRGAN x2 model...")
117
+ self.realesrgan_x2 = RealESRGAN(device, scale=2)
118
+ self.realesrgan_x2.load_weights('models/upscalers/RealESRGAN_x2.pth', download=False)
119
+
120
+ if self.realesrgan_x4 is None:
121
+ progress(0.9, desc="Loading RealESRGAN x4 model...")
122
+ self.realesrgan_x4 = RealESRGAN(device, scale=4)
123
+ self.realesrgan_x4.load_weights('models/upscalers/RealESRGAN_x4.pth', download=False)
124
+
125
+ progress(1.0, desc="All models loaded successfully")
126
+
127
+ def setup_pipeline(self):
128
+ controlnet = ControlNetModel.from_single_file(
129
+ "models/ControlNet/control_v11f1e_sd15_tile.pth", torch_dtype=torch.float16
130
+ )
131
+ model_path = "models/models/Stable-diffusion/juggernaut_reborn.safetensors"
132
+ pipe = StableDiffusionControlNetImg2ImgPipeline.from_single_file(
133
+ model_path,
134
+ controlnet=controlnet,
135
+ torch_dtype=torch.float16,
136
+ use_safetensors=True,
137
+ safety_checker=None
138
+ )
139
+ vae = AutoencoderKL.from_single_file(
140
+ "models/VAE/vae-ft-mse-840000-ema-pruned.safetensors",
141
+ torch_dtype=torch.float16
142
+ )
143
+ pipe.vae = vae
144
+ pipe.load_textual_inversion("models/embeddings/verybadimagenegative_v1.3.pt")
145
+ pipe.load_textual_inversion("models/embeddings/JuggernautNegative-neg.pt")
146
+ pipe.load_lora_weights("models/Lora/SDXLrender_v2.0.safetensors")
147
+ pipe.fuse_lora(lora_scale=0.5)
148
+ pipe.load_lora_weights("models/Lora/more_details.safetensors")
149
+ pipe.fuse_lora(lora_scale=1.)
150
+ pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
151
+ pipe.enable_freeu(s1=0.9, s2=0.2, b1=1.3, b2=1.4)
152
+ return pipe
153
+
154
+ @timer_func
155
+ def process_image(self, input_image, resolution, num_inference_steps, strength, hdr, guidance_scale):
156
+ condition_image = self.prepare_image(input_image, resolution, hdr)
157
+
158
+ prompt = "masterpiece, best quality, highres"
159
+ negative_prompt = "low quality, normal quality, ugly, blurry, blur, lowres, bad anatomy, bad hands, cropped, worst quality, verybadimagenegative_v1.3, JuggernautNegative-neg"
160
+
161
+ options = {
162
+ "prompt": prompt,
163
+ "negative_prompt": negative_prompt,
164
+ "image": condition_image,
165
+ "control_image": condition_image,
166
+ "width": condition_image.size[0],
167
+ "height": condition_image.size[1],
168
+ "strength": strength,
169
+ "num_inference_steps": num_inference_steps,
170
+ "guidance_scale": guidance_scale,
171
+ "generator": torch.Generator(device=device).manual_seed(0),
172
+ }
173
+
174
+ print("Running inference...")
175
+ result = self.pipe(**options).images[0]
176
+ print("Image processing completed successfully")
177
+
178
+ return result
179
+
180
+ def prepare_image(self, input_image, resolution, hdr):
181
+ condition_image = self.resize_and_upscale(input_image, resolution)
182
+ condition_image = self.create_hdr_effect(condition_image, hdr)
183
+ return condition_image
184
+
185
+ @timer_func
186
+ def resize_and_upscale(self, input_image, resolution):
187
+ scale = 2 if resolution <= 2048 else 4
188
+
189
+ if isinstance(input_image, str):
190
+ input_image = Image.open(input_image).convert("RGB")
191
+ elif isinstance(input_image, io.IOBase):
192
+ input_image = Image.open(input_image).convert("RGB")
193
+ elif isinstance(input_image, Image.Image):
194
+ input_image = input_image.convert("RGB")
195
+ elif isinstance(input_image, np.ndarray):
196
+ input_image = Image.fromarray(input_image).convert("RGB")
197
+ else:
198
+ raise ValueError(f"Unsupported input type for input_image: {type(input_image)}")
199
+
200
+ W, H = input_image.size
201
+ k = float(resolution) / min(H, W)
202
+ H = int(round(H * k / 64.0)) * 64
203
+ W = int(round(W * k / 64.0)) * 64
204
+ img = input_image.resize((W, H), resample=Image.LANCZOS)
205
+
206
+ if scale == 2:
207
+ img = self.realesrgan_x2.predict(img)
208
+ else:
209
+ img = self.realesrgan_x4.predict(img)
210
+
211
+ return img
212
+
213
+ @timer_func
214
+ def create_hdr_effect(self, original_image, hdr):
215
+ if hdr == 0:
216
+ return original_image
217
+ cv_original = cv2.cvtColor(np.array(original_image), cv2.COLOR_RGB2BGR)
218
+ factors = [1.0 - 0.9 * hdr, 1.0 - 0.7 * hdr, 1.0 - 0.45 * hdr,
219
+ 1.0 - 0.25 * hdr, 1.0, 1.0 + 0.2 * hdr,
220
+ 1.0 + 0.4 * hdr, 1.0 + 0.6 * hdr, 1.0 + 0.8 * hdr]
221
+ images = [cv2.convertScaleAbs(cv_original, alpha=factor) for factor in factors]
222
+ merge_mertens = cv2.createMergeMertens()
223
+ hdr_image = merge_mertens.process(images)
224
+ hdr_image_8bit = np.clip(hdr_image * 255, 0, 255).astype('uint8')
225
+ hdr_result = Image.fromarray(cv2.cvtColor(hdr_image_8bit, cv2.COLOR_BGR2RGB))
226
+
227
+ return hdr_result
228
+
229
+ model_manager = ModelManager()
230
+
231
+ def extract_frames(video_path, output_folder):
232
+ os.makedirs(output_folder, exist_ok=True)
233
+ command = [
234
+ 'ffmpeg',
235
+ '-i', video_path,
236
+ '-vf', 'fps=30',
237
+ f'{output_folder}/frame_%06d.png'
238
+ ]
239
+ subprocess.run(command, check=True)
240
+
241
+ def frames_to_video(input_folder, output_path, fps, original_video_path):
242
+ # First, create the video from frames without audio
243
+ temp_output_path = output_path + "_temp.mp4"
244
+ video_command = [
245
+ 'ffmpeg',
246
+ '-framerate', str(fps),
247
+ '-i', f'{input_folder}/frame_%06d.png',
248
+ '-c:v', 'libx264',
249
+ '-pix_fmt', 'yuv420p',
250
+ temp_output_path
251
+ ]
252
+ subprocess.run(video_command, check=True)
253
+
254
+ # Then, copy the audio from the original video and add it to the new video
255
+ final_command = [
256
+ 'ffmpeg',
257
+ '-i', temp_output_path,
258
+ '-i', original_video_path,
259
+ '-c:v', 'copy',
260
+ '-c:a', 'aac',
261
+ '-map', '0:v:0',
262
+ '-map', '1:a:0?',
263
+ '-shortest',
264
+ output_path
265
+ ]
266
+ subprocess.run(final_command, check=True)
267
+
268
+ # Remove the temporary file
269
+ os.remove(temp_output_path)
270
+
271
+ @timer_func
272
+ def process_video(input_video, resolution, num_inference_steps, strength, hdr, guidance_scale, max_frames=None, frame_interval=1, preserve_frames=False, progress=gr.Progress()):
273
+ abort_event.clear() # Clear the abort flag at the start of a new job
274
+ print("Starting video processing...")
275
+ model_manager.load_models(progress) # Ensure models are loaded
276
+
277
+ # Create a new job folder
278
+ job_id = str(uuid.uuid4())
279
+ job_folder = os.path.join("jobs", job_id)
280
+ os.makedirs(job_folder, exist_ok=True)
281
+
282
+ # Save job config
283
+ config = {
284
+ "resolution": resolution,
285
+ "num_inference_steps": num_inference_steps,
286
+ "strength": strength,
287
+ "hdr": hdr,
288
+ "guidance_scale": guidance_scale,
289
+ "max_frames": max_frames,
290
+ "frame_interval": frame_interval,
291
+ "preserve_frames": preserve_frames
292
+ }
293
+ with open(os.path.join(job_folder, "config.json"), "w") as f:
294
+ json.dump(config, f)
295
+
296
+ # If input_video is a file object or has a 'name' attribute, use its name
297
+ if isinstance(input_video, io.IOBase) or hasattr(input_video, 'name'):
298
+ input_video = input_video.name
299
+
300
+ # Set up folders
301
+ frames_folder = os.path.join(job_folder, "video_frames")
302
+ processed_frames_folder = os.path.join(job_folder, "processed_frames")
303
+ os.makedirs(frames_folder, exist_ok=True)
304
+ os.makedirs(processed_frames_folder, exist_ok=True)
305
+
306
+ # Extract frames
307
+ progress(0.1, desc="Extracting frames...")
308
+ extract_frames(input_video, frames_folder)
309
+
310
+ # Process selected frames
311
+ frame_files = sorted(os.listdir(frames_folder))
312
+ total_frames = len(frame_files)
313
+ frames_to_process = min(max_frames, total_frames) if max_frames else total_frames
314
+
315
+ try:
316
+ progress(0.2, desc="Processing frames...")
317
+ for i, frame_file in enumerate(tqdm(frame_files[:frames_to_process], desc="Processing frames")):
318
+ if abort_event.is_set():
319
+ print("Job aborted. Stopping processing of new frames.")
320
+ break
321
+
322
+ output_frame_path = os.path.join(processed_frames_folder, frame_file)
323
+ if not preserve_frames or not os.path.exists(output_frame_path):
324
+ if i % frame_interval == 0:
325
+ # Process this frame
326
+ input_image = Image.open(os.path.join(frames_folder, frame_file))
327
+ processed_image = model_manager.process_image(input_image, resolution, num_inference_steps, strength, hdr, guidance_scale)
328
+ processed_image.save(output_frame_path)
329
+ else:
330
+ # Copy the previous processed frame or the original frame
331
+ prev_frame = f"frame_{int(frame_file.split('_')[1].split('.')[0]) - 1:06d}.png"
332
+ prev_frame_path = os.path.join(processed_frames_folder, prev_frame)
333
+ if os.path.exists(prev_frame_path):
334
+ shutil.copy2(prev_frame_path, output_frame_path)
335
+ else:
336
+ shutil.copy2(os.path.join(frames_folder, frame_file), output_frame_path)
337
+ progress((0.2 + 0.7 * (i + 1) / frames_to_process), desc=f"Processing frame {i+1}/{frames_to_process}")
338
+
339
+ # Always attempt to reassemble video
340
+ progress(0.9, desc="Reassembling video...")
341
+ input_filename = os.path.splitext(os.path.basename(input_video))[0]
342
+ output_video = os.path.join(job_folder, f"{input_filename}_upscaled.mp4")
343
+ frames_to_video(processed_frames_folder, output_video, 30, input_video)
344
+
345
+ if abort_event.is_set():
346
+ progress(1.0, desc="Video processing aborted, but partial result saved")
347
+ print("Video processing aborted, but partial result saved")
348
+ else:
349
+ progress(1.0, desc="Video processing completed successfully")
350
+ print("Video processing completed successfully")
351
+
352
+ return output_video
353
+
354
+ except Exception as e:
355
+ print(f"An error occurred during processing: {str(e)}")
356
+ progress(1.0, desc=f"Error: {str(e)}")
357
+ return None
358
+
359
+ def gradio_process_media(input_media, resolution, num_inference_steps, strength, hdr, guidance_scale, max_frames, frame_interval, preserve_frames, progress=gr.Progress()):
360
+ abort_event.clear() # Clear the abort flag at the start of a new job
361
+ if input_media is None:
362
+ return None, "No input media provided."
363
+
364
+ print(f"Input media type: {type(input_media)}")
365
+
366
+ # Get the file path
367
+ if isinstance(input_media, str):
368
+ file_path = input_media
369
+ elif isinstance(input_media, io.IOBase):
370
+ file_path = input_media.name
371
+ elif hasattr(input_media, 'name'):
372
+ file_path = input_media.name
373
+ else:
374
+ raise ValueError(f"Unsupported input type: {type(input_media)}")
375
+
376
+ print(f"File path: {file_path}")
377
+
378
+ # Ensure models are loaded
379
+ model_manager.load_models(progress)
380
+
381
+ # Check if the file is a video
382
+ video_extensions = ('.mp4', '.avi', '.mov', '.mkv')
383
+ if file_path.lower().endswith(video_extensions):
384
+ print("Processing video...")
385
+ result = process_video(file_path, resolution, num_inference_steps, strength, hdr, guidance_scale, max_frames, frame_interval, preserve_frames, progress)
386
+ if result:
387
+ return result, "Video processing completed successfully."
388
+ else:
389
+ return None, "Error occurred during video processing."
390
+ else:
391
+ print("Processing image...")
392
+ result = model_manager.process_image(file_path, resolution, num_inference_steps, strength, hdr, guidance_scale)
393
+ if result:
394
+ # Save the processed image
395
+ output_path = os.path.join("processed_images", f"processed_{os.path.basename(file_path)}")
396
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
397
+ result.save(output_path)
398
+ return output_path, "Image processing completed successfully."
399
+ else:
400
+ return None, "Error occurred during image processing."
401
+
402
+ # Update the Gradio interface
403
+ with gr.Blocks(css=css, theme=gr.themes.Default(primary_hue="blue")) as iface:
404
+ gr.Markdown(
405
+ """
406
+ # SimpleSlowVideoUpscaler
407
+
408
+ Built by [Hrishi](https://twitter.com/hrishioa) and Claude
409
+
410
+ This project is based on [gokaygokay/Tile-Upscaler](https://huggingface.co/spaces/gokaygokay/Tile-Upscaler), which in turn is inspired by ideas from [@philz1337x/clarity-upscaler](https://github.com/philz1337x/clarity-upscaler) and [@BatouResearch/controlnet-tile-upscale](https://github.com/BatouResearch/controlnet-tile-upscale).
411
+
412
+ If you find this project useful, please consider [starring it on GitHub](https://github.com/hrishioa/SimpleSlowVideoUpscaler)!
413
+ """
414
+ )
415
+
416
+ with gr.Row():
417
+ with gr.Column(scale=2):
418
+ input_media = gr.File(label="Input Media (Image or Video)")
419
+ resolution = gr.Slider(256, 2048, 512, step=256, label="Resolution")
420
+ num_inference_steps = gr.Slider(1, 50, 10, step=1, label="Number of Inference Steps")
421
+ strength = gr.Slider(0, 1, 0.3, step=0.01, label="Strength")
422
+ hdr = gr.Slider(0, 1, 0, step=0.1, label="HDR Effect")
423
+ guidance_scale = gr.Slider(0, 20, 5, step=0.5, label="Guidance Scale")
424
+ max_frames = gr.Number(label="Max Frames to Process (leave empty for full video)", precision=0)
425
+ frame_interval = gr.Slider(1, 30, 1, step=1, label="Frame Interval (process every nth frame)")
426
+ preserve_frames = gr.Checkbox(label="Preserve Existing Processed Frames", value=True)
427
+
428
+ with gr.Column(scale=1):
429
+ submit_button = gr.Button("Process Media")
430
+ abort_button = gr.Button("Abort Job")
431
+ output = gr.File(label="Processed Media")
432
+ status = gr.Markdown("Ready to process media.")
433
+
434
+ submit_button.click(
435
+ gradio_process_media,
436
+ inputs=[input_media, resolution, num_inference_steps, strength, hdr, guidance_scale, max_frames, frame_interval, preserve_frames],
437
+ outputs=[output, status]
438
+ )
439
+
440
+ abort_button.click(abort_job, inputs=[], outputs=status)
441
+
442
+ # Launch the Gradio app
443
+ iface.launch()