File size: 45,942 Bytes
7344bef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
##### Enjoy this spagheti VRAM optimizations done by DeepBeepMeep !
# I am sure you are a nice person and as you copy this code, you will give me officially proper credits:
# Please link to https://github.com/deepbeepmeep/Wan2GP and @deepbeepmeep on twitter  
from __future__ import annotations

import gc
import math
import os
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Any

import torch
import torch.nn.functional as F
from accelerate import init_empty_weights
from einops import rearrange
from safetensors.torch import load_file
from tqdm import tqdm

from mmgp import offload
from models.wan.modules.vae import WanVAE
from .attention_backend import log_sparse_backend, require_sparge_attention
from .tcdecoder import build_tcdecoder
from .utils import Causal_LQ4x_Proj
from .wan_video_dit import WanModel, precompute_freqs_cis_3d


FLASHVSR_VARIANT_TINY_LONG = "tiny-long"
FLASHVSR_VARIANT_TINY = "tiny"
FLASHVSR_VARIANT_FULL = "full"

FLASHVSR_TOPK_RATIO = 0.0  # 0 = auto area-scaled ratio; >0 = fixed sparse attention ratio.
FLASHVSR_FULL_MIN_AUTO_TOPK_RATIO = 1.5
FLASHVSR_KV_CACHE_WINDOWS = 1  # Stream cache windows kept between denoise chunks; each window is two latent frames.
FLASHVSR_CONTINUE_CACHE_FRAMES = 11
FLASHVSR_COTENANTS_MAP = {"lq_proj": ["transformer"]}
FLASHVSR_SAVE_STILL_IMAGE_DEBUG_VIDEO = False
FLASHVSR_DISABLE_STILL_IMAGE_OPTIMIZATIONS = False
FLASHVSR_STILL_IMAGE_SHIFT_CORRECTION = True
FLASHVSR_STILL_IMAGE_SHIFT_CORRECTION_INPUT_SHIFT = None  # None = half-period output phase shift.
FLASHVSR_STILL_IMAGE_SHIFT_CORRECTION_PERIOD = 16
FLASHVSR_STILL_IMAGE_RETURN_WARMED_FRAME = True
FLASHVSR_STILL_IMAGE_SHIFT_BLEND = 0.5
FLASHVSR_STILL_IMAGE_DEBUG_VIDEO_PATH = "flashvsr_still_image_debug.mp4"
FLASHVSR_STILL_IMAGE_DEBUG_VIDEO_FPS = 4

WAN_1_3B_CONFIG = {
    "has_image_input": False,
    "patch_size": (1, 2, 2),
    "in_dim": 16,
    "dim": 1536,
    "ffn_dim": 8960,
    "freq_dim": 256,
    "text_dim": 4096,
    "out_dim": 16,
    "num_heads": 12,
    "num_layers": 30,
    "eps": 1e-6,
}


@contextmanager
def _default_dtype(dtype: torch.dtype):
    previous_dtype = torch.get_default_dtype()
    torch.set_default_dtype(dtype)
    try:
        yield
    finally:
        torch.set_default_dtype(previous_dtype)


@dataclass
class FlashVSRPaths:
    transformer: str
    lq_proj: str
    posi_prompt: str
    tcdecoder: str | None = None
    vae: str | None = None


def _preprocess_transformer_state_dict(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
    converter = WanModel.state_dict_converter()
    state_dict, _ = converter.from_civitai(state_dict)
    return state_dict


def _sinusoidal_embedding_1d(dim: int, position: torch.Tensor) -> torch.Tensor:
    sinusoid = torch.outer(position.type(torch.float64), torch.pow(10000, -torch.arange(dim // 2, dtype=torch.float64, device=position.device).div(dim // 2)))
    return torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1).to(position.dtype)


def _next_conditioning_frame_count(frame_count: int) -> int:
    padded = max(25, frame_count + 4)
    remainder = padded % 8
    if remainder != 1:
        padded += (1 - remainder) % 8
    return padded


def _aligned_output_size(height: int, width: int, scale: float) -> tuple[int, int]:
    target_h = max(1, int(height * scale))
    target_w = max(1, int(width * scale))
    return max(128, math.ceil(target_h / 128) * 128), max(128, math.ceil(target_w / 128) * 128)


def _conditioning_sizes(sample: torch.Tensor, scale: float) -> tuple[int, int, int, int]:
    _, frames, height, width = sample.shape
    output_height = max(1, int(height * scale))
    output_width = max(1, int(width * scale))
    padded_output_height, padded_output_width = _aligned_output_size(height, width, scale)
    pad_h = padded_output_height - output_height
    pad_w = padded_output_width - output_width
    if pad_h or pad_w:
        print(f"[FlashVSR] Edge padding output canvas {output_width}x{output_height} -> {padded_output_width}x{padded_output_height}; final crop restores {output_width}x{output_height}")
    return output_height, output_width, padded_output_height, padded_output_width


def _prepare_conditioning_range(sample: torch.Tensor, start: int, end: int, output_height: int, output_width: int, padded_output_height: int, padded_output_width: int, dtype: torch.dtype = torch.bfloat16) -> torch.Tensor:
    frames = int(sample.shape[1])
    pad_h = padded_output_height - output_height
    pad_w = padded_output_width - output_width
    frame_indices = [min(max(frame_idx, 0), frames - 1) for frame_idx in range(start, end)]
    lq = sample[:, frame_indices]
    if lq.dtype == torch.uint8:
        lq = lq.float().div_(127.5).sub_(1.0)
    else:
        lq = lq.detach().float().clamp_(-1.0, 1.0)
    lq = F.interpolate(lq.permute(1, 0, 2, 3).contiguous(), size=(output_height, output_width), mode="bicubic", align_corners=False)
    if pad_h or pad_w:
        lq = F.pad(lq, (0, pad_w, 0, pad_h), mode="replicate")
    return lq.clamp_(-1.0, 1.0).to(dtype=dtype).permute(1, 0, 2, 3).contiguous()


def _pad_conditioning_frames(lq_video: torch.Tensor, target_frames: int) -> torch.Tensor:
    missing = target_frames - lq_video.shape[2]
    if missing <= 0:
        return lq_video[:, :, :target_frames]
    tail = lq_video[:, :, -1:].repeat(1, 1, missing, 1, 1)
    return torch.cat([lq_video, tail], dim=2)


def _crop_output_frames(frames: torch.Tensor, height: int, width: int) -> torch.Tensor:
    if frames.shape[-2:] == (height, width):
        return frames
    return frames[..., :height, :width].contiguous()


def _shift_spatial_replicate(tensor: torch.Tensor, shift_y: int, shift_x: int) -> torch.Tensor:
    if shift_y == 0 and shift_x == 0:
        return tensor.clone()
    height, width = tensor.shape[-2:]
    shift_y = max(1 - height, min(height - 1, int(shift_y)))
    shift_x = max(1 - width, min(width - 1, int(shift_x)))
    crop = tensor[..., max(0, -shift_y):height - max(0, shift_y), max(0, -shift_x):width - max(0, shift_x)]
    return F.pad(crop, (max(0, shift_x), max(0, -shift_x), max(0, shift_y), max(0, -shift_y)), mode="replicate")


def _apply_still_image_shift_correction(base: torch.Tensor, shifted: torch.Tensor, scale: float) -> torch.Tensor:
    base_float = base.to(dtype=torch.float32, copy=True)
    corrected = base_float.lerp_(shifted.to(dtype=torch.float32), float(FLASHVSR_STILL_IMAGE_SHIFT_BLEND))
    if base.dtype == torch.uint8:
        return corrected.round_().clamp_(0, 255).to(torch.uint8)
    return corrected.clamp_(-1.0, 1.0).to(dtype=base.dtype)


def _shift_continue_cache(continue_cache: Any, shift_y: int, shift_x: int) -> Any:
    if not isinstance(continue_cache, dict):
        return continue_cache
    tail = continue_cache.get("tail_frames")
    if not torch.is_tensor(tail) or tail.ndim != 4:
        return continue_cache
    shifted_cache = dict(continue_cache)
    shifted_cache["tail_frames"] = _shift_spatial_replicate(tail, shift_y, shift_x)
    return shifted_cache


def _two_pass_shifted_continue_cache(continue_cache: Any, shift_y: int, shift_x: int) -> Any:
    if not isinstance(continue_cache, dict):
        return continue_cache
    tail = continue_cache.get("tail_frames_shifted")
    if not torch.is_tensor(tail) or tail.ndim != 4:
        return _shift_continue_cache(continue_cache, shift_y, shift_x)
    shifted_cache = dict(continue_cache)
    shifted_cache["tail_frames"] = tail
    return shifted_cache


def _make_two_pass_continue_cache(base_cache: Any, shifted_cache: Any, shift_y: int, shift_x: int, out_shift_y: int, out_shift_x: int) -> Any:
    if not isinstance(base_cache, dict):
        return base_cache
    cache = dict(base_cache)
    shifted_tail = shifted_cache.get("tail_frames") if isinstance(shifted_cache, dict) else None
    if torch.is_tensor(shifted_tail) and shifted_tail.ndim == 4:
        cache["tail_frames_shifted"] = shifted_tail.contiguous()
        cache.update({"two_pass": True, "shift_y": shift_y, "shift_x": shift_x, "out_shift_y": out_shift_y, "out_shift_x": out_shift_x})
    return cache


def _select_still_image_frame(frames: torch.Tensor, frame_index: int) -> torch.Tensor:
    return frames[:, :, frame_index:frame_index + 1].contiguous() if frames.ndim == 5 else frames[:, frame_index:frame_index + 1].contiguous()


def _decoded_frames_to_cpu(frames: torch.Tensor, frame_count: int, height: int, width: int) -> torch.Tensor:
    frames = frames.detach()[0, :, :frame_count, :height, :width]
    if frames.device.type == "cpu" and frames.dtype == torch.float32 and frames.is_contiguous():
        return frames
    frames_cpu = torch.empty(tuple(frames.shape), dtype=torch.float32, device="cpu")
    frames_cpu.copy_(frames)
    return frames_cpu


def _save_still_image_debug_video(frames: torch.Tensor) -> None:
    if not FLASHVSR_SAVE_STILL_IMAGE_DEBUG_VIDEO:
        return
    path = os.path.abspath(FLASHVSR_STILL_IMAGE_DEBUG_VIDEO_PATH)
    try:
        from shared.utils.audio_video import save_video
        debug_frames = frames.detach().cpu()
        save_video(tensor=debug_frames, save_file=path, fps=FLASHVSR_STILL_IMAGE_DEBUG_VIDEO_FPS, nrow=1, normalize=True, value_range=(-1, 1), codec_type="libx264_8", container="mp4")
        print(f"[FlashVSR] Still image debug video saved to {path} ({int(debug_frames.shape[2])} frames)")
        del debug_frames
    except Exception as exc:
        print(f"[FlashVSR] Failed to save still image debug video: {exc}")


def _nested_tensors_to(value: Any, device: torch.device | str, dtype: torch.dtype | None = None) -> Any:
    if torch.is_tensor(value):
        return value.detach().to(device=device, dtype=dtype or value.dtype)
    if isinstance(value, list):
        return [_nested_tensors_to(item, device, dtype) for item in value]
    return value


def _tcdecoder_mem_halo_latents(tcdecoder: torch.nn.Module) -> int:
    radius = 0.0
    jump = 1.0
    decoder = tcdecoder.taehv.decoder if hasattr(tcdecoder, "taehv") else tcdecoder.decoder
    for module in decoder:
        if isinstance(module, torch.nn.Conv2d):
            kernel = module.kernel_size[0] if isinstance(module.kernel_size, tuple) else int(module.kernel_size)
            radius += ((kernel - 1) / 2) * jump
        elif module.__class__.__name__ == "MemBlock":
            for submodule in module.conv:
                if isinstance(submodule, torch.nn.Conv2d):
                    kernel = submodule.kernel_size[0] if isinstance(submodule.kernel_size, tuple) else int(submodule.kernel_size)
                    radius += ((kernel - 1) / 2) * jump
        elif isinstance(module, torch.nn.Upsample):
            scale = module.scale_factor[0] if isinstance(module.scale_factor, tuple) else module.scale_factor
            jump /= float(scale or 1)
    return max(1, int(math.ceil(radius)))


def _report_progress(progress_callback, phase: str, current_step: int | None = None, total_steps: int | None = None) -> None:
    if callable(progress_callback):
        progress_callback(phase, current_step, total_steps)


def _abort_requested(abort_callback) -> bool:
    return callable(abort_callback) and abort_callback()


def _apply_continue_cache(frames: torch.Tensor, continue_cache: Any) -> torch.Tensor:
    if not isinstance(continue_cache, dict):
        return frames
    tail = continue_cache.get("tail_frames")
    if not torch.is_tensor(tail) or tail.ndim != 4:
        return frames
    if tail.shape[0] != frames.shape[0] or tail.shape[-2:] != frames.shape[-2:]:
        return frames
    overlap = min(int(tail.shape[1]), int(frames.shape[1]))
    if overlap <= 0:
        return frames
    if frames.dtype == torch.uint8:
        if tail.dtype != torch.uint8:
            tail = tail.float().clamp(-1.0, 1.0).add(1.0).mul_(127.5).round_().clamp_(0, 255).to(torch.uint8)
        frames[:, :overlap].copy_(tail[:, -overlap:].to(device=frames.device))
        return frames
    if tail.dtype == torch.uint8:
        tail = tail.to(device=frames.device, dtype=frames.dtype).div(127.5).sub(1.0)
    else:
        tail = tail.to(device=frames.device, dtype=frames.dtype)
    frames[:, :overlap].copy_(tail[:, -overlap:])
    return frames


def _make_continue_cache(frames: torch.Tensor, scale: float, variant: str, overlap_frames: int = FLASHVSR_CONTINUE_CACHE_FRAMES) -> dict[str, Any]:
    tail_len = min(overlap_frames, frames.shape[1])
    tail = frames[:, -tail_len:].detach().cpu()
    if tail.dtype != torch.uint8:
        tail = tail.float().clamp(-1.0, 1.0).add(1.0).mul_(127.5).round_().clamp_(0, 255).to(torch.uint8)
    return {"tail_frames": tail.contiguous(), "scale": scale, "variant": variant}


def _wavelet_color_fix(frames: torch.Tensor, lq_video: torch.Tensor) -> torch.Tensor:
    if frames.shape != lq_video[:, :, :frames.shape[2]].shape:
        return frames
    for start in range(0, frames.shape[2], 4):
        end = min(start + 4, frames.shape[2])
        frame_chunk = frames[:, :, start:end]
        lq_chunk = lq_video[:, :, start:end].to(device=frames.device, dtype=frames.dtype)
        mean_frames = frame_chunk.mean(dim=(3, 4), keepdim=True)
        std_frames = frame_chunk.std(dim=(3, 4), keepdim=True).clamp_min_(1e-5)
        mean_lq = lq_chunk.mean(dim=(3, 4), keepdim=True)
        std_lq = lq_chunk.std(dim=(3, 4), keepdim=True).clamp_min_(1e-5)
        frame_chunk.sub_(mean_frames).div_(std_frames).mul_(std_lq).add_(mean_lq).clamp_(-1.0, 1.0)
    return frames


def _wavelet_color_fix_from_sample(frames: torch.Tensor, sample: torch.Tensor, scale: float, output_height: int, output_width: int, padded_output_height: int, padded_output_width: int) -> torch.Tensor:
    step = 1 if frames.dtype == torch.uint8 else 4
    for start in range(0, min(int(frames.shape[2]), int(sample.shape[1])), step):
        end = min(start + step, int(frames.shape[2]), int(sample.shape[1]))
        frame_chunk = frames[:, :, start:end]
        if frames.dtype == torch.uint8:
            frame_float = frame_chunk.float()
            lq_chunk = sample[:, start:end].unsqueeze(0).to(device=frames.device, dtype=torch.float32)
            if sample.dtype != torch.uint8:
                lq_chunk.clamp_(-1.0, 1.0).add_(1.0).mul_(127.5)
            mean_frames = frame_float.mean(dim=(3, 4), keepdim=True)
            std_frames = frame_float.std(dim=(3, 4), keepdim=True).clamp_min_(1e-5)
            mean_lq = lq_chunk.mean(dim=(3, 4), keepdim=True)
            std_lq = lq_chunk.std(dim=(3, 4), keepdim=True).clamp_min_(1e-5)
            frame_float.sub_(mean_frames).div_(std_frames).mul_(std_lq).add_(mean_lq).round_().clamp_(0, 255)
            frame_chunk.copy_(frame_float.to(torch.uint8))
            del frame_float, lq_chunk, mean_frames, std_frames, mean_lq, std_lq
            continue
        lq_chunk = sample[:, start:end].unsqueeze(0).to(device=frames.device, dtype=frames.dtype)
        if sample.dtype == torch.uint8:
            lq_chunk.div_(127.5).sub_(1.0)
        else:
            lq_chunk.clamp_(-1.0, 1.0)
        mean_frames = frame_chunk.mean(dim=(3, 4), keepdim=True)
        std_frames = frame_chunk.std(dim=(3, 4), keepdim=True).clamp_min_(1e-5)
        mean_lq = lq_chunk.mean(dim=(3, 4), keepdim=True)
        std_lq = lq_chunk.std(dim=(3, 4), keepdim=True).clamp_min_(1e-5)
        frame_chunk.sub_(mean_frames).div_(std_frames).mul_(std_lq).add_(mean_lq).clamp_(-1.0, 1.0)
        del lq_chunk, mean_frames, std_frames, mean_lq, std_lq
    return frames


def _denoise_stream_chunk(
    dit: WanModel,
    x: torch.Tensor,
    context: torch.Tensor | None,
    lq_layer_chunks: list[list[torch.Tensor | None]],
    block_cache_k: list[torch.Tensor | None],
    block_cache_v: list[torch.Tensor | None],
    chunk_index: int,
    timestep_embed: torch.Tensor,
    timestep_mod: torch.Tensor,
    *,
    topk_ratio: float = 2.0,
    kv_ratio: float = FLASHVSR_KV_CACHE_WINDOWS,
    local_range: int = 9,
    cache_next: bool = True,
    allow_short_start: bool = False,
    abort_callback=None,
) -> tuple[torch.Tensor | None, list[torch.Tensor | None], list[torch.Tensor | None]]:
    x, (frames, height, width) = dit.patchify(x)
    win = (2, 8, 8)
    seqlen = frames // win[0]
    window_size = win[0] * height * width // 128
    topk = int(window_size * window_size * topk_ratio) - 1
    kv_len = max(1, int(kv_ratio))
    if chunk_index == 0:
        freqs_t = dit.freqs[0][:frames]
    else:
        start = 4 + chunk_index * 2
        freqs_t = dit.freqs[0][start:start + frames]
    freqs = tuple((freq.real.to(device=x.device, dtype=x.dtype), freq.imag.to(device=x.device, dtype=x.dtype)) for freq in (freqs_t, dit.freqs[1][:height], dit.freqs[2][:width]))
    for block_id, block in enumerate(dit.blocks):
        if _abort_requested(abort_callback):
            return None, block_cache_k, block_cache_v
        if block_id < len(lq_layer_chunks[0]):
            offset = 0
            for chunk in lq_layer_chunks:
                lq = chunk[block_id].to(x.device, dtype=x.dtype)
                next_offset = offset + lq.shape[1]
                x[:, offset:next_offset].add_(lq)
                offset = next_offset
                chunk[block_id] = None
                del lq
        cache_refs = None
        if block_cache_k[block_id] is not None:
            cache_refs = [block_cache_k[block_id].to(x.device, dtype=x.dtype), block_cache_v[block_id].to(x.device, dtype=x.dtype)]
            block_cache_k[block_id] = None
            block_cache_v[block_id] = None
        x_ref = [x]
        x = None
        x, next_cache_k, next_cache_v = block(
            x_ref, context, timestep_mod, freqs, frames, height, width, seqlen, topk,
            block_id=block_id, kv_len=kv_len, is_stream=True,
            pre_cache_refs=cache_refs, local_range=local_range, cache_next=cache_next, allow_short_start=allow_short_start,
        )
        x_ref.clear()
        block_cache_k[block_id] = next_cache_k
        del next_cache_k
        block_cache_v[block_id] = next_cache_v
        del next_cache_v, cache_refs
        if _abort_requested(abort_callback):
            return None, block_cache_k, block_cache_v
    x = dit.head([x], timestep_embed)
    return dit.unpatchify([x], (frames, height, width)), block_cache_k, block_cache_v


class FlashVSRRuntime:
    def __init__(self) -> None:
        self.variant: str | None = None
        self.dtype = torch.bfloat16
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.dit: WanModel | None = None
        self.lq_proj: Causal_LQ4x_Proj | None = None
        self.tcdecoder: torch.nn.Module | None = None
        self.vae: WanVAE | None = None
        self.offloadobj = None
        self.prompt_context: torch.Tensor | None = None
        self.timestep: torch.Tensor | None = None
        self.timestep_embed: torch.Tensor | None = None
        self.timestep_mod: torch.Tensor | None = None
        self.profile = None

    def load(self, paths: FlashVSRPaths, variant: str, profile, init_pipe) -> None:
        require_sparge_attention()
        variant = variant or FLASHVSR_VARIANT_TINY_LONG
        if self.dit is not None and self.variant == variant and self.profile == profile:
            return
        self.release()
        self.variant = variant
        self.profile = profile
        with init_empty_weights(include_buffers=True), _default_dtype(self.dtype):
            self.dit = WanModel(**WAN_1_3B_CONFIG).eval()
            self.lq_proj = Causal_LQ4x_Proj(in_dim=3, out_dim=1536, layer_num=1).eval()
        self.dit._offload_hooks = ["reinit_cross_kv"]
        self.lq_proj._offload_hooks = ["stream_forward"]
        offload.load_model_data(self.dit, paths.transformer, writable_tensors=False, preprocess_sd=_preprocess_transformer_state_dict, default_dtype=self.dtype, ignore_unused_weights=True, verboseLevel=-1)
        self.dit.freqs = precompute_freqs_cis_3d(WAN_1_3B_CONFIG["dim"] // WAN_1_3B_CONFIG["num_heads"])
        offload.load_model_data(self.lq_proj, paths.lq_proj, writable_tensors=False, default_dtype=self.dtype, verboseLevel=-1)
        self.dit.requires_grad_(False)
        self.lq_proj.requires_grad_(False)
        self.prompt_context = load_file(paths.posi_prompt, device="cpu")["context"].to(self.dtype)
        pipe = {"transformer": self.dit, "lq_proj": self.lq_proj}
        if variant in (FLASHVSR_VARIANT_TINY, FLASHVSR_VARIANT_TINY_LONG):
            self.tcdecoder = build_tcdecoder(new_channels=[512, 256, 128, 128], device="cpu", dtype=self.dtype, new_latent_channels=16 + 768).eval()
            self.tcdecoder._offload_hooks = ["decode_video"]
            offload.load_model_data(self.tcdecoder, paths.tcdecoder, writable_tensors=False, default_dtype=self.dtype, ignore_unused_weights=True, verboseLevel=-1)
            self.tcdecoder.requires_grad_(False)
            pipe["tcdecoder"] = self.tcdecoder
        else:
            self.vae = WanVAE(vae_pth=paths.vae, dtype=self.dtype, upsampler_factor=1, device="cpu")
            self.vae.device = self.device
            self.vae.model.requires_grad_(False)
            pipe["vae"] = self.vae.model
        kwargs = {"coTenantsMap": FLASHVSR_COTENANTS_MAP}
        profile_no = init_pipe(pipe, kwargs, profile)
        self.offloadobj = offload.profile(pipe, profile_no=profile_no, quantizeTransformer=False, convertWeightsFloatTo=self.dtype, verboseLevel=-1, **kwargs)
        log_sparse_backend()

    def _prepare_run_state(self) -> None:
        if self.device.type != "cuda":
            raise RuntimeError("FlashVSR requires CUDA.")
        context = self.prompt_context.to(self.device, dtype=self.dtype)
        self.dit.reinit_cross_kv(context)
        self.timestep = torch.tensor([1000.0], device=self.device, dtype=self.dtype)
        self.timestep_embed = self.dit.time_embedding(_sinusoidal_embedding_1d(self.dit.freq_dim, self.timestep))
        self.timestep_mod = self.dit.time_projection(self.timestep_embed).unflatten(1, (6, self.dit.dim))

    def _clear_runtime_caches(self) -> None:
        if self.dit is not None:
            self.dit.clear_cross_kv()
        if self.lq_proj is not None:
            self.lq_proj.clear_cache()
        if self.tcdecoder is not None:
            self.tcdecoder.clean_mem()
        if self.vae is not None:
            self.vae.model.clear_cache()
        self.timestep = None
        self.timestep_embed = None
        self.timestep_mod = None

    def _unload_mmgp(self) -> None:
        self._clear_runtime_caches()
        if self.offloadobj is not None:
            self.offloadobj.unload_all()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

    def _decode_tcdecoder(self, latents: torch.Tensor, sample: torch.Tensor, lq_start: int, lq_end: int, output_height: int, output_width: int, padded_output_height: int, padded_output_width: int, tile_size: int, tile_mems: dict[tuple[int, int], Any] | None, abort_callback=None, progress_callback=None, progress_step: int | None = None, progress_total: int | None = None) -> tuple[torch.Tensor | None, dict[tuple[int, int], Any] | None]:
        if self.tcdecoder is None:
            raise RuntimeError("FlashVSR tiny variants require TCDecoder.")
        _report_progress(progress_callback, "TCDecoder Decoding", progress_step, progress_total)
        tile_size = int(tile_size or 0)
        cur_lq = _prepare_conditioning_range(sample, lq_start, lq_end, output_height, output_width, padded_output_height, padded_output_width, dtype=self.dtype).unsqueeze(0)
        if tile_size <= 0 or (padded_output_height <= tile_size and padded_output_width <= tile_size):
            cur_lq = cur_lq.to(self.device, dtype=self.dtype)
            frames = self.tcdecoder.decode_video(latents.transpose(1, 2), parallel=False, show_progress_bar=False, cond=cur_lq).transpose(1, 2).mul_(2).sub_(1)
            del cur_lq
            _report_progress(progress_callback, "TCDecoder Decoding", progress_step + 1 if progress_step is not None else None, progress_total)
            return frames, tile_mems

        halo = _tcdecoder_mem_halo_latents(self.tcdecoder)
        latent_tile = max(1, tile_size // 8)
        latent_height = padded_output_height // 8
        latent_width = padded_output_width // 8
        tile_mems = {} if tile_mems is None else tile_mems
        frames_out = None
        for latent_y0 in range(0, latent_height, latent_tile):
            latent_y1 = min(latent_y0 + latent_tile, latent_height)
            write_y0, write_y1 = latent_y0 * 8, min(latent_y1 * 8, output_height)
            if write_y1 <= write_y0:
                continue
            expanded_y0, expanded_y1 = max(0, latent_y0 - halo), min(latent_height, latent_y1 + halo)
            crop_y0 = (latent_y0 - expanded_y0) * 8
            for latent_x0 in range(0, latent_width, latent_tile):
                if _abort_requested(abort_callback):
                    del cur_lq
                    return None, tile_mems
                latent_x1 = min(latent_x0 + latent_tile, latent_width)
                write_x0, write_x1 = latent_x0 * 8, min(latent_x1 * 8, output_width)
                if write_x1 <= write_x0:
                    continue
                expanded_x0, expanded_x1 = max(0, latent_x0 - halo), min(latent_width, latent_x1 + halo)
                crop_x0 = (latent_x0 - expanded_x0) * 8
                tile_key = (latent_y0, latent_x0)
                saved_mem = tile_mems.get(tile_key)
                if saved_mem is None:
                    self.tcdecoder.clean_mem()
                else:
                    self.tcdecoder.mem = _nested_tensors_to(saved_mem, self.device, self.dtype)
                cur_lq_tile = cur_lq[:, :, :, expanded_y0 * 8:expanded_y1 * 8, expanded_x0 * 8:expanded_x1 * 8].contiguous().to(self.device, dtype=self.dtype)
                cur_latents = latents[:, :, :, expanded_y0:expanded_y1, expanded_x0:expanded_x1].to(self.device, dtype=self.dtype)
                tile_frames = self.tcdecoder.decode_video(cur_latents.transpose(1, 2), parallel=False, show_progress_bar=False, cond=cur_lq_tile).transpose(1, 2).mul_(2).sub_(1)
                tile_mems[tile_key] = _nested_tensors_to(self.tcdecoder.mem, "cpu")
                self.tcdecoder.clean_mem()
                tile_frames = tile_frames[:, :, :, crop_y0:crop_y0 + latent_y1 * 8 - latent_y0 * 8, crop_x0:crop_x0 + latent_x1 * 8 - latent_x0 * 8]
                if frames_out is None:
                    frames_out = torch.empty((tile_frames.shape[0], tile_frames.shape[1], tile_frames.shape[2], output_height, output_width), dtype=torch.float32, device="cpu")
                tile_cpu = tile_frames[:, :, :, :write_y1 - write_y0, :write_x1 - write_x0].detach().cpu().float()
                frames_out[:, :, :, write_y0:write_y1, write_x0:write_x1].copy_(tile_cpu)
                del cur_lq_tile, cur_latents, tile_frames, tile_cpu
        del cur_lq
        _report_progress(progress_callback, "TCDecoder Decoding", progress_step + 1 if progress_step is not None else None, progress_total)
        return frames_out, tile_mems

    def release(self) -> None:
        self._clear_runtime_caches()
        if self.offloadobj is not None:
            self.offloadobj.release()
            self.offloadobj = None
        self.dit = None
        self.lq_proj = None
        self.tcdecoder = None
        self.vae = None
        self.prompt_context = None
        self.timestep = None
        self.timestep_embed = None
        self.timestep_mod = None
        self.variant = None
        self.profile = None
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

    @torch.inference_mode()
    def upscale(
        self,
        sample: torch.Tensor,
        scale: float,
        *,
        seed: int = 0,
        continue_cache: Any = None,
        return_continue_cache: bool = False,
        persistent_models: bool = False,
        vae_tile_size: int | None = None,
        topk_ratio: float = FLASHVSR_TOPK_RATIO,
        still_image: bool = False,
        abort_callback=None,
        progress_callback=None,
    ) -> tuple[torch.Tensor | None, dict[str, Any] | None]:
        if self.dit is None or self.lq_proj is None:
            raise RuntimeError("FlashVSR models are not loaded.")
        def abort_result():
            self._unload_mmgp()
            if not persistent_models:
                self.release()
            return None, None

        input_frames = sample.shape[1]
        num_frames = _next_conditioning_frame_count(input_frames)
        output_height, output_width, padded_output_height, padded_output_width = _conditioning_sizes(sample, scale)
        configured_topk_ratio = max(0.0, min(4.0, float(topk_ratio or 0.0)))
        if configured_topk_ratio > 0:
            topk_ratio = configured_topk_ratio
            print(f"[FlashVSR] Sparse top-k ratio fixed to {topk_ratio:.3f}")
        else:
            raw_topk_ratio = min(2.0, 2.0 * 768 * 1280 / max(int(padded_output_height) * int(padded_output_width), 1))
            topk_ratio = max(raw_topk_ratio, FLASHVSR_FULL_MIN_AUTO_TOPK_RATIO)
            if topk_ratio != raw_topk_ratio:
                print(f"[FlashVSR] Sparse top-k ratio adjusted to {topk_ratio:.3f} for {padded_output_width}x{padded_output_height} (minimum; raw auto {raw_topk_ratio:.3f})")
            elif topk_ratio < 2.0:
                print(f"[FlashVSR] Sparse top-k ratio adjusted to {topk_ratio:.3f} for {padded_output_width}x{padded_output_height}")
        self._prepare_run_state()
        self.lq_proj.clear_cache()
        if self.tcdecoder is not None:
            self.tcdecoder.clean_mem()
        if self.vae is not None:
            self.vae.model.clear_cache()
        print(f"[FlashVSR] Stream KV cache windows: {max(1, int(FLASHVSR_KV_CACHE_WINDOWS))}")
        tcdecoder_tile_size = int(vae_tile_size or 0) if self.tcdecoder is not None else 0
        tcdecoder_tile_mems = None
        if self.tcdecoder is not None:
            if tcdecoder_tile_size > 0 and (padded_output_height > tcdecoder_tile_size or padded_output_width > tcdecoder_tile_size):
                print(f"[FlashVSR] TCDecoder spatial tiling policy: tile_size={tcdecoder_tile_size}px, halo={_tcdecoder_mem_halo_latents(self.tcdecoder) * 8}px")
                tcdecoder_tile_mems = {}
            else:
                print("[FlashVSR] TCDecoder spatial tiling policy: tile_size=0px")
        generator = torch.Generator(device="cpu").manual_seed(0 if seed is None or seed < 0 else int(seed))
        still_image = bool(still_image and input_frames == 1)
        self.lq_proj.shift_start_prefix = still_image
        optimize_still_image = still_image and not FLASHVSR_DISABLE_STILL_IMAGE_OPTIMIZATIONS
        first_chunk_latent_frames = 2 if optimize_still_image else 6
        first_chunk_lq_steps = first_chunk_latent_frames + 1 if optimize_still_image else 7
        still_debug_frame_count = (first_chunk_latent_frames - 1) * 4 + 1
        still_output_frame = still_debug_frame_count - 1 if still_image and FLASHVSR_STILL_IMAGE_RETURN_WARMED_FRAME else 0
        if optimize_still_image:
            print(f"[FlashVSR] Still image mode: denoising {first_chunk_latent_frames} startup latent frames instead of 6; returning decoded frame {still_output_frame}")
        elif still_image and FLASHVSR_DISABLE_STILL_IMAGE_OPTIMIZATIONS:
            print(f"[FlashVSR] Still image debug mode: image optimizations disabled; denoising original 6 startup latent frames; returning decoded frame {still_output_frame}")
        latent_frame_count = first_chunk_latent_frames if still_image else (num_frames - 1) // 4
        latents = torch.empty((1, 16, latent_frame_count, padded_output_height // 8, padded_output_width // 8), device="cpu", dtype=self.dtype)
        latents.normal_(generator=generator)
        process_total = (num_frames - 1) // 8 - 2
        pre_cache_k = [None] * len(self.dit.blocks)
        pre_cache_v = [None] * len(self.dit.blocks)
        frames_out = None
        frames_cursor = 0
        lq_pre_idx = 0
        lq_cur_idx = 0
        _report_progress(progress_callback, "Denoising", 0, process_total)
        for process_idx in tqdm(range(process_total), desc="FlashVSR"):
            if _abort_requested(abort_callback):
                return abort_result()
            lq_layer_chunks = []
            torch.cuda.empty_cache()
            if process_idx == 0:
                for inner_idx in range(first_chunk_lq_steps):
                    if _abort_requested(abort_callback):
                        return abort_result()
                    lq_chunk = _prepare_conditioning_range(sample, max(0, inner_idx * 4 - 3), (inner_idx + 1) * 4 - 3, output_height, output_width, padded_output_height, padded_output_width, dtype=self.dtype).unsqueeze(0).to(self.device, dtype=self.dtype)
                    lq_list = [lq_chunk]
                    del lq_chunk
                    cur = self.lq_proj.stream_forward(lq_list)
                    if cur is not None:
                        lq_layer_chunks.append(cur)
                    del cur
                lq_cur_idx = 1 if optimize_still_image else 21
                latent_start, latent_end = 0, first_chunk_latent_frames
                cur_latents = latents[:, :, :first_chunk_latent_frames].to(self.device, dtype=self.dtype)
            else:
                for inner_idx in range(2):
                    if _abort_requested(abort_callback):
                        return abort_result()
                    lq_start = process_idx * 8 + 17 + inner_idx * 4
                    lq_chunk = _prepare_conditioning_range(sample, lq_start, lq_start + 4, output_height, output_width, padded_output_height, padded_output_width, dtype=self.dtype).unsqueeze(0).to(self.device, dtype=self.dtype)
                    lq_list = [lq_chunk]
                    del lq_chunk
                    cur = self.lq_proj.stream_forward(lq_list)
                    if cur is not None:
                        lq_layer_chunks.append(cur)
                    del cur
                lq_cur_idx = process_idx * 8 + 21
                latent_start, latent_end = 4 + process_idx * 2, 6 + process_idx * 2
                cur_latents = latents[:, :, latent_start:latent_end].to(self.device, dtype=self.dtype)
            torch.cuda.empty_cache()

            noise_pred, pre_cache_k, pre_cache_v = _denoise_stream_chunk(
                self.dit, cur_latents, None, lq_layer_chunks, pre_cache_k, pre_cache_v, process_idx,
                self.timestep_embed, self.timestep_mod, topk_ratio=topk_ratio, cache_next=process_idx + 1 < process_total, allow_short_start=optimize_still_image and process_idx == 0, abort_callback=abort_callback,
            )
            if noise_pred is None:
                return abort_result()
            cur_latents = cur_latents - noise_pred
            _report_progress(progress_callback, "Denoising", process_idx + 1, process_total)
            if self.variant == FLASHVSR_VARIANT_TINY_LONG:
                save_still_debug_video = still_image and frames_cursor == 0 and FLASHVSR_SAVE_STILL_IMAGE_DEBUG_VIDEO
                decode_latents = cur_latents if still_image and frames_cursor == 0 else cur_latents
                decode_lq_cur_idx = still_debug_frame_count if still_image and frames_cursor == 0 else lq_cur_idx
                cur_frames, tcdecoder_tile_mems = self._decode_tcdecoder(decode_latents, sample, lq_pre_idx, decode_lq_cur_idx, output_height, output_width, padded_output_height, padded_output_width, tcdecoder_tile_size, tcdecoder_tile_mems, abort_callback=abort_callback, progress_callback=progress_callback, progress_step=process_idx, progress_total=process_total)
                if cur_frames is None:
                    return abort_result()
                cur_frames = _crop_output_frames(cur_frames.detach().cpu(), output_height, output_width)
                if save_still_debug_video:
                    _save_still_image_debug_video(cur_frames)
                if still_image and frames_cursor == 0:
                    cur_frames = _select_still_image_frame(cur_frames, still_output_frame)
                copy_frames = min(int(cur_frames.shape[2]), input_frames - frames_cursor)
                if copy_frames > 0:
                    if frames_out is None:
                        frames_out = torch.empty((cur_frames.shape[0], cur_frames.shape[1], input_frames, output_height, output_width), dtype=torch.float32, device="cpu")
                    frames_out[:, :, frames_cursor:frames_cursor + copy_frames].copy_(cur_frames[:, :, :copy_frames].float())
                    frames_cursor += copy_frames
                lq_pre_idx = lq_cur_idx
                del cur_frames
            else:
                latents[:, :, latent_start:latent_end].copy_(cur_latents.detach().cpu())
            lq_layer_chunks = None
        self.lq_proj.clear_cache()
        pre_cache_k = pre_cache_v = None
        self.dit.clear_cross_kv()
        gc.collect()
        if self.variant == FLASHVSR_VARIANT_TINY_LONG:
            frames = frames_out
        else:
            if self.variant == FLASHVSR_VARIANT_TINY:
                if _abort_requested(abort_callback):
                    return abort_result()
                self.tcdecoder.clean_mem()
                frames_out = None
                frames_cursor = 0
                lq_pre_idx = 0
                for decode_idx in range(process_total):
                    if _abort_requested(abort_callback):
                        return abort_result()
                    if decode_idx == 0:
                        lq_cur_idx = 1 if optimize_still_image else 21
                        latent_start, latent_end = 0, first_chunk_latent_frames
                    else:
                        lq_cur_idx = decode_idx * 8 + 21
                        latent_start, latent_end = 4 + decode_idx * 2, 6 + decode_idx * 2
                    cur_latents = latents[:, :, latent_start:latent_end].to(self.device, dtype=self.dtype)
                    save_still_debug_video = still_image and frames_cursor == 0 and FLASHVSR_SAVE_STILL_IMAGE_DEBUG_VIDEO
                    decode_latents = cur_latents if still_image and frames_cursor == 0 else cur_latents
                    decode_lq_cur_idx = still_debug_frame_count if still_image and frames_cursor == 0 else lq_cur_idx
                    cur_frames, tcdecoder_tile_mems = self._decode_tcdecoder(decode_latents, sample, lq_pre_idx, decode_lq_cur_idx, output_height, output_width, padded_output_height, padded_output_width, tcdecoder_tile_size, tcdecoder_tile_mems, abort_callback=abort_callback, progress_callback=progress_callback, progress_step=decode_idx, progress_total=process_total)
                    if cur_frames is None:
                        return abort_result()
                    cur_frames = _crop_output_frames(cur_frames.detach().cpu(), output_height, output_width)
                    if save_still_debug_video:
                        _save_still_image_debug_video(cur_frames)
                    if still_image and frames_cursor == 0:
                        cur_frames = _select_still_image_frame(cur_frames, still_output_frame)
                    copy_frames = min(int(cur_frames.shape[2]), input_frames - frames_cursor)
                    if copy_frames > 0:
                        if frames_out is None:
                            frames_out = torch.empty((cur_frames.shape[0], cur_frames.shape[1], input_frames, output_height, output_width), dtype=torch.float32, device="cpu")
                        frames_out[:, :, frames_cursor:frames_cursor + copy_frames].copy_(cur_frames[:, :, :copy_frames].float())
                        frames_cursor += copy_frames
                    lq_pre_idx = lq_cur_idx
                    del cur_latents, cur_frames
                frames = frames_out
            else:
                if _abort_requested(abort_callback):
                    return abort_result()
                _report_progress(progress_callback, "VAE Decoding")
                if self.vae is None:
                    raise RuntimeError("FlashVSR full variant requires the Wan VAE.")
                vae_tile_size = int(vae_tile_size or 0)
                print(f"[FlashVSR] Wan VAE tiling policy: tile_size={vae_tile_size}px")
                save_still_debug_video = still_image and FLASHVSR_SAVE_STILL_IMAGE_DEBUG_VIDEO
                decode_latents = latents[0, :, :first_chunk_latent_frames].contiguous() if still_image else latents[0]
                frames = self.vae.decode_to_cpu_uint8([decode_latents], vae_tile_size, target_frames=None if save_still_debug_video else 1 if still_image else input_frames, target_height=output_height, target_width=output_width, frame_start=0 if save_still_debug_video or not still_image else still_output_frame)[0]
                if save_still_debug_video:
                    _save_still_image_debug_video(frames)
                if still_image:
                    frames = _select_still_image_frame(frames, still_output_frame if save_still_debug_video else 0)
        if self.tcdecoder is not None:
            self.tcdecoder.clean_mem()
        if self.vae is not None:
            self.vae.model.clear_cache()
        latents = frames_out = pre_cache_k = pre_cache_v = tcdecoder_tile_mems = None
        noise_pred = cur_latents = lq_layer_chunks = None
        lq_chunk = cur = cur_lq = cur_frames = None
        if torch.is_tensor(frames) and frames.dtype == torch.uint8 and frames.ndim == 4:
            if frames.shape[1:] != (input_frames, output_height, output_width):
                frames = frames[:, :input_frames, :output_height, :output_width].contiguous()
        else:
            decoded_frames = frames
            frames = _decoded_frames_to_cpu(decoded_frames, input_frames, output_height, output_width)
            del decoded_frames
        gc.collect()
        _report_progress(progress_callback, "Color Correction")
        _wavelet_color_fix_from_sample(frames.unsqueeze(0), sample, scale, output_height, output_width, output_height, output_width)
        if frames.dtype != torch.uint8:
            frames.clamp_(-1.0, 1.0)
        frames = _apply_continue_cache(frames, continue_cache)
        cache = _make_continue_cache(frames, scale, self.variant) if return_continue_cache else None
        sample = None
        self._unload_mmgp()
        if not persistent_models:
            self.release()
        return frames, cache


_RUNTIME = FlashVSRRuntime()


def upscale_video(
    sample: torch.Tensor,
    scale: float,
    paths: FlashVSRPaths,
    *,
    variant: str = FLASHVSR_VARIANT_TINY_LONG,
    seed: int = 0,
    continue_cache: Any = None,
    return_continue_cache: bool = False,
    persistent_models: bool = False,
    vae_tile_size: int | None = None,
    topk_ratio: float = FLASHVSR_TOPK_RATIO,
    init_pipe,
    profile,
    still_image: bool = False,
    two_pass: bool = False,
    abort_callback=None,
    progress_callback=None,
) -> tuple[torch.Tensor | None, dict[str, Any] | None]:
    _report_progress(progress_callback, "Caching")
    _RUNTIME.load(paths, variant, profile=profile, init_pipe=init_pipe)
    try:
        shift_correction = bool(
            FLASHVSR_STILL_IMAGE_SHIFT_CORRECTION
            and two_pass
        )
        if shift_correction:
            shift_y, shift_x = FLASHVSR_STILL_IMAGE_SHIFT_CORRECTION_INPUT_SHIFT or (max(1, int(round(FLASHVSR_STILL_IMAGE_SHIFT_CORRECTION_PERIOD * 0.5 / scale))), 0)
            out_shift_y, out_shift_x = int(round(shift_y * scale)), int(round(shift_x * scale))
            print(f"[FlashVSR] x{scale:g} shifted two-pass blend: extra shifted pass ({shift_y}px input / {out_shift_y}px output), blend={FLASHVSR_STILL_IMAGE_SHIFT_BLEND:g}")
            base, base_cache = _RUNTIME.upscale(sample, scale, seed=seed, continue_cache=continue_cache, return_continue_cache=return_continue_cache, persistent_models=True, vae_tile_size=vae_tile_size, topk_ratio=topk_ratio, still_image=still_image, abort_callback=abort_callback, progress_callback=progress_callback)
            if base is None:
                result = (None, None)
            else:
                shifted_sample = _shift_spatial_replicate(sample, shift_y, shift_x)
                shifted_continue_cache = _two_pass_shifted_continue_cache(continue_cache, out_shift_y, out_shift_x)
                shifted, shifted_cache = _RUNTIME.upscale(shifted_sample, scale, seed=seed, continue_cache=shifted_continue_cache, return_continue_cache=return_continue_cache, persistent_models=True, vae_tile_size=vae_tile_size, topk_ratio=topk_ratio, still_image=still_image, abort_callback=abort_callback, progress_callback=progress_callback)
                result = (None, None) if shifted is None else (_apply_still_image_shift_correction(base, _shift_spatial_replicate(shifted, -out_shift_y, -out_shift_x), scale), _make_two_pass_continue_cache(base_cache, shifted_cache, shift_y, shift_x, out_shift_y, out_shift_x))
                del shifted_sample, shifted
            del base
            if not persistent_models:
                _RUNTIME.release()
        else:
            result = _RUNTIME.upscale(sample, scale, seed=seed, continue_cache=continue_cache, return_continue_cache=return_continue_cache, persistent_models=persistent_models, vae_tile_size=vae_tile_size, topk_ratio=topk_ratio, still_image=still_image, abort_callback=abort_callback, progress_callback=progress_callback)
        if result[0] is None:
            if persistent_models:
                _RUNTIME._unload_mmgp()
            else:
                _RUNTIME.release()
        return result
    except Exception:
        if persistent_models:
            _RUNTIME._unload_mmgp()
        else:
            _RUNTIME.release()
        raise


def release_models() -> None:
    _RUNTIME.release()