File size: 45,942 Bytes
7344bef | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 | ##### Enjoy this spagheti VRAM optimizations done by DeepBeepMeep !
# I am sure you are a nice person and as you copy this code, you will give me officially proper credits:
# Please link to https://github.com/deepbeepmeep/Wan2GP and @deepbeepmeep on twitter
from __future__ import annotations
import gc
import math
import os
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Any
import torch
import torch.nn.functional as F
from accelerate import init_empty_weights
from einops import rearrange
from safetensors.torch import load_file
from tqdm import tqdm
from mmgp import offload
from models.wan.modules.vae import WanVAE
from .attention_backend import log_sparse_backend, require_sparge_attention
from .tcdecoder import build_tcdecoder
from .utils import Causal_LQ4x_Proj
from .wan_video_dit import WanModel, precompute_freqs_cis_3d
FLASHVSR_VARIANT_TINY_LONG = "tiny-long"
FLASHVSR_VARIANT_TINY = "tiny"
FLASHVSR_VARIANT_FULL = "full"
FLASHVSR_TOPK_RATIO = 0.0 # 0 = auto area-scaled ratio; >0 = fixed sparse attention ratio.
FLASHVSR_FULL_MIN_AUTO_TOPK_RATIO = 1.5
FLASHVSR_KV_CACHE_WINDOWS = 1 # Stream cache windows kept between denoise chunks; each window is two latent frames.
FLASHVSR_CONTINUE_CACHE_FRAMES = 11
FLASHVSR_COTENANTS_MAP = {"lq_proj": ["transformer"]}
FLASHVSR_SAVE_STILL_IMAGE_DEBUG_VIDEO = False
FLASHVSR_DISABLE_STILL_IMAGE_OPTIMIZATIONS = False
FLASHVSR_STILL_IMAGE_SHIFT_CORRECTION = True
FLASHVSR_STILL_IMAGE_SHIFT_CORRECTION_INPUT_SHIFT = None # None = half-period output phase shift.
FLASHVSR_STILL_IMAGE_SHIFT_CORRECTION_PERIOD = 16
FLASHVSR_STILL_IMAGE_RETURN_WARMED_FRAME = True
FLASHVSR_STILL_IMAGE_SHIFT_BLEND = 0.5
FLASHVSR_STILL_IMAGE_DEBUG_VIDEO_PATH = "flashvsr_still_image_debug.mp4"
FLASHVSR_STILL_IMAGE_DEBUG_VIDEO_FPS = 4
WAN_1_3B_CONFIG = {
"has_image_input": False,
"patch_size": (1, 2, 2),
"in_dim": 16,
"dim": 1536,
"ffn_dim": 8960,
"freq_dim": 256,
"text_dim": 4096,
"out_dim": 16,
"num_heads": 12,
"num_layers": 30,
"eps": 1e-6,
}
@contextmanager
def _default_dtype(dtype: torch.dtype):
previous_dtype = torch.get_default_dtype()
torch.set_default_dtype(dtype)
try:
yield
finally:
torch.set_default_dtype(previous_dtype)
@dataclass
class FlashVSRPaths:
transformer: str
lq_proj: str
posi_prompt: str
tcdecoder: str | None = None
vae: str | None = None
def _preprocess_transformer_state_dict(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
converter = WanModel.state_dict_converter()
state_dict, _ = converter.from_civitai(state_dict)
return state_dict
def _sinusoidal_embedding_1d(dim: int, position: torch.Tensor) -> torch.Tensor:
sinusoid = torch.outer(position.type(torch.float64), torch.pow(10000, -torch.arange(dim // 2, dtype=torch.float64, device=position.device).div(dim // 2)))
return torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1).to(position.dtype)
def _next_conditioning_frame_count(frame_count: int) -> int:
padded = max(25, frame_count + 4)
remainder = padded % 8
if remainder != 1:
padded += (1 - remainder) % 8
return padded
def _aligned_output_size(height: int, width: int, scale: float) -> tuple[int, int]:
target_h = max(1, int(height * scale))
target_w = max(1, int(width * scale))
return max(128, math.ceil(target_h / 128) * 128), max(128, math.ceil(target_w / 128) * 128)
def _conditioning_sizes(sample: torch.Tensor, scale: float) -> tuple[int, int, int, int]:
_, frames, height, width = sample.shape
output_height = max(1, int(height * scale))
output_width = max(1, int(width * scale))
padded_output_height, padded_output_width = _aligned_output_size(height, width, scale)
pad_h = padded_output_height - output_height
pad_w = padded_output_width - output_width
if pad_h or pad_w:
print(f"[FlashVSR] Edge padding output canvas {output_width}x{output_height} -> {padded_output_width}x{padded_output_height}; final crop restores {output_width}x{output_height}")
return output_height, output_width, padded_output_height, padded_output_width
def _prepare_conditioning_range(sample: torch.Tensor, start: int, end: int, output_height: int, output_width: int, padded_output_height: int, padded_output_width: int, dtype: torch.dtype = torch.bfloat16) -> torch.Tensor:
frames = int(sample.shape[1])
pad_h = padded_output_height - output_height
pad_w = padded_output_width - output_width
frame_indices = [min(max(frame_idx, 0), frames - 1) for frame_idx in range(start, end)]
lq = sample[:, frame_indices]
if lq.dtype == torch.uint8:
lq = lq.float().div_(127.5).sub_(1.0)
else:
lq = lq.detach().float().clamp_(-1.0, 1.0)
lq = F.interpolate(lq.permute(1, 0, 2, 3).contiguous(), size=(output_height, output_width), mode="bicubic", align_corners=False)
if pad_h or pad_w:
lq = F.pad(lq, (0, pad_w, 0, pad_h), mode="replicate")
return lq.clamp_(-1.0, 1.0).to(dtype=dtype).permute(1, 0, 2, 3).contiguous()
def _pad_conditioning_frames(lq_video: torch.Tensor, target_frames: int) -> torch.Tensor:
missing = target_frames - lq_video.shape[2]
if missing <= 0:
return lq_video[:, :, :target_frames]
tail = lq_video[:, :, -1:].repeat(1, 1, missing, 1, 1)
return torch.cat([lq_video, tail], dim=2)
def _crop_output_frames(frames: torch.Tensor, height: int, width: int) -> torch.Tensor:
if frames.shape[-2:] == (height, width):
return frames
return frames[..., :height, :width].contiguous()
def _shift_spatial_replicate(tensor: torch.Tensor, shift_y: int, shift_x: int) -> torch.Tensor:
if shift_y == 0 and shift_x == 0:
return tensor.clone()
height, width = tensor.shape[-2:]
shift_y = max(1 - height, min(height - 1, int(shift_y)))
shift_x = max(1 - width, min(width - 1, int(shift_x)))
crop = tensor[..., max(0, -shift_y):height - max(0, shift_y), max(0, -shift_x):width - max(0, shift_x)]
return F.pad(crop, (max(0, shift_x), max(0, -shift_x), max(0, shift_y), max(0, -shift_y)), mode="replicate")
def _apply_still_image_shift_correction(base: torch.Tensor, shifted: torch.Tensor, scale: float) -> torch.Tensor:
base_float = base.to(dtype=torch.float32, copy=True)
corrected = base_float.lerp_(shifted.to(dtype=torch.float32), float(FLASHVSR_STILL_IMAGE_SHIFT_BLEND))
if base.dtype == torch.uint8:
return corrected.round_().clamp_(0, 255).to(torch.uint8)
return corrected.clamp_(-1.0, 1.0).to(dtype=base.dtype)
def _shift_continue_cache(continue_cache: Any, shift_y: int, shift_x: int) -> Any:
if not isinstance(continue_cache, dict):
return continue_cache
tail = continue_cache.get("tail_frames")
if not torch.is_tensor(tail) or tail.ndim != 4:
return continue_cache
shifted_cache = dict(continue_cache)
shifted_cache["tail_frames"] = _shift_spatial_replicate(tail, shift_y, shift_x)
return shifted_cache
def _two_pass_shifted_continue_cache(continue_cache: Any, shift_y: int, shift_x: int) -> Any:
if not isinstance(continue_cache, dict):
return continue_cache
tail = continue_cache.get("tail_frames_shifted")
if not torch.is_tensor(tail) or tail.ndim != 4:
return _shift_continue_cache(continue_cache, shift_y, shift_x)
shifted_cache = dict(continue_cache)
shifted_cache["tail_frames"] = tail
return shifted_cache
def _make_two_pass_continue_cache(base_cache: Any, shifted_cache: Any, shift_y: int, shift_x: int, out_shift_y: int, out_shift_x: int) -> Any:
if not isinstance(base_cache, dict):
return base_cache
cache = dict(base_cache)
shifted_tail = shifted_cache.get("tail_frames") if isinstance(shifted_cache, dict) else None
if torch.is_tensor(shifted_tail) and shifted_tail.ndim == 4:
cache["tail_frames_shifted"] = shifted_tail.contiguous()
cache.update({"two_pass": True, "shift_y": shift_y, "shift_x": shift_x, "out_shift_y": out_shift_y, "out_shift_x": out_shift_x})
return cache
def _select_still_image_frame(frames: torch.Tensor, frame_index: int) -> torch.Tensor:
return frames[:, :, frame_index:frame_index + 1].contiguous() if frames.ndim == 5 else frames[:, frame_index:frame_index + 1].contiguous()
def _decoded_frames_to_cpu(frames: torch.Tensor, frame_count: int, height: int, width: int) -> torch.Tensor:
frames = frames.detach()[0, :, :frame_count, :height, :width]
if frames.device.type == "cpu" and frames.dtype == torch.float32 and frames.is_contiguous():
return frames
frames_cpu = torch.empty(tuple(frames.shape), dtype=torch.float32, device="cpu")
frames_cpu.copy_(frames)
return frames_cpu
def _save_still_image_debug_video(frames: torch.Tensor) -> None:
if not FLASHVSR_SAVE_STILL_IMAGE_DEBUG_VIDEO:
return
path = os.path.abspath(FLASHVSR_STILL_IMAGE_DEBUG_VIDEO_PATH)
try:
from shared.utils.audio_video import save_video
debug_frames = frames.detach().cpu()
save_video(tensor=debug_frames, save_file=path, fps=FLASHVSR_STILL_IMAGE_DEBUG_VIDEO_FPS, nrow=1, normalize=True, value_range=(-1, 1), codec_type="libx264_8", container="mp4")
print(f"[FlashVSR] Still image debug video saved to {path} ({int(debug_frames.shape[2])} frames)")
del debug_frames
except Exception as exc:
print(f"[FlashVSR] Failed to save still image debug video: {exc}")
def _nested_tensors_to(value: Any, device: torch.device | str, dtype: torch.dtype | None = None) -> Any:
if torch.is_tensor(value):
return value.detach().to(device=device, dtype=dtype or value.dtype)
if isinstance(value, list):
return [_nested_tensors_to(item, device, dtype) for item in value]
return value
def _tcdecoder_mem_halo_latents(tcdecoder: torch.nn.Module) -> int:
radius = 0.0
jump = 1.0
decoder = tcdecoder.taehv.decoder if hasattr(tcdecoder, "taehv") else tcdecoder.decoder
for module in decoder:
if isinstance(module, torch.nn.Conv2d):
kernel = module.kernel_size[0] if isinstance(module.kernel_size, tuple) else int(module.kernel_size)
radius += ((kernel - 1) / 2) * jump
elif module.__class__.__name__ == "MemBlock":
for submodule in module.conv:
if isinstance(submodule, torch.nn.Conv2d):
kernel = submodule.kernel_size[0] if isinstance(submodule.kernel_size, tuple) else int(submodule.kernel_size)
radius += ((kernel - 1) / 2) * jump
elif isinstance(module, torch.nn.Upsample):
scale = module.scale_factor[0] if isinstance(module.scale_factor, tuple) else module.scale_factor
jump /= float(scale or 1)
return max(1, int(math.ceil(radius)))
def _report_progress(progress_callback, phase: str, current_step: int | None = None, total_steps: int | None = None) -> None:
if callable(progress_callback):
progress_callback(phase, current_step, total_steps)
def _abort_requested(abort_callback) -> bool:
return callable(abort_callback) and abort_callback()
def _apply_continue_cache(frames: torch.Tensor, continue_cache: Any) -> torch.Tensor:
if not isinstance(continue_cache, dict):
return frames
tail = continue_cache.get("tail_frames")
if not torch.is_tensor(tail) or tail.ndim != 4:
return frames
if tail.shape[0] != frames.shape[0] or tail.shape[-2:] != frames.shape[-2:]:
return frames
overlap = min(int(tail.shape[1]), int(frames.shape[1]))
if overlap <= 0:
return frames
if frames.dtype == torch.uint8:
if tail.dtype != torch.uint8:
tail = tail.float().clamp(-1.0, 1.0).add(1.0).mul_(127.5).round_().clamp_(0, 255).to(torch.uint8)
frames[:, :overlap].copy_(tail[:, -overlap:].to(device=frames.device))
return frames
if tail.dtype == torch.uint8:
tail = tail.to(device=frames.device, dtype=frames.dtype).div(127.5).sub(1.0)
else:
tail = tail.to(device=frames.device, dtype=frames.dtype)
frames[:, :overlap].copy_(tail[:, -overlap:])
return frames
def _make_continue_cache(frames: torch.Tensor, scale: float, variant: str, overlap_frames: int = FLASHVSR_CONTINUE_CACHE_FRAMES) -> dict[str, Any]:
tail_len = min(overlap_frames, frames.shape[1])
tail = frames[:, -tail_len:].detach().cpu()
if tail.dtype != torch.uint8:
tail = tail.float().clamp(-1.0, 1.0).add(1.0).mul_(127.5).round_().clamp_(0, 255).to(torch.uint8)
return {"tail_frames": tail.contiguous(), "scale": scale, "variant": variant}
def _wavelet_color_fix(frames: torch.Tensor, lq_video: torch.Tensor) -> torch.Tensor:
if frames.shape != lq_video[:, :, :frames.shape[2]].shape:
return frames
for start in range(0, frames.shape[2], 4):
end = min(start + 4, frames.shape[2])
frame_chunk = frames[:, :, start:end]
lq_chunk = lq_video[:, :, start:end].to(device=frames.device, dtype=frames.dtype)
mean_frames = frame_chunk.mean(dim=(3, 4), keepdim=True)
std_frames = frame_chunk.std(dim=(3, 4), keepdim=True).clamp_min_(1e-5)
mean_lq = lq_chunk.mean(dim=(3, 4), keepdim=True)
std_lq = lq_chunk.std(dim=(3, 4), keepdim=True).clamp_min_(1e-5)
frame_chunk.sub_(mean_frames).div_(std_frames).mul_(std_lq).add_(mean_lq).clamp_(-1.0, 1.0)
return frames
def _wavelet_color_fix_from_sample(frames: torch.Tensor, sample: torch.Tensor, scale: float, output_height: int, output_width: int, padded_output_height: int, padded_output_width: int) -> torch.Tensor:
step = 1 if frames.dtype == torch.uint8 else 4
for start in range(0, min(int(frames.shape[2]), int(sample.shape[1])), step):
end = min(start + step, int(frames.shape[2]), int(sample.shape[1]))
frame_chunk = frames[:, :, start:end]
if frames.dtype == torch.uint8:
frame_float = frame_chunk.float()
lq_chunk = sample[:, start:end].unsqueeze(0).to(device=frames.device, dtype=torch.float32)
if sample.dtype != torch.uint8:
lq_chunk.clamp_(-1.0, 1.0).add_(1.0).mul_(127.5)
mean_frames = frame_float.mean(dim=(3, 4), keepdim=True)
std_frames = frame_float.std(dim=(3, 4), keepdim=True).clamp_min_(1e-5)
mean_lq = lq_chunk.mean(dim=(3, 4), keepdim=True)
std_lq = lq_chunk.std(dim=(3, 4), keepdim=True).clamp_min_(1e-5)
frame_float.sub_(mean_frames).div_(std_frames).mul_(std_lq).add_(mean_lq).round_().clamp_(0, 255)
frame_chunk.copy_(frame_float.to(torch.uint8))
del frame_float, lq_chunk, mean_frames, std_frames, mean_lq, std_lq
continue
lq_chunk = sample[:, start:end].unsqueeze(0).to(device=frames.device, dtype=frames.dtype)
if sample.dtype == torch.uint8:
lq_chunk.div_(127.5).sub_(1.0)
else:
lq_chunk.clamp_(-1.0, 1.0)
mean_frames = frame_chunk.mean(dim=(3, 4), keepdim=True)
std_frames = frame_chunk.std(dim=(3, 4), keepdim=True).clamp_min_(1e-5)
mean_lq = lq_chunk.mean(dim=(3, 4), keepdim=True)
std_lq = lq_chunk.std(dim=(3, 4), keepdim=True).clamp_min_(1e-5)
frame_chunk.sub_(mean_frames).div_(std_frames).mul_(std_lq).add_(mean_lq).clamp_(-1.0, 1.0)
del lq_chunk, mean_frames, std_frames, mean_lq, std_lq
return frames
def _denoise_stream_chunk(
dit: WanModel,
x: torch.Tensor,
context: torch.Tensor | None,
lq_layer_chunks: list[list[torch.Tensor | None]],
block_cache_k: list[torch.Tensor | None],
block_cache_v: list[torch.Tensor | None],
chunk_index: int,
timestep_embed: torch.Tensor,
timestep_mod: torch.Tensor,
*,
topk_ratio: float = 2.0,
kv_ratio: float = FLASHVSR_KV_CACHE_WINDOWS,
local_range: int = 9,
cache_next: bool = True,
allow_short_start: bool = False,
abort_callback=None,
) -> tuple[torch.Tensor | None, list[torch.Tensor | None], list[torch.Tensor | None]]:
x, (frames, height, width) = dit.patchify(x)
win = (2, 8, 8)
seqlen = frames // win[0]
window_size = win[0] * height * width // 128
topk = int(window_size * window_size * topk_ratio) - 1
kv_len = max(1, int(kv_ratio))
if chunk_index == 0:
freqs_t = dit.freqs[0][:frames]
else:
start = 4 + chunk_index * 2
freqs_t = dit.freqs[0][start:start + frames]
freqs = tuple((freq.real.to(device=x.device, dtype=x.dtype), freq.imag.to(device=x.device, dtype=x.dtype)) for freq in (freqs_t, dit.freqs[1][:height], dit.freqs[2][:width]))
for block_id, block in enumerate(dit.blocks):
if _abort_requested(abort_callback):
return None, block_cache_k, block_cache_v
if block_id < len(lq_layer_chunks[0]):
offset = 0
for chunk in lq_layer_chunks:
lq = chunk[block_id].to(x.device, dtype=x.dtype)
next_offset = offset + lq.shape[1]
x[:, offset:next_offset].add_(lq)
offset = next_offset
chunk[block_id] = None
del lq
cache_refs = None
if block_cache_k[block_id] is not None:
cache_refs = [block_cache_k[block_id].to(x.device, dtype=x.dtype), block_cache_v[block_id].to(x.device, dtype=x.dtype)]
block_cache_k[block_id] = None
block_cache_v[block_id] = None
x_ref = [x]
x = None
x, next_cache_k, next_cache_v = block(
x_ref, context, timestep_mod, freqs, frames, height, width, seqlen, topk,
block_id=block_id, kv_len=kv_len, is_stream=True,
pre_cache_refs=cache_refs, local_range=local_range, cache_next=cache_next, allow_short_start=allow_short_start,
)
x_ref.clear()
block_cache_k[block_id] = next_cache_k
del next_cache_k
block_cache_v[block_id] = next_cache_v
del next_cache_v, cache_refs
if _abort_requested(abort_callback):
return None, block_cache_k, block_cache_v
x = dit.head([x], timestep_embed)
return dit.unpatchify([x], (frames, height, width)), block_cache_k, block_cache_v
class FlashVSRRuntime:
def __init__(self) -> None:
self.variant: str | None = None
self.dtype = torch.bfloat16
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.dit: WanModel | None = None
self.lq_proj: Causal_LQ4x_Proj | None = None
self.tcdecoder: torch.nn.Module | None = None
self.vae: WanVAE | None = None
self.offloadobj = None
self.prompt_context: torch.Tensor | None = None
self.timestep: torch.Tensor | None = None
self.timestep_embed: torch.Tensor | None = None
self.timestep_mod: torch.Tensor | None = None
self.profile = None
def load(self, paths: FlashVSRPaths, variant: str, profile, init_pipe) -> None:
require_sparge_attention()
variant = variant or FLASHVSR_VARIANT_TINY_LONG
if self.dit is not None and self.variant == variant and self.profile == profile:
return
self.release()
self.variant = variant
self.profile = profile
with init_empty_weights(include_buffers=True), _default_dtype(self.dtype):
self.dit = WanModel(**WAN_1_3B_CONFIG).eval()
self.lq_proj = Causal_LQ4x_Proj(in_dim=3, out_dim=1536, layer_num=1).eval()
self.dit._offload_hooks = ["reinit_cross_kv"]
self.lq_proj._offload_hooks = ["stream_forward"]
offload.load_model_data(self.dit, paths.transformer, writable_tensors=False, preprocess_sd=_preprocess_transformer_state_dict, default_dtype=self.dtype, ignore_unused_weights=True, verboseLevel=-1)
self.dit.freqs = precompute_freqs_cis_3d(WAN_1_3B_CONFIG["dim"] // WAN_1_3B_CONFIG["num_heads"])
offload.load_model_data(self.lq_proj, paths.lq_proj, writable_tensors=False, default_dtype=self.dtype, verboseLevel=-1)
self.dit.requires_grad_(False)
self.lq_proj.requires_grad_(False)
self.prompt_context = load_file(paths.posi_prompt, device="cpu")["context"].to(self.dtype)
pipe = {"transformer": self.dit, "lq_proj": self.lq_proj}
if variant in (FLASHVSR_VARIANT_TINY, FLASHVSR_VARIANT_TINY_LONG):
self.tcdecoder = build_tcdecoder(new_channels=[512, 256, 128, 128], device="cpu", dtype=self.dtype, new_latent_channels=16 + 768).eval()
self.tcdecoder._offload_hooks = ["decode_video"]
offload.load_model_data(self.tcdecoder, paths.tcdecoder, writable_tensors=False, default_dtype=self.dtype, ignore_unused_weights=True, verboseLevel=-1)
self.tcdecoder.requires_grad_(False)
pipe["tcdecoder"] = self.tcdecoder
else:
self.vae = WanVAE(vae_pth=paths.vae, dtype=self.dtype, upsampler_factor=1, device="cpu")
self.vae.device = self.device
self.vae.model.requires_grad_(False)
pipe["vae"] = self.vae.model
kwargs = {"coTenantsMap": FLASHVSR_COTENANTS_MAP}
profile_no = init_pipe(pipe, kwargs, profile)
self.offloadobj = offload.profile(pipe, profile_no=profile_no, quantizeTransformer=False, convertWeightsFloatTo=self.dtype, verboseLevel=-1, **kwargs)
log_sparse_backend()
def _prepare_run_state(self) -> None:
if self.device.type != "cuda":
raise RuntimeError("FlashVSR requires CUDA.")
context = self.prompt_context.to(self.device, dtype=self.dtype)
self.dit.reinit_cross_kv(context)
self.timestep = torch.tensor([1000.0], device=self.device, dtype=self.dtype)
self.timestep_embed = self.dit.time_embedding(_sinusoidal_embedding_1d(self.dit.freq_dim, self.timestep))
self.timestep_mod = self.dit.time_projection(self.timestep_embed).unflatten(1, (6, self.dit.dim))
def _clear_runtime_caches(self) -> None:
if self.dit is not None:
self.dit.clear_cross_kv()
if self.lq_proj is not None:
self.lq_proj.clear_cache()
if self.tcdecoder is not None:
self.tcdecoder.clean_mem()
if self.vae is not None:
self.vae.model.clear_cache()
self.timestep = None
self.timestep_embed = None
self.timestep_mod = None
def _unload_mmgp(self) -> None:
self._clear_runtime_caches()
if self.offloadobj is not None:
self.offloadobj.unload_all()
if torch.cuda.is_available():
torch.cuda.empty_cache()
def _decode_tcdecoder(self, latents: torch.Tensor, sample: torch.Tensor, lq_start: int, lq_end: int, output_height: int, output_width: int, padded_output_height: int, padded_output_width: int, tile_size: int, tile_mems: dict[tuple[int, int], Any] | None, abort_callback=None, progress_callback=None, progress_step: int | None = None, progress_total: int | None = None) -> tuple[torch.Tensor | None, dict[tuple[int, int], Any] | None]:
if self.tcdecoder is None:
raise RuntimeError("FlashVSR tiny variants require TCDecoder.")
_report_progress(progress_callback, "TCDecoder Decoding", progress_step, progress_total)
tile_size = int(tile_size or 0)
cur_lq = _prepare_conditioning_range(sample, lq_start, lq_end, output_height, output_width, padded_output_height, padded_output_width, dtype=self.dtype).unsqueeze(0)
if tile_size <= 0 or (padded_output_height <= tile_size and padded_output_width <= tile_size):
cur_lq = cur_lq.to(self.device, dtype=self.dtype)
frames = self.tcdecoder.decode_video(latents.transpose(1, 2), parallel=False, show_progress_bar=False, cond=cur_lq).transpose(1, 2).mul_(2).sub_(1)
del cur_lq
_report_progress(progress_callback, "TCDecoder Decoding", progress_step + 1 if progress_step is not None else None, progress_total)
return frames, tile_mems
halo = _tcdecoder_mem_halo_latents(self.tcdecoder)
latent_tile = max(1, tile_size // 8)
latent_height = padded_output_height // 8
latent_width = padded_output_width // 8
tile_mems = {} if tile_mems is None else tile_mems
frames_out = None
for latent_y0 in range(0, latent_height, latent_tile):
latent_y1 = min(latent_y0 + latent_tile, latent_height)
write_y0, write_y1 = latent_y0 * 8, min(latent_y1 * 8, output_height)
if write_y1 <= write_y0:
continue
expanded_y0, expanded_y1 = max(0, latent_y0 - halo), min(latent_height, latent_y1 + halo)
crop_y0 = (latent_y0 - expanded_y0) * 8
for latent_x0 in range(0, latent_width, latent_tile):
if _abort_requested(abort_callback):
del cur_lq
return None, tile_mems
latent_x1 = min(latent_x0 + latent_tile, latent_width)
write_x0, write_x1 = latent_x0 * 8, min(latent_x1 * 8, output_width)
if write_x1 <= write_x0:
continue
expanded_x0, expanded_x1 = max(0, latent_x0 - halo), min(latent_width, latent_x1 + halo)
crop_x0 = (latent_x0 - expanded_x0) * 8
tile_key = (latent_y0, latent_x0)
saved_mem = tile_mems.get(tile_key)
if saved_mem is None:
self.tcdecoder.clean_mem()
else:
self.tcdecoder.mem = _nested_tensors_to(saved_mem, self.device, self.dtype)
cur_lq_tile = cur_lq[:, :, :, expanded_y0 * 8:expanded_y1 * 8, expanded_x0 * 8:expanded_x1 * 8].contiguous().to(self.device, dtype=self.dtype)
cur_latents = latents[:, :, :, expanded_y0:expanded_y1, expanded_x0:expanded_x1].to(self.device, dtype=self.dtype)
tile_frames = self.tcdecoder.decode_video(cur_latents.transpose(1, 2), parallel=False, show_progress_bar=False, cond=cur_lq_tile).transpose(1, 2).mul_(2).sub_(1)
tile_mems[tile_key] = _nested_tensors_to(self.tcdecoder.mem, "cpu")
self.tcdecoder.clean_mem()
tile_frames = tile_frames[:, :, :, crop_y0:crop_y0 + latent_y1 * 8 - latent_y0 * 8, crop_x0:crop_x0 + latent_x1 * 8 - latent_x0 * 8]
if frames_out is None:
frames_out = torch.empty((tile_frames.shape[0], tile_frames.shape[1], tile_frames.shape[2], output_height, output_width), dtype=torch.float32, device="cpu")
tile_cpu = tile_frames[:, :, :, :write_y1 - write_y0, :write_x1 - write_x0].detach().cpu().float()
frames_out[:, :, :, write_y0:write_y1, write_x0:write_x1].copy_(tile_cpu)
del cur_lq_tile, cur_latents, tile_frames, tile_cpu
del cur_lq
_report_progress(progress_callback, "TCDecoder Decoding", progress_step + 1 if progress_step is not None else None, progress_total)
return frames_out, tile_mems
def release(self) -> None:
self._clear_runtime_caches()
if self.offloadobj is not None:
self.offloadobj.release()
self.offloadobj = None
self.dit = None
self.lq_proj = None
self.tcdecoder = None
self.vae = None
self.prompt_context = None
self.timestep = None
self.timestep_embed = None
self.timestep_mod = None
self.variant = None
self.profile = None
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
@torch.inference_mode()
def upscale(
self,
sample: torch.Tensor,
scale: float,
*,
seed: int = 0,
continue_cache: Any = None,
return_continue_cache: bool = False,
persistent_models: bool = False,
vae_tile_size: int | None = None,
topk_ratio: float = FLASHVSR_TOPK_RATIO,
still_image: bool = False,
abort_callback=None,
progress_callback=None,
) -> tuple[torch.Tensor | None, dict[str, Any] | None]:
if self.dit is None or self.lq_proj is None:
raise RuntimeError("FlashVSR models are not loaded.")
def abort_result():
self._unload_mmgp()
if not persistent_models:
self.release()
return None, None
input_frames = sample.shape[1]
num_frames = _next_conditioning_frame_count(input_frames)
output_height, output_width, padded_output_height, padded_output_width = _conditioning_sizes(sample, scale)
configured_topk_ratio = max(0.0, min(4.0, float(topk_ratio or 0.0)))
if configured_topk_ratio > 0:
topk_ratio = configured_topk_ratio
print(f"[FlashVSR] Sparse top-k ratio fixed to {topk_ratio:.3f}")
else:
raw_topk_ratio = min(2.0, 2.0 * 768 * 1280 / max(int(padded_output_height) * int(padded_output_width), 1))
topk_ratio = max(raw_topk_ratio, FLASHVSR_FULL_MIN_AUTO_TOPK_RATIO)
if topk_ratio != raw_topk_ratio:
print(f"[FlashVSR] Sparse top-k ratio adjusted to {topk_ratio:.3f} for {padded_output_width}x{padded_output_height} (minimum; raw auto {raw_topk_ratio:.3f})")
elif topk_ratio < 2.0:
print(f"[FlashVSR] Sparse top-k ratio adjusted to {topk_ratio:.3f} for {padded_output_width}x{padded_output_height}")
self._prepare_run_state()
self.lq_proj.clear_cache()
if self.tcdecoder is not None:
self.tcdecoder.clean_mem()
if self.vae is not None:
self.vae.model.clear_cache()
print(f"[FlashVSR] Stream KV cache windows: {max(1, int(FLASHVSR_KV_CACHE_WINDOWS))}")
tcdecoder_tile_size = int(vae_tile_size or 0) if self.tcdecoder is not None else 0
tcdecoder_tile_mems = None
if self.tcdecoder is not None:
if tcdecoder_tile_size > 0 and (padded_output_height > tcdecoder_tile_size or padded_output_width > tcdecoder_tile_size):
print(f"[FlashVSR] TCDecoder spatial tiling policy: tile_size={tcdecoder_tile_size}px, halo={_tcdecoder_mem_halo_latents(self.tcdecoder) * 8}px")
tcdecoder_tile_mems = {}
else:
print("[FlashVSR] TCDecoder spatial tiling policy: tile_size=0px")
generator = torch.Generator(device="cpu").manual_seed(0 if seed is None or seed < 0 else int(seed))
still_image = bool(still_image and input_frames == 1)
self.lq_proj.shift_start_prefix = still_image
optimize_still_image = still_image and not FLASHVSR_DISABLE_STILL_IMAGE_OPTIMIZATIONS
first_chunk_latent_frames = 2 if optimize_still_image else 6
first_chunk_lq_steps = first_chunk_latent_frames + 1 if optimize_still_image else 7
still_debug_frame_count = (first_chunk_latent_frames - 1) * 4 + 1
still_output_frame = still_debug_frame_count - 1 if still_image and FLASHVSR_STILL_IMAGE_RETURN_WARMED_FRAME else 0
if optimize_still_image:
print(f"[FlashVSR] Still image mode: denoising {first_chunk_latent_frames} startup latent frames instead of 6; returning decoded frame {still_output_frame}")
elif still_image and FLASHVSR_DISABLE_STILL_IMAGE_OPTIMIZATIONS:
print(f"[FlashVSR] Still image debug mode: image optimizations disabled; denoising original 6 startup latent frames; returning decoded frame {still_output_frame}")
latent_frame_count = first_chunk_latent_frames if still_image else (num_frames - 1) // 4
latents = torch.empty((1, 16, latent_frame_count, padded_output_height // 8, padded_output_width // 8), device="cpu", dtype=self.dtype)
latents.normal_(generator=generator)
process_total = (num_frames - 1) // 8 - 2
pre_cache_k = [None] * len(self.dit.blocks)
pre_cache_v = [None] * len(self.dit.blocks)
frames_out = None
frames_cursor = 0
lq_pre_idx = 0
lq_cur_idx = 0
_report_progress(progress_callback, "Denoising", 0, process_total)
for process_idx in tqdm(range(process_total), desc="FlashVSR"):
if _abort_requested(abort_callback):
return abort_result()
lq_layer_chunks = []
torch.cuda.empty_cache()
if process_idx == 0:
for inner_idx in range(first_chunk_lq_steps):
if _abort_requested(abort_callback):
return abort_result()
lq_chunk = _prepare_conditioning_range(sample, max(0, inner_idx * 4 - 3), (inner_idx + 1) * 4 - 3, output_height, output_width, padded_output_height, padded_output_width, dtype=self.dtype).unsqueeze(0).to(self.device, dtype=self.dtype)
lq_list = [lq_chunk]
del lq_chunk
cur = self.lq_proj.stream_forward(lq_list)
if cur is not None:
lq_layer_chunks.append(cur)
del cur
lq_cur_idx = 1 if optimize_still_image else 21
latent_start, latent_end = 0, first_chunk_latent_frames
cur_latents = latents[:, :, :first_chunk_latent_frames].to(self.device, dtype=self.dtype)
else:
for inner_idx in range(2):
if _abort_requested(abort_callback):
return abort_result()
lq_start = process_idx * 8 + 17 + inner_idx * 4
lq_chunk = _prepare_conditioning_range(sample, lq_start, lq_start + 4, output_height, output_width, padded_output_height, padded_output_width, dtype=self.dtype).unsqueeze(0).to(self.device, dtype=self.dtype)
lq_list = [lq_chunk]
del lq_chunk
cur = self.lq_proj.stream_forward(lq_list)
if cur is not None:
lq_layer_chunks.append(cur)
del cur
lq_cur_idx = process_idx * 8 + 21
latent_start, latent_end = 4 + process_idx * 2, 6 + process_idx * 2
cur_latents = latents[:, :, latent_start:latent_end].to(self.device, dtype=self.dtype)
torch.cuda.empty_cache()
noise_pred, pre_cache_k, pre_cache_v = _denoise_stream_chunk(
self.dit, cur_latents, None, lq_layer_chunks, pre_cache_k, pre_cache_v, process_idx,
self.timestep_embed, self.timestep_mod, topk_ratio=topk_ratio, cache_next=process_idx + 1 < process_total, allow_short_start=optimize_still_image and process_idx == 0, abort_callback=abort_callback,
)
if noise_pred is None:
return abort_result()
cur_latents = cur_latents - noise_pred
_report_progress(progress_callback, "Denoising", process_idx + 1, process_total)
if self.variant == FLASHVSR_VARIANT_TINY_LONG:
save_still_debug_video = still_image and frames_cursor == 0 and FLASHVSR_SAVE_STILL_IMAGE_DEBUG_VIDEO
decode_latents = cur_latents if still_image and frames_cursor == 0 else cur_latents
decode_lq_cur_idx = still_debug_frame_count if still_image and frames_cursor == 0 else lq_cur_idx
cur_frames, tcdecoder_tile_mems = self._decode_tcdecoder(decode_latents, sample, lq_pre_idx, decode_lq_cur_idx, output_height, output_width, padded_output_height, padded_output_width, tcdecoder_tile_size, tcdecoder_tile_mems, abort_callback=abort_callback, progress_callback=progress_callback, progress_step=process_idx, progress_total=process_total)
if cur_frames is None:
return abort_result()
cur_frames = _crop_output_frames(cur_frames.detach().cpu(), output_height, output_width)
if save_still_debug_video:
_save_still_image_debug_video(cur_frames)
if still_image and frames_cursor == 0:
cur_frames = _select_still_image_frame(cur_frames, still_output_frame)
copy_frames = min(int(cur_frames.shape[2]), input_frames - frames_cursor)
if copy_frames > 0:
if frames_out is None:
frames_out = torch.empty((cur_frames.shape[0], cur_frames.shape[1], input_frames, output_height, output_width), dtype=torch.float32, device="cpu")
frames_out[:, :, frames_cursor:frames_cursor + copy_frames].copy_(cur_frames[:, :, :copy_frames].float())
frames_cursor += copy_frames
lq_pre_idx = lq_cur_idx
del cur_frames
else:
latents[:, :, latent_start:latent_end].copy_(cur_latents.detach().cpu())
lq_layer_chunks = None
self.lq_proj.clear_cache()
pre_cache_k = pre_cache_v = None
self.dit.clear_cross_kv()
gc.collect()
if self.variant == FLASHVSR_VARIANT_TINY_LONG:
frames = frames_out
else:
if self.variant == FLASHVSR_VARIANT_TINY:
if _abort_requested(abort_callback):
return abort_result()
self.tcdecoder.clean_mem()
frames_out = None
frames_cursor = 0
lq_pre_idx = 0
for decode_idx in range(process_total):
if _abort_requested(abort_callback):
return abort_result()
if decode_idx == 0:
lq_cur_idx = 1 if optimize_still_image else 21
latent_start, latent_end = 0, first_chunk_latent_frames
else:
lq_cur_idx = decode_idx * 8 + 21
latent_start, latent_end = 4 + decode_idx * 2, 6 + decode_idx * 2
cur_latents = latents[:, :, latent_start:latent_end].to(self.device, dtype=self.dtype)
save_still_debug_video = still_image and frames_cursor == 0 and FLASHVSR_SAVE_STILL_IMAGE_DEBUG_VIDEO
decode_latents = cur_latents if still_image and frames_cursor == 0 else cur_latents
decode_lq_cur_idx = still_debug_frame_count if still_image and frames_cursor == 0 else lq_cur_idx
cur_frames, tcdecoder_tile_mems = self._decode_tcdecoder(decode_latents, sample, lq_pre_idx, decode_lq_cur_idx, output_height, output_width, padded_output_height, padded_output_width, tcdecoder_tile_size, tcdecoder_tile_mems, abort_callback=abort_callback, progress_callback=progress_callback, progress_step=decode_idx, progress_total=process_total)
if cur_frames is None:
return abort_result()
cur_frames = _crop_output_frames(cur_frames.detach().cpu(), output_height, output_width)
if save_still_debug_video:
_save_still_image_debug_video(cur_frames)
if still_image and frames_cursor == 0:
cur_frames = _select_still_image_frame(cur_frames, still_output_frame)
copy_frames = min(int(cur_frames.shape[2]), input_frames - frames_cursor)
if copy_frames > 0:
if frames_out is None:
frames_out = torch.empty((cur_frames.shape[0], cur_frames.shape[1], input_frames, output_height, output_width), dtype=torch.float32, device="cpu")
frames_out[:, :, frames_cursor:frames_cursor + copy_frames].copy_(cur_frames[:, :, :copy_frames].float())
frames_cursor += copy_frames
lq_pre_idx = lq_cur_idx
del cur_latents, cur_frames
frames = frames_out
else:
if _abort_requested(abort_callback):
return abort_result()
_report_progress(progress_callback, "VAE Decoding")
if self.vae is None:
raise RuntimeError("FlashVSR full variant requires the Wan VAE.")
vae_tile_size = int(vae_tile_size or 0)
print(f"[FlashVSR] Wan VAE tiling policy: tile_size={vae_tile_size}px")
save_still_debug_video = still_image and FLASHVSR_SAVE_STILL_IMAGE_DEBUG_VIDEO
decode_latents = latents[0, :, :first_chunk_latent_frames].contiguous() if still_image else latents[0]
frames = self.vae.decode_to_cpu_uint8([decode_latents], vae_tile_size, target_frames=None if save_still_debug_video else 1 if still_image else input_frames, target_height=output_height, target_width=output_width, frame_start=0 if save_still_debug_video or not still_image else still_output_frame)[0]
if save_still_debug_video:
_save_still_image_debug_video(frames)
if still_image:
frames = _select_still_image_frame(frames, still_output_frame if save_still_debug_video else 0)
if self.tcdecoder is not None:
self.tcdecoder.clean_mem()
if self.vae is not None:
self.vae.model.clear_cache()
latents = frames_out = pre_cache_k = pre_cache_v = tcdecoder_tile_mems = None
noise_pred = cur_latents = lq_layer_chunks = None
lq_chunk = cur = cur_lq = cur_frames = None
if torch.is_tensor(frames) and frames.dtype == torch.uint8 and frames.ndim == 4:
if frames.shape[1:] != (input_frames, output_height, output_width):
frames = frames[:, :input_frames, :output_height, :output_width].contiguous()
else:
decoded_frames = frames
frames = _decoded_frames_to_cpu(decoded_frames, input_frames, output_height, output_width)
del decoded_frames
gc.collect()
_report_progress(progress_callback, "Color Correction")
_wavelet_color_fix_from_sample(frames.unsqueeze(0), sample, scale, output_height, output_width, output_height, output_width)
if frames.dtype != torch.uint8:
frames.clamp_(-1.0, 1.0)
frames = _apply_continue_cache(frames, continue_cache)
cache = _make_continue_cache(frames, scale, self.variant) if return_continue_cache else None
sample = None
self._unload_mmgp()
if not persistent_models:
self.release()
return frames, cache
_RUNTIME = FlashVSRRuntime()
def upscale_video(
sample: torch.Tensor,
scale: float,
paths: FlashVSRPaths,
*,
variant: str = FLASHVSR_VARIANT_TINY_LONG,
seed: int = 0,
continue_cache: Any = None,
return_continue_cache: bool = False,
persistent_models: bool = False,
vae_tile_size: int | None = None,
topk_ratio: float = FLASHVSR_TOPK_RATIO,
init_pipe,
profile,
still_image: bool = False,
two_pass: bool = False,
abort_callback=None,
progress_callback=None,
) -> tuple[torch.Tensor | None, dict[str, Any] | None]:
_report_progress(progress_callback, "Caching")
_RUNTIME.load(paths, variant, profile=profile, init_pipe=init_pipe)
try:
shift_correction = bool(
FLASHVSR_STILL_IMAGE_SHIFT_CORRECTION
and two_pass
)
if shift_correction:
shift_y, shift_x = FLASHVSR_STILL_IMAGE_SHIFT_CORRECTION_INPUT_SHIFT or (max(1, int(round(FLASHVSR_STILL_IMAGE_SHIFT_CORRECTION_PERIOD * 0.5 / scale))), 0)
out_shift_y, out_shift_x = int(round(shift_y * scale)), int(round(shift_x * scale))
print(f"[FlashVSR] x{scale:g} shifted two-pass blend: extra shifted pass ({shift_y}px input / {out_shift_y}px output), blend={FLASHVSR_STILL_IMAGE_SHIFT_BLEND:g}")
base, base_cache = _RUNTIME.upscale(sample, scale, seed=seed, continue_cache=continue_cache, return_continue_cache=return_continue_cache, persistent_models=True, vae_tile_size=vae_tile_size, topk_ratio=topk_ratio, still_image=still_image, abort_callback=abort_callback, progress_callback=progress_callback)
if base is None:
result = (None, None)
else:
shifted_sample = _shift_spatial_replicate(sample, shift_y, shift_x)
shifted_continue_cache = _two_pass_shifted_continue_cache(continue_cache, out_shift_y, out_shift_x)
shifted, shifted_cache = _RUNTIME.upscale(shifted_sample, scale, seed=seed, continue_cache=shifted_continue_cache, return_continue_cache=return_continue_cache, persistent_models=True, vae_tile_size=vae_tile_size, topk_ratio=topk_ratio, still_image=still_image, abort_callback=abort_callback, progress_callback=progress_callback)
result = (None, None) if shifted is None else (_apply_still_image_shift_correction(base, _shift_spatial_replicate(shifted, -out_shift_y, -out_shift_x), scale), _make_two_pass_continue_cache(base_cache, shifted_cache, shift_y, shift_x, out_shift_y, out_shift_x))
del shifted_sample, shifted
del base
if not persistent_models:
_RUNTIME.release()
else:
result = _RUNTIME.upscale(sample, scale, seed=seed, continue_cache=continue_cache, return_continue_cache=return_continue_cache, persistent_models=persistent_models, vae_tile_size=vae_tile_size, topk_ratio=topk_ratio, still_image=still_image, abort_callback=abort_callback, progress_callback=progress_callback)
if result[0] is None:
if persistent_models:
_RUNTIME._unload_mmgp()
else:
_RUNTIME.release()
return result
except Exception:
if persistent_models:
_RUNTIME._unload_mmgp()
else:
_RUNTIME.release()
raise
def release_models() -> None:
_RUNTIME.release()
|