Update model/pipeline.py
Browse files- model/pipeline.py +120 -32
model/pipeline.py
CHANGED
@@ -42,72 +42,154 @@ class CatVTONPipeline:
|
|
42 |
self.weight_dtype = weight_dtype
|
43 |
self.skip_safety_check = skip_safety_check
|
44 |
|
|
|
|
|
45 |
# TF32 sadece CUDA'da anlamlı
|
46 |
if use_tf32 and self.device.type == "cuda":
|
47 |
torch.set_float32_matmul_precision("high")
|
48 |
torch.backends.cuda.matmul.allow_tf32 = True
|
|
|
49 |
|
50 |
# Scheduler
|
51 |
-
|
|
|
|
|
|
|
|
|
|
|
52 |
|
53 |
# VAE yükleme (hata durumunda fallback)
|
54 |
try:
|
55 |
self.vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse").to(
|
56 |
self.device, dtype=self.weight_dtype
|
57 |
)
|
|
|
58 |
except Exception as e:
|
59 |
print(f"[WARN] VAE yüklenirken hata: {e}. float32 fallback yapılıyor.")
|
60 |
-
|
61 |
-
self.
|
62 |
-
|
|
|
|
|
|
|
|
|
|
|
63 |
|
64 |
-
# Safety checker
|
65 |
if not skip_safety_check:
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
else:
|
71 |
self.feature_extractor = None
|
72 |
self.safety_checker = None
|
|
|
73 |
|
74 |
# UNet ve adapter
|
75 |
-
|
76 |
-
self.
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
|
82 |
# Compile (isteğe bağlı)
|
83 |
if compile:
|
84 |
try:
|
|
|
85 |
self.unet = torch.compile(self.unet)
|
86 |
self.vae = torch.compile(self.vae, mode="reduce-overhead")
|
|
|
87 |
except Exception as e:
|
88 |
print(f"[WARN] Compile sırasında hata: {e}. Compile edilmemiş sürüm devam ediyor.")
|
89 |
|
|
|
|
|
90 |
def auto_attn_ckpt_load(self, attn_ckpt, version):
|
91 |
sub_folder = {
|
92 |
"mix": "mix-48k-1024",
|
93 |
-
"vitonhd": "vitonhd-16k-512",
|
94 |
"dresscode": "dresscode-16k-512",
|
95 |
}[version]
|
|
|
96 |
if os.path.exists(attn_ckpt):
|
97 |
-
|
|
|
|
|
|
|
|
|
|
|
98 |
else:
|
99 |
-
|
100 |
-
|
101 |
-
|
|
|
|
|
|
|
|
|
|
|
102 |
|
103 |
def run_safety_checker(self, image):
|
104 |
-
if self.safety_checker is None:
|
105 |
has_nsfw_concept = None
|
106 |
else:
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
|
|
|
|
|
|
|
|
111 |
return image, has_nsfw_concept
|
112 |
|
113 |
def check_inputs(self, image, condition_image, mask, width, height):
|
@@ -220,19 +302,25 @@ class CatVTONPipeline:
|
|
220 |
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
221 |
image = numpy_to_pil(image)
|
222 |
|
223 |
-
|
|
|
224 |
current_script_directory = os.path.dirname(os.path.realpath(__file__))
|
225 |
nsfw_image_path = os.path.join(os.path.dirname(current_script_directory), "resource", "img", "NSFW.jpg")
|
226 |
try:
|
227 |
nsfw_image = PIL.Image.open(nsfw_image_path).resize(image[0].size)
|
228 |
except Exception:
|
229 |
nsfw_image = None
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
|
|
|
|
|
|
|
|
|
|
236 |
|
237 |
return image
|
238 |
|
|
|
42 |
self.weight_dtype = weight_dtype
|
43 |
self.skip_safety_check = skip_safety_check
|
44 |
|
45 |
+
print(f"[INFO] Pipeline başlatılıyor - Device: {self.device}, dtype: {self.weight_dtype}")
|
46 |
+
|
47 |
# TF32 sadece CUDA'da anlamlı
|
48 |
if use_tf32 and self.device.type == "cuda":
|
49 |
torch.set_float32_matmul_precision("high")
|
50 |
torch.backends.cuda.matmul.allow_tf32 = True
|
51 |
+
print("[INFO] TF32 etkinleştirildi")
|
52 |
|
53 |
# Scheduler
|
54 |
+
try:
|
55 |
+
self.noise_scheduler = DDIMScheduler.from_pretrained(base_ckpt, subfolder="scheduler")
|
56 |
+
print("[INFO] Scheduler yüklendi")
|
57 |
+
except Exception as e:
|
58 |
+
print(f"[ERROR] Scheduler yüklenirken hata: {e}")
|
59 |
+
raise
|
60 |
|
61 |
# VAE yükleme (hata durumunda fallback)
|
62 |
try:
|
63 |
self.vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse").to(
|
64 |
self.device, dtype=self.weight_dtype
|
65 |
)
|
66 |
+
print("[INFO] VAE yüklendi")
|
67 |
except Exception as e:
|
68 |
print(f"[WARN] VAE yüklenirken hata: {e}. float32 fallback yapılıyor.")
|
69 |
+
try:
|
70 |
+
self.vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse").to(
|
71 |
+
self.device, dtype=torch.float32
|
72 |
+
)
|
73 |
+
print("[INFO] VAE float32 ile yüklendi")
|
74 |
+
except Exception as e2:
|
75 |
+
print(f"[ERROR] VAE yüklenemedi: {e2}")
|
76 |
+
raise
|
77 |
|
78 |
+
# Safety checker - ZeroGPU uyumlu hata yakalama
|
79 |
if not skip_safety_check:
|
80 |
+
try:
|
81 |
+
self.feature_extractor = CLIPImageProcessor.from_pretrained(base_ckpt, subfolder="feature_extractor")
|
82 |
+
|
83 |
+
# Safety checker için çoklu deneme stratejisi
|
84 |
+
safety_checker_loaded = False
|
85 |
+
|
86 |
+
# 1. Önce safetensors formatını dene
|
87 |
+
try:
|
88 |
+
self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(
|
89 |
+
base_ckpt,
|
90 |
+
subfolder="safety_checker",
|
91 |
+
use_safetensors=True,
|
92 |
+
torch_dtype=self.weight_dtype
|
93 |
+
).to(self.device, dtype=self.weight_dtype)
|
94 |
+
safety_checker_loaded = True
|
95 |
+
print("[INFO] Safety checker (safetensors) yüklendi")
|
96 |
+
except Exception as e1:
|
97 |
+
print(f"[WARN] Safetensors safety checker yüklenemedi: {e1}")
|
98 |
+
|
99 |
+
# 2. Eğer safetensors başarısızsa, farklı model dene
|
100 |
+
if not safety_checker_loaded:
|
101 |
+
try:
|
102 |
+
self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(
|
103 |
+
"CompVis/stable-diffusion-safety-checker",
|
104 |
+
torch_dtype=self.weight_dtype
|
105 |
+
).to(self.device, dtype=self.weight_dtype)
|
106 |
+
safety_checker_loaded = True
|
107 |
+
print("[INFO] Safety checker (CompVis) yüklendi")
|
108 |
+
except Exception as e2:
|
109 |
+
print(f"[WARN] CompVis safety checker yüklenemedi: {e2}")
|
110 |
+
|
111 |
+
# 3. Son çare olarak safety checker'ı devre dışı bırak
|
112 |
+
if not safety_checker_loaded:
|
113 |
+
print("[WARN] Safety checker yüklenemedi, devre dışı bırakılıyor")
|
114 |
+
self.safety_checker = None
|
115 |
+
self.feature_extractor = None
|
116 |
+
self.skip_safety_check = True
|
117 |
+
|
118 |
+
except Exception as e:
|
119 |
+
print(f"[WARN] Safety checker başlatılamadı: {e}. Devre dışı bırakılıyor.")
|
120 |
+
self.feature_extractor = None
|
121 |
+
self.safety_checker = None
|
122 |
+
self.skip_safety_check = True
|
123 |
else:
|
124 |
self.feature_extractor = None
|
125 |
self.safety_checker = None
|
126 |
+
print("[INFO] Safety checker devre dışı")
|
127 |
|
128 |
# UNet ve adapter
|
129 |
+
try:
|
130 |
+
self.unet = UNet2DConditionModel.from_pretrained(base_ckpt, subfolder="unet").to(
|
131 |
+
self.device, dtype=self.weight_dtype
|
132 |
+
)
|
133 |
+
print("[INFO] UNet yüklendi")
|
134 |
+
|
135 |
+
init_adapter(self.unet, cross_attn_cls=SkipAttnProcessor)
|
136 |
+
self.attn_modules = get_trainable_module(self.unet, "attention")
|
137 |
+
print("[INFO] Adapter başlatıldı")
|
138 |
+
|
139 |
+
self.auto_attn_ckpt_load(attn_ckpt, attn_ckpt_version)
|
140 |
+
print("[INFO] Attention checkpoint yüklendi")
|
141 |
+
except Exception as e:
|
142 |
+
print(f"[ERROR] UNet yüklenirken hata: {e}")
|
143 |
+
raise
|
144 |
|
145 |
# Compile (isteğe bağlı)
|
146 |
if compile:
|
147 |
try:
|
148 |
+
print("[INFO] Model compile ediliyor...")
|
149 |
self.unet = torch.compile(self.unet)
|
150 |
self.vae = torch.compile(self.vae, mode="reduce-overhead")
|
151 |
+
print("[INFO] Model compile edildi")
|
152 |
except Exception as e:
|
153 |
print(f"[WARN] Compile sırasında hata: {e}. Compile edilmemiş sürüm devam ediyor.")
|
154 |
|
155 |
+
print("[INFO] Pipeline başarıyla başlatıldı")
|
156 |
+
|
157 |
def auto_attn_ckpt_load(self, attn_ckpt, version):
|
158 |
sub_folder = {
|
159 |
"mix": "mix-48k-1024",
|
160 |
+
"vitonhd": "vitonhd-16k-512",
|
161 |
"dresscode": "dresscode-16k-512",
|
162 |
}[version]
|
163 |
+
|
164 |
if os.path.exists(attn_ckpt):
|
165 |
+
checkpoint_path = os.path.join(attn_ckpt, sub_folder, "attention")
|
166 |
+
if os.path.exists(checkpoint_path):
|
167 |
+
load_checkpoint_in_model(self.attn_modules, checkpoint_path)
|
168 |
+
print(f"[INFO] Local checkpoint yüklendi: {checkpoint_path}")
|
169 |
+
else:
|
170 |
+
print(f"[WARN] Local checkpoint bulunamadı: {checkpoint_path}")
|
171 |
else:
|
172 |
+
try:
|
173 |
+
repo_path = snapshot_download(repo_id=attn_ckpt)
|
174 |
+
checkpoint_path = os.path.join(repo_path, sub_folder, "attention")
|
175 |
+
load_checkpoint_in_model(self.attn_modules, checkpoint_path)
|
176 |
+
print(f"[INFO] Downloaded checkpoint yüklendi: {checkpoint_path}")
|
177 |
+
except Exception as e:
|
178 |
+
print(f"[ERROR] Checkpoint indirilemedi: {e}")
|
179 |
+
raise
|
180 |
|
181 |
def run_safety_checker(self, image):
|
182 |
+
if self.safety_checker is None or self.skip_safety_check:
|
183 |
has_nsfw_concept = None
|
184 |
else:
|
185 |
+
try:
|
186 |
+
safety_checker_input = self.feature_extractor(image, return_tensors="pt").to(self.device)
|
187 |
+
image, has_nsfw_concept = self.safety_checker(
|
188 |
+
images=image, clip_input=safety_checker_input.pixel_values.to(self.weight_dtype)
|
189 |
+
)
|
190 |
+
except Exception as e:
|
191 |
+
print(f"[WARN] Safety checker çalıştırılamadı: {e}")
|
192 |
+
has_nsfw_concept = None
|
193 |
return image, has_nsfw_concept
|
194 |
|
195 |
def check_inputs(self, image, condition_image, mask, width, height):
|
|
|
302 |
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
303 |
image = numpy_to_pil(image)
|
304 |
|
305 |
+
# Safety checker kontrolü
|
306 |
+
if not self.skip_safety_check and self.safety_checker is not None:
|
307 |
current_script_directory = os.path.dirname(os.path.realpath(__file__))
|
308 |
nsfw_image_path = os.path.join(os.path.dirname(current_script_directory), "resource", "img", "NSFW.jpg")
|
309 |
try:
|
310 |
nsfw_image = PIL.Image.open(nsfw_image_path).resize(image[0].size)
|
311 |
except Exception:
|
312 |
nsfw_image = None
|
313 |
+
|
314 |
+
try:
|
315 |
+
image_np = np.array(image)
|
316 |
+
_, has_nsfw_concept = self.run_safety_checker(image=image_np)
|
317 |
+
if has_nsfw_concept is not None:
|
318 |
+
for i, not_safe in enumerate(has_nsfw_concept):
|
319 |
+
if not_safe and nsfw_image is not None:
|
320 |
+
image[i] = nsfw_image
|
321 |
+
print(f"[WARN] NSFW içerik tespit edildi, görüntü {i} değiştirildi")
|
322 |
+
except Exception as e:
|
323 |
+
print(f"[WARN] Safety check sırasında hata: {e}")
|
324 |
|
325 |
return image
|
326 |
|