alexander00001 commited on
Commit
1aa6f43
·
verified ·
1 Parent(s): 740dcb4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -10
app.py CHANGED
@@ -127,9 +127,14 @@ def initialize_model():
127
  PRIVATE_MODEL,
128
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
129
  use_safetensors=True,
130
- trust_remote_code=True # 私人仓库需要
 
 
 
 
 
131
  )
132
- print("Successfully loaded private NSFW Wan model!")
133
 
134
  except Exception as private_error:
135
  print(f"Private Wan model loading failed: {private_error}")
@@ -140,7 +145,9 @@ def initialize_model():
140
  pipeline = WanPipeline.from_pretrained(
141
  "Wan-AI/Wan2.2-T2V-A14B-Diffusers",
142
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
143
- use_safetensors=True
 
 
144
  )
145
  print("Loaded official Wan2.2-Diffusers model")
146
  except Exception as wan_error:
@@ -156,22 +163,24 @@ def initialize_model():
156
 
157
  pipeline = pipeline.to(device)
158
 
159
- # GPU优化 - CogVideoX专用
160
  if torch.cuda.is_available():
161
  try:
162
- # CogVideoX特有的优化方法
163
- if hasattr(pipeline, 'enable_vae_tiling'):
164
- pipeline.enable_vae_tiling()
165
  if hasattr(pipeline, 'enable_model_cpu_offload'):
166
- pipeline.enable_model_cpu_offload()
 
 
 
 
167
  # 通用内存优化
168
  try:
169
  pipeline.enable_xformers_memory_efficient_attention()
170
  except:
171
  pass
172
- print(" CogVideoX memory optimizations applied")
173
  except Exception as mem_error:
174
- print(f"⚠️ Memory optimization warning: {mem_error}")
175
 
176
  # 初始化Compel
177
  if COMPEL_AVAILABLE and hasattr(pipeline, 'tokenizer'):
 
127
  PRIVATE_MODEL,
128
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
129
  use_safetensors=True,
130
+ trust_remote_code=True,
131
+ # 更激进的内存优化
132
+ text_encoder_dtype=torch.float32,
133
+ device_map="balanced",
134
+ load_in_8bit=True, # 8bit量化
135
+ low_cpu_mem_usage=True # 低CPU内存使用
136
  )
137
+ print("Successfully loaded private NSFW Wan model with memory optimization!")
138
 
139
  except Exception as private_error:
140
  print(f"Private Wan model loading failed: {private_error}")
 
145
  pipeline = WanPipeline.from_pretrained(
146
  "Wan-AI/Wan2.2-T2V-A14B-Diffusers",
147
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
148
+ use_safetensors=True,
149
+ text_encoder_dtype=torch.float32,
150
+ device_map="balanced"
151
  )
152
  print("Loaded official Wan2.2-Diffusers model")
153
  except Exception as wan_error:
 
163
 
164
  pipeline = pipeline.to(device)
165
 
166
+ # GPU优化 - Wan模型专用内存管理
167
  if torch.cuda.is_available():
168
  try:
169
+ # Wan模型特有的优化方法
 
 
170
  if hasattr(pipeline, 'enable_model_cpu_offload'):
171
+ pipeline.enable_model_cpu_offload() # 将部分组件移至CPU
172
+ if hasattr(pipeline, 'enable_vae_tiling'):
173
+ pipeline.enable_vae_tiling() # VAE分块处理
174
+ if hasattr(pipeline, 'enable_sequential_cpu_offload'):
175
+ pipeline.enable_sequential_cpu_offload() # 顺序CPU卸载
176
  # 通用内存优化
177
  try:
178
  pipeline.enable_xformers_memory_efficient_attention()
179
  except:
180
  pass
181
+ print("Wan model memory optimizations applied")
182
  except Exception as mem_error:
183
+ print(f"Memory optimization warning: {mem_error}")
184
 
185
  # 初始化Compel
186
  if COMPEL_AVAILABLE and hasattr(pipeline, 'tokenizer'):