File size: 13,768 Bytes
289fb74
 
 
 
 
d454ef7
 
289fb74
 
 
 
 
 
 
 
 
 
 
 
 
 
d454ef7
 
 
 
289fb74
 
d454ef7
289fb74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d454ef7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
289fb74
d454ef7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
289fb74
 
 
 
 
 
 
 
 
 
f91fd95
 
d454ef7
 
 
 
 
 
 
 
 
f91fd95
d454ef7
 
 
 
 
 
 
 
 
 
 
 
 
f91fd95
d454ef7
f91fd95
289fb74
 
 
f91fd95
 
d454ef7
 
 
 
f91fd95
d454ef7
f91fd95
d454ef7
f91fd95
 
 
289fb74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d454ef7
 
 
 
 
289fb74
d454ef7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
289fb74
 
 
 
 
 
 
d454ef7
 
 
 
 
289fb74
d454ef7
 
 
 
 
 
 
 
 
 
 
 
289fb74
 
 
 
 
 
 
d454ef7
 
 
 
 
 
 
 
 
289fb74
 
 
 
 
 
d454ef7
 
 
 
 
289fb74
 
 
 
 
 
d454ef7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
289fb74
 
 
 
 
 
d454ef7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
289fb74
 
 
 
 
 
d454ef7
 
 
 
289fb74
 
 
 
 
d454ef7
 
 
 
289fb74
 
 
 
d454ef7
 
 
 
 
 
289fb74
 
 
 
d454ef7
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
from typing import Any, List, Callable, Dict, Optional
import cv2
import threading
import numpy
from functools import lru_cache
from pathlib import Path

import SwitcherAI.processors.frame.core as frame_processors
from SwitcherAI.typing import Frame, Face
from SwitcherAI.utilities import conditional_download, resolve_relative_path

# Global variables (maintaining your original structure)
FRAME_PROCESSOR = None
THREAD_SEMAPHORE = threading.Semaphore(1)
THREAD_LOCK = threading.Lock()
NAME = 'FACEFUSION.FRAME_PROCESSOR.FRAME_ENHANCER'

# Enhanced model configuration inspired by FaceFusion
@lru_cache(maxsize=None)
def get_model_config() -> Dict[str, Any]:
    """Get model configuration with enhanced options"""
    base_path = resolve_relative_path('../.assets/models')
    if isinstance(base_path, str):
        base_path = Path(base_path)
    
    return {
        'real_esrgan_x4': {
            'model_path': base_path / 'RealESRGAN_x4plus.pth',
            'scale': 4,
            'tile_size': 256,
            'tile_pad': 16,
            'num_feat': 64,
            'num_block': 23,
            'num_grow_ch': 32
        }
    }


def get_frame_processor() -> Any:
    global FRAME_PROCESSOR

    with THREAD_LOCK:
        if FRAME_PROCESSOR is None:
            try:
                # Import Real-ESRGAN components
                from basicsr.archs.rrdbnet_arch import RRDBNet
                from realesrgan import RealESRGANer
                import torch
                
                config = get_model_config()['real_esrgan_x4']
                model_path = config['model_path']
                
                # Check if model exists
                if not model_path.exists():
                    print(f"⚠️ Real-ESRGAN model not found at: {model_path}")
                    print("🔄 Attempting to download model...")
                    if not pre_check():
                        print("❌ Failed to download Real-ESRGAN model")
                        return None
                
                FRAME_PROCESSOR = RealESRGANer(
                    model_path=str(model_path),
                    model=RRDBNet(
                        num_in_ch=3,
                        num_out_ch=3,
                        num_feat=config['num_feat'],
                        num_block=config['num_block'],
                        num_grow_ch=config['num_grow_ch'],
                        scale=config['scale']
                    ),
                    device=frame_processors.get_device(),
                    tile=config['tile_size'],
                    tile_pad=config['tile_pad'],
                    pre_pad=0,
                    scale=config['scale']
                )
                
                # Ensure CUDA device is set if available
                if torch.cuda.is_available():
                    torch.cuda.set_device(0)
                
                print("✅ Real-ESRGAN frame processor initialized")
                
            except ImportError as e:
                print(f"⚠️ Real-ESRGAN not available: {e}")
                print("💡 Install with: pip install realesrgan basicsr")
                FRAME_PROCESSOR = None
            except Exception as e:
                print(f"⚠️ Failed to initialize Real-ESRGAN: {e}")
                FRAME_PROCESSOR = None

    return FRAME_PROCESSOR


def clear_frame_processor() -> None:
    global FRAME_PROCESSOR
    FRAME_PROCESSOR = None


def pre_check() -> bool:
    """Download required models for frame enhancement"""
    try:
        download_directory_path = resolve_relative_path('../.assets/models')
        
        # Ensure download directory exists
        if isinstance(download_directory_path, str):
            download_directory_path = Path(download_directory_path)
        download_directory_path.mkdir(parents=True, exist_ok=True)
        
        # Download Real-ESRGAN model
        model_urls = [
            'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth'
        ]
        
        conditional_download(str(download_directory_path), model_urls)
        
        # Verify the model was downloaded
        model_path = download_directory_path / 'RealESRGAN_x4plus.pth'
        if model_path.exists() and model_path.stat().st_size > 0:
            print(f"✅ Real-ESRGAN model verified: {model_path.stat().st_size / (1024*1024):.1f}MB")
            return True
        else:
            print("❌ Real-ESRGAN model download failed or file is empty")
            return False
            
    except Exception as e:
        print(f"❌ Real-ESRGAN pre-check failed: {e}")
        return False


def pre_process() -> bool:
    """Pre-process check with model validation"""
    try:
        # Check if processor is available
        processor = get_frame_processor()
        if processor is None:
            print("⚠️ Real-ESRGAN not available, frame enhancement will be skipped")
            return False
        
        return True
        
    except Exception as e:
        print(f"⚠️ Frame enhancement pre-process failed: {e}")
        return False


def post_process() -> None:
    clear_frame_processor()
    # Clear cache as in FaceFusion version
    get_model_config.cache_clear()


def create_tile_frames(temp_vision_frame: Frame, tile_size: tuple = (256, 256)) -> tuple:
    """
    Enhanced tiling function inspired by FaceFusion for better memory management
    """
    height, width = temp_vision_frame.shape[:2]
    tile_height, tile_width = tile_size[0], tile_size[1]
    
    # Calculate padding
    pad_height = (tile_height - height % tile_height) % tile_height
    pad_width = (tile_width - width % tile_width) % tile_width
    
    # Pad the frame
    if pad_height > 0 or pad_width > 0:
        temp_vision_frame = numpy.pad(
            temp_vision_frame, 
            ((0, pad_height), (0, pad_width), (0, 0)), 
            mode='reflect'
        )
    
    # Create tiles
    tiles = []
    padded_height, padded_width = temp_vision_frame.shape[:2]
    
    for y in range(0, padded_height, tile_height):
        for x in range(0, padded_width, tile_width):
            tile = temp_vision_frame[y:y+tile_height, x:x+tile_width]
            tiles.append(tile)
    
    return tiles, pad_width, pad_height


def merge_tile_frames(tiles: List[Frame], original_width: int, original_height: int, 
                     pad_width: int, pad_height: int, tile_size: tuple) -> Frame:
    """
    Enhanced tile merging function inspired by FaceFusion
    """
    tile_height, tile_width = tile_size[0], tile_size[1]
    padded_height = original_height + pad_height
    padded_width = original_width + pad_width
    
    # Reconstruct the image from tiles
    result = numpy.zeros((padded_height, padded_width, 3), dtype=numpy.uint8)
    tile_idx = 0
    
    for y in range(0, padded_height, tile_height):
        for x in range(0, padded_width, tile_width):
            if tile_idx < len(tiles):
                tile = tiles[tile_idx]
                result[y:y+tile_height, x:x+tile_width] = tile
                tile_idx += 1
    
    # Remove padding and return to original size
    if pad_height > 0 or pad_width > 0:
        result = result[:original_height, :original_width]
    
    return result


def enhance_frame_with_tiling(temp_frame: Frame) -> Frame:
    """
    Enhanced frame enhancement with improved tiling (inspired by FaceFusion)
    """
    try:
        processor = get_frame_processor()
        if processor is None:
            print("⚠️ Real-ESRGAN processor not available, returning original frame")
            return temp_frame
        
        config = get_model_config()['real_esrgan_x4']
        tile_size = (config['tile_size'], config['tile_size'])
        scale = config['scale']
        
        # Create tiles for processing
        tiles, pad_width, pad_height = create_tile_frames(temp_frame, tile_size)
        enhanced_tiles = []
        
        with THREAD_SEMAPHORE:
            for tile in tiles:
                try:
                    # Process each tile individually to manage memory
                    enhanced_tile, _ = processor.enhance(tile, outscale=scale)
                    enhanced_tiles.append(enhanced_tile)
                except Exception as e:
                    print(f"⚠️ Tile enhancement failed: {e}")
                    # Use original tile if enhancement fails
                    enhanced_tiles.append(tile)
        
        # Merge tiles back together
        original_height, original_width = temp_frame.shape[:2]
        enhanced_frame = merge_tile_frames(
            enhanced_tiles, 
            original_width * scale, 
            original_height * scale,
            pad_width * scale, 
            pad_height * scale,
            (tile_size[0] * scale, tile_size[1] * scale)
        )
        
        return enhanced_frame
        
    except Exception as e:
        print(f"⚠️ Enhanced tiling failed: {e}")
        return temp_frame


def enhance_frame(temp_frame: Frame) -> Frame:
    """
    Main enhancement function with fallback to original method
    """
    try:
        processor = get_frame_processor()
        if processor is None:
            print("⚠️ Frame enhancer not available, returning original frame")
            return temp_frame
        
        # Try enhanced tiling method first
        try:
            return enhance_frame_with_tiling(temp_frame)
        except Exception as e:
            print(f"⚠️ Tiling method failed: {e}, trying simple enhancement")
            
            # Fallback to original method
            with THREAD_SEMAPHORE:
                enhanced_frame, _ = processor.enhance(temp_frame, outscale=1)
                return enhanced_frame
                
    except Exception as e:
        print(f"⚠️ Frame enhancement failed completely: {e}")
        return temp_frame


def blend_frame(original_frame: Frame, enhanced_frame: Frame, blend_ratio: float = 0.8) -> Frame:
    """
    Blend original and enhanced frames (inspired by FaceFusion)
    """
    try:
        if original_frame.shape != enhanced_frame.shape:
            original_frame = cv2.resize(original_frame, (enhanced_frame.shape[1], enhanced_frame.shape[0]))
        
        # Convert blend ratio (0-1 where 1 = full enhancement)
        return cv2.addWeighted(original_frame, 1 - blend_ratio, enhanced_frame, blend_ratio, 0)
    except Exception as e:
        print(f"⚠️ Frame blending failed: {e}")
        return enhanced_frame


def process_frame(source_face: Face, reference_face: Face, temp_frame: Frame) -> Frame:
    """
    Main processing function (maintains your original interface)
    """
    try:
        return enhance_frame(temp_frame)
    except Exception as e:
        print(f"⚠️ Error in process_frame: {e}")
        return temp_frame


def process_frames(source_path: str, temp_frame_paths: List[str], update: Callable[[], None]) -> None:
    """
    Process multiple frames (maintains your original interface)
    """
    try:
        processor = get_frame_processor()
        if processor is None:
            print("⚠️ Frame enhancer not available, skipping frame enhancement")
            if update:
                update()
            return
        
        for temp_frame_path in temp_frame_paths:
            try:
                temp_frame = cv2.imread(temp_frame_path)
                if temp_frame is not None:
                    result_frame = process_frame(None, None, temp_frame)
                    cv2.imwrite(temp_frame_path, result_frame)
                else:
                    print(f"⚠️ Failed to read frame: {temp_frame_path}")
                    
            except Exception as e:
                print(f"⚠️ Error processing frame {temp_frame_path}: {e}")
                
            if update:
                update()
                
    except Exception as e:
        print(f"⚠️ Error in process_frames: {e}")


def process_image(source_path: str, target_path: str, output_path: str) -> None:
    """
    Process single image (maintains your original interface)
    """
    try:
        processor = get_frame_processor()
        if processor is None:
            print("⚠️ Frame enhancer not available, copying original image")
            import shutil
            shutil.copy2(target_path, output_path)
            return
        
        target_frame = cv2.imread(target_path)
        if target_frame is not None:
            result = process_frame(None, None, target_frame)
            cv2.imwrite(output_path, result)
        else:
            print(f"⚠️ Failed to read image: {target_path}")
            
    except Exception as e:
        print(f"⚠️ Error in process_image: {e}")


def process_video(source_path: str, temp_frame_paths: List[str]) -> None:
    """
    Process video frames (maintains your original interface)
    """
    try:
        frame_processors.process_video(None, temp_frame_paths, process_frames)
    except Exception as e:
        print(f"⚠️ Error in process_video: {e}")


# Additional utility functions inspired by FaceFusion
def get_model_scale() -> int:
    """Get the current model's scale factor"""
    try:
        return get_model_config()['real_esrgan_x4']['scale']
    except:
        return 1


def prepare_frame(frame: Frame) -> Frame:
    """Prepare frame for processing"""
    try:
        if frame.dtype != numpy.uint8:
            frame = frame.astype(numpy.uint8)
        return frame
    except:
        return frame


def normalize_frame(frame: Frame) -> Frame:
    """Normalize frame after processing"""
    try:
        return numpy.clip(frame, 0, 255).astype(numpy.uint8)
    except:
        return frame