Update app.py
Browse files
app.py
CHANGED
|
@@ -127,9 +127,14 @@ def initialize_model():
|
|
| 127 |
PRIVATE_MODEL,
|
| 128 |
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
|
| 129 |
use_safetensors=True,
|
| 130 |
-
trust_remote_code=True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 131 |
)
|
| 132 |
-
print("Successfully loaded private NSFW Wan model!")
|
| 133 |
|
| 134 |
except Exception as private_error:
|
| 135 |
print(f"Private Wan model loading failed: {private_error}")
|
|
@@ -140,7 +145,9 @@ def initialize_model():
|
|
| 140 |
pipeline = WanPipeline.from_pretrained(
|
| 141 |
"Wan-AI/Wan2.2-T2V-A14B-Diffusers",
|
| 142 |
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
|
| 143 |
-
use_safetensors=True
|
|
|
|
|
|
|
| 144 |
)
|
| 145 |
print("Loaded official Wan2.2-Diffusers model")
|
| 146 |
except Exception as wan_error:
|
|
@@ -156,22 +163,24 @@ def initialize_model():
|
|
| 156 |
|
| 157 |
pipeline = pipeline.to(device)
|
| 158 |
|
| 159 |
-
# GPU优化 -
|
| 160 |
if torch.cuda.is_available():
|
| 161 |
try:
|
| 162 |
-
#
|
| 163 |
-
if hasattr(pipeline, 'enable_vae_tiling'):
|
| 164 |
-
pipeline.enable_vae_tiling()
|
| 165 |
if hasattr(pipeline, 'enable_model_cpu_offload'):
|
| 166 |
-
pipeline.enable_model_cpu_offload()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 167 |
# 通用内存优化
|
| 168 |
try:
|
| 169 |
pipeline.enable_xformers_memory_efficient_attention()
|
| 170 |
except:
|
| 171 |
pass
|
| 172 |
-
print("
|
| 173 |
except Exception as mem_error:
|
| 174 |
-
print(f"
|
| 175 |
|
| 176 |
# 初始化Compel
|
| 177 |
if COMPEL_AVAILABLE and hasattr(pipeline, 'tokenizer'):
|
|
|
|
| 127 |
PRIVATE_MODEL,
|
| 128 |
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
|
| 129 |
use_safetensors=True,
|
| 130 |
+
trust_remote_code=True,
|
| 131 |
+
# 更激进的内存优化
|
| 132 |
+
text_encoder_dtype=torch.float32,
|
| 133 |
+
device_map="balanced",
|
| 134 |
+
load_in_8bit=True, # 8bit量化
|
| 135 |
+
low_cpu_mem_usage=True # 低CPU内存使用
|
| 136 |
)
|
| 137 |
+
print("Successfully loaded private NSFW Wan model with memory optimization!")
|
| 138 |
|
| 139 |
except Exception as private_error:
|
| 140 |
print(f"Private Wan model loading failed: {private_error}")
|
|
|
|
| 145 |
pipeline = WanPipeline.from_pretrained(
|
| 146 |
"Wan-AI/Wan2.2-T2V-A14B-Diffusers",
|
| 147 |
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
|
| 148 |
+
use_safetensors=True,
|
| 149 |
+
text_encoder_dtype=torch.float32,
|
| 150 |
+
device_map="balanced"
|
| 151 |
)
|
| 152 |
print("Loaded official Wan2.2-Diffusers model")
|
| 153 |
except Exception as wan_error:
|
|
|
|
| 163 |
|
| 164 |
pipeline = pipeline.to(device)
|
| 165 |
|
| 166 |
+
# GPU优化 - Wan模型专用内存管理
|
| 167 |
if torch.cuda.is_available():
|
| 168 |
try:
|
| 169 |
+
# Wan模型特有的优化方法
|
|
|
|
|
|
|
| 170 |
if hasattr(pipeline, 'enable_model_cpu_offload'):
|
| 171 |
+
pipeline.enable_model_cpu_offload() # 将部分组件移至CPU
|
| 172 |
+
if hasattr(pipeline, 'enable_vae_tiling'):
|
| 173 |
+
pipeline.enable_vae_tiling() # VAE分块处理
|
| 174 |
+
if hasattr(pipeline, 'enable_sequential_cpu_offload'):
|
| 175 |
+
pipeline.enable_sequential_cpu_offload() # 顺序CPU卸载
|
| 176 |
# 通用内存优化
|
| 177 |
try:
|
| 178 |
pipeline.enable_xformers_memory_efficient_attention()
|
| 179 |
except:
|
| 180 |
pass
|
| 181 |
+
print("Wan model memory optimizations applied")
|
| 182 |
except Exception as mem_error:
|
| 183 |
+
print(f"Memory optimization warning: {mem_error}")
|
| 184 |
|
| 185 |
# 初始化Compel
|
| 186 |
if COMPEL_AVAILABLE and hasattr(pipeline, 'tokenizer'):
|