1Smurf1 commited on
Commit
3da6b77
·
verified ·
1 Parent(s): 71ccd86

Update model/pipeline.py

Browse files
Files changed (1) hide show
  1. model/pipeline.py +120 -32
model/pipeline.py CHANGED
@@ -42,72 +42,154 @@ class CatVTONPipeline:
42
  self.weight_dtype = weight_dtype
43
  self.skip_safety_check = skip_safety_check
44
 
 
 
45
  # TF32 sadece CUDA'da anlamlı
46
  if use_tf32 and self.device.type == "cuda":
47
  torch.set_float32_matmul_precision("high")
48
  torch.backends.cuda.matmul.allow_tf32 = True
 
49
 
50
  # Scheduler
51
- self.noise_scheduler = DDIMScheduler.from_pretrained(base_ckpt, subfolder="scheduler")
 
 
 
 
 
52
 
53
  # VAE yükleme (hata durumunda fallback)
54
  try:
55
  self.vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse").to(
56
  self.device, dtype=self.weight_dtype
57
  )
 
58
  except Exception as e:
59
  print(f"[WARN] VAE yüklenirken hata: {e}. float32 fallback yapılıyor.")
60
- self.vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse").to(
61
- self.device, dtype=torch.float32
62
- )
 
 
 
 
 
63
 
64
- # Safety checker
65
  if not skip_safety_check:
66
- self.feature_extractor = CLIPImageProcessor.from_pretrained(base_ckpt, subfolder="feature_extractor")
67
- self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(
68
- base_ckpt, subfolder="safety_checker"
69
- ).to(self.device, dtype=self.weight_dtype)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  else:
71
  self.feature_extractor = None
72
  self.safety_checker = None
 
73
 
74
  # UNet ve adapter
75
- self.unet = UNet2DConditionModel.from_pretrained(base_ckpt, subfolder="unet").to(
76
- self.device, dtype=self.weight_dtype
77
- )
78
- init_adapter(self.unet, cross_attn_cls=SkipAttnProcessor)
79
- self.attn_modules = get_trainable_module(self.unet, "attention")
80
- self.auto_attn_ckpt_load(attn_ckpt, attn_ckpt_version)
 
 
 
 
 
 
 
 
 
81
 
82
  # Compile (isteğe bağlı)
83
  if compile:
84
  try:
 
85
  self.unet = torch.compile(self.unet)
86
  self.vae = torch.compile(self.vae, mode="reduce-overhead")
 
87
  except Exception as e:
88
  print(f"[WARN] Compile sırasında hata: {e}. Compile edilmemiş sürüm devam ediyor.")
89
 
 
 
90
  def auto_attn_ckpt_load(self, attn_ckpt, version):
91
  sub_folder = {
92
  "mix": "mix-48k-1024",
93
- "vitonhd": "vitonhd-16k-512",
94
  "dresscode": "dresscode-16k-512",
95
  }[version]
 
96
  if os.path.exists(attn_ckpt):
97
- load_checkpoint_in_model(self.attn_modules, os.path.join(attn_ckpt, sub_folder, "attention"))
 
 
 
 
 
98
  else:
99
- repo_path = snapshot_download(repo_id=attn_ckpt)
100
- print(f"Downloaded {attn_ckpt} to {repo_path}")
101
- load_checkpoint_in_model(self.attn_modules, os.path.join(repo_path, sub_folder, "attention"))
 
 
 
 
 
102
 
103
  def run_safety_checker(self, image):
104
- if self.safety_checker is None:
105
  has_nsfw_concept = None
106
  else:
107
- safety_checker_input = self.feature_extractor(image, return_tensors="pt").to(self.device)
108
- image, has_nsfw_concept = self.safety_checker(
109
- images=image, clip_input=safety_checker_input.pixel_values.to(self.weight_dtype)
110
- )
 
 
 
 
111
  return image, has_nsfw_concept
112
 
113
  def check_inputs(self, image, condition_image, mask, width, height):
@@ -220,19 +302,25 @@ class CatVTONPipeline:
220
  image = image.cpu().permute(0, 2, 3, 1).float().numpy()
221
  image = numpy_to_pil(image)
222
 
223
- if not self.skip_safety_check:
 
224
  current_script_directory = os.path.dirname(os.path.realpath(__file__))
225
  nsfw_image_path = os.path.join(os.path.dirname(current_script_directory), "resource", "img", "NSFW.jpg")
226
  try:
227
  nsfw_image = PIL.Image.open(nsfw_image_path).resize(image[0].size)
228
  except Exception:
229
  nsfw_image = None
230
- image_np = np.array(image)
231
- _, has_nsfw_concept = self.run_safety_checker(image=image_np)
232
- if has_nsfw_concept is not None:
233
- for i, not_safe in enumerate(has_nsfw_concept):
234
- if not_safe and nsfw_image is not None:
235
- image[i] = nsfw_image
 
 
 
 
 
236
 
237
  return image
238
 
 
42
  self.weight_dtype = weight_dtype
43
  self.skip_safety_check = skip_safety_check
44
 
45
+ print(f"[INFO] Pipeline başlatılıyor - Device: {self.device}, dtype: {self.weight_dtype}")
46
+
47
  # TF32 sadece CUDA'da anlamlı
48
  if use_tf32 and self.device.type == "cuda":
49
  torch.set_float32_matmul_precision("high")
50
  torch.backends.cuda.matmul.allow_tf32 = True
51
+ print("[INFO] TF32 etkinleştirildi")
52
 
53
  # Scheduler
54
+ try:
55
+ self.noise_scheduler = DDIMScheduler.from_pretrained(base_ckpt, subfolder="scheduler")
56
+ print("[INFO] Scheduler yüklendi")
57
+ except Exception as e:
58
+ print(f"[ERROR] Scheduler yüklenirken hata: {e}")
59
+ raise
60
 
61
  # VAE yükleme (hata durumunda fallback)
62
  try:
63
  self.vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse").to(
64
  self.device, dtype=self.weight_dtype
65
  )
66
+ print("[INFO] VAE yüklendi")
67
  except Exception as e:
68
  print(f"[WARN] VAE yüklenirken hata: {e}. float32 fallback yapılıyor.")
69
+ try:
70
+ self.vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse").to(
71
+ self.device, dtype=torch.float32
72
+ )
73
+ print("[INFO] VAE float32 ile yüklendi")
74
+ except Exception as e2:
75
+ print(f"[ERROR] VAE yüklenemedi: {e2}")
76
+ raise
77
 
78
+ # Safety checker - ZeroGPU uyumlu hata yakalama
79
  if not skip_safety_check:
80
+ try:
81
+ self.feature_extractor = CLIPImageProcessor.from_pretrained(base_ckpt, subfolder="feature_extractor")
82
+
83
+ # Safety checker için çoklu deneme stratejisi
84
+ safety_checker_loaded = False
85
+
86
+ # 1. Önce safetensors formatını dene
87
+ try:
88
+ self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(
89
+ base_ckpt,
90
+ subfolder="safety_checker",
91
+ use_safetensors=True,
92
+ torch_dtype=self.weight_dtype
93
+ ).to(self.device, dtype=self.weight_dtype)
94
+ safety_checker_loaded = True
95
+ print("[INFO] Safety checker (safetensors) yüklendi")
96
+ except Exception as e1:
97
+ print(f"[WARN] Safetensors safety checker yüklenemedi: {e1}")
98
+
99
+ # 2. Eğer safetensors başarısızsa, farklı model dene
100
+ if not safety_checker_loaded:
101
+ try:
102
+ self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(
103
+ "CompVis/stable-diffusion-safety-checker",
104
+ torch_dtype=self.weight_dtype
105
+ ).to(self.device, dtype=self.weight_dtype)
106
+ safety_checker_loaded = True
107
+ print("[INFO] Safety checker (CompVis) yüklendi")
108
+ except Exception as e2:
109
+ print(f"[WARN] CompVis safety checker yüklenemedi: {e2}")
110
+
111
+ # 3. Son çare olarak safety checker'ı devre dışı bırak
112
+ if not safety_checker_loaded:
113
+ print("[WARN] Safety checker yüklenemedi, devre dışı bırakılıyor")
114
+ self.safety_checker = None
115
+ self.feature_extractor = None
116
+ self.skip_safety_check = True
117
+
118
+ except Exception as e:
119
+ print(f"[WARN] Safety checker başlatılamadı: {e}. Devre dışı bırakılıyor.")
120
+ self.feature_extractor = None
121
+ self.safety_checker = None
122
+ self.skip_safety_check = True
123
  else:
124
  self.feature_extractor = None
125
  self.safety_checker = None
126
+ print("[INFO] Safety checker devre dışı")
127
 
128
  # UNet ve adapter
129
+ try:
130
+ self.unet = UNet2DConditionModel.from_pretrained(base_ckpt, subfolder="unet").to(
131
+ self.device, dtype=self.weight_dtype
132
+ )
133
+ print("[INFO] UNet yüklendi")
134
+
135
+ init_adapter(self.unet, cross_attn_cls=SkipAttnProcessor)
136
+ self.attn_modules = get_trainable_module(self.unet, "attention")
137
+ print("[INFO] Adapter başlatıldı")
138
+
139
+ self.auto_attn_ckpt_load(attn_ckpt, attn_ckpt_version)
140
+ print("[INFO] Attention checkpoint yüklendi")
141
+ except Exception as e:
142
+ print(f"[ERROR] UNet yüklenirken hata: {e}")
143
+ raise
144
 
145
  # Compile (isteğe bağlı)
146
  if compile:
147
  try:
148
+ print("[INFO] Model compile ediliyor...")
149
  self.unet = torch.compile(self.unet)
150
  self.vae = torch.compile(self.vae, mode="reduce-overhead")
151
+ print("[INFO] Model compile edildi")
152
  except Exception as e:
153
  print(f"[WARN] Compile sırasında hata: {e}. Compile edilmemiş sürüm devam ediyor.")
154
 
155
+ print("[INFO] Pipeline başarıyla başlatıldı")
156
+
157
  def auto_attn_ckpt_load(self, attn_ckpt, version):
158
  sub_folder = {
159
  "mix": "mix-48k-1024",
160
+ "vitonhd": "vitonhd-16k-512",
161
  "dresscode": "dresscode-16k-512",
162
  }[version]
163
+
164
  if os.path.exists(attn_ckpt):
165
+ checkpoint_path = os.path.join(attn_ckpt, sub_folder, "attention")
166
+ if os.path.exists(checkpoint_path):
167
+ load_checkpoint_in_model(self.attn_modules, checkpoint_path)
168
+ print(f"[INFO] Local checkpoint yüklendi: {checkpoint_path}")
169
+ else:
170
+ print(f"[WARN] Local checkpoint bulunamadı: {checkpoint_path}")
171
  else:
172
+ try:
173
+ repo_path = snapshot_download(repo_id=attn_ckpt)
174
+ checkpoint_path = os.path.join(repo_path, sub_folder, "attention")
175
+ load_checkpoint_in_model(self.attn_modules, checkpoint_path)
176
+ print(f"[INFO] Downloaded checkpoint yüklendi: {checkpoint_path}")
177
+ except Exception as e:
178
+ print(f"[ERROR] Checkpoint indirilemedi: {e}")
179
+ raise
180
 
181
  def run_safety_checker(self, image):
182
+ if self.safety_checker is None or self.skip_safety_check:
183
  has_nsfw_concept = None
184
  else:
185
+ try:
186
+ safety_checker_input = self.feature_extractor(image, return_tensors="pt").to(self.device)
187
+ image, has_nsfw_concept = self.safety_checker(
188
+ images=image, clip_input=safety_checker_input.pixel_values.to(self.weight_dtype)
189
+ )
190
+ except Exception as e:
191
+ print(f"[WARN] Safety checker çalıştırılamadı: {e}")
192
+ has_nsfw_concept = None
193
  return image, has_nsfw_concept
194
 
195
  def check_inputs(self, image, condition_image, mask, width, height):
 
302
  image = image.cpu().permute(0, 2, 3, 1).float().numpy()
303
  image = numpy_to_pil(image)
304
 
305
+ # Safety checker kontrolü
306
+ if not self.skip_safety_check and self.safety_checker is not None:
307
  current_script_directory = os.path.dirname(os.path.realpath(__file__))
308
  nsfw_image_path = os.path.join(os.path.dirname(current_script_directory), "resource", "img", "NSFW.jpg")
309
  try:
310
  nsfw_image = PIL.Image.open(nsfw_image_path).resize(image[0].size)
311
  except Exception:
312
  nsfw_image = None
313
+
314
+ try:
315
+ image_np = np.array(image)
316
+ _, has_nsfw_concept = self.run_safety_checker(image=image_np)
317
+ if has_nsfw_concept is not None:
318
+ for i, not_safe in enumerate(has_nsfw_concept):
319
+ if not_safe and nsfw_image is not None:
320
+ image[i] = nsfw_image
321
+ print(f"[WARN] NSFW içerik tespit edildi, görüntü {i} değiştirildi")
322
+ except Exception as e:
323
+ print(f"[WARN] Safety check sırasında hata: {e}")
324
 
325
  return image
326