k4d3 commited on
Commit
779901b
·
1 Parent(s): 1de94ef

Enhance image processing workflow in wdv3.py

Browse files

This commit introduces a new function, `process_image`, to streamline the image processing logic by ensuring the global processor is initialized before processing each image. The batch processing setup is updated to sort and deduplicate image paths, improving efficiency. Additionally, comments are added for clarity, and the number of workers remains set to one to minimize CUDA contention. These changes enhance the maintainability and robustness of the WDV3Processor class.

Files changed (1) hide show
  1. caption/wdv3.py +16 -2
caption/wdv3.py CHANGED
@@ -1,4 +1,7 @@
1
  #!/usr/bin/env python
 
 
 
2
 
3
  import sys
4
  import os
@@ -232,12 +235,19 @@ class WDV3Processor(BatchProcessor[Path, None]):
232
  if torch.cuda.is_available():
233
  torch.cuda.empty_cache()
234
 
 
 
 
 
 
 
 
235
  def main(opts: ScriptOptions):
236
  target_path = Path(opts.image_file).resolve()
237
 
238
  batch_opts = BatchOptions(
239
  batch_size=16,
240
- num_workers=1, # Single worker
241
  device="cpu" if opts.cpu else "cuda",
242
  debug=False,
243
  skip_existing=True,
@@ -253,8 +263,12 @@ def main(opts: ScriptOptions):
253
  image_paths.extend(target_path.rglob(f'*{ext}'))
254
  image_paths.extend(target_path.rglob(f'*{ext.upper()}'))
255
 
 
 
 
 
256
  with mp.Pool(processes=batch_opts.num_workers, initializer=initializer, initargs=(batch_opts, opts.model)) as pool:
257
- pool.map(lambda img: global_processor.process_item(img), sorted(set(image_paths)))
258
  pool.close()
259
  pool.join()
260
  else:
 
1
  #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ #
4
+ # caption/wdv3.py
5
 
6
  import sys
7
  import os
 
235
  if torch.cuda.is_available():
236
  torch.cuda.empty_cache()
237
 
238
+ # Function to process a single image, required to be at the top level for pickling
239
+ def process_image(img_path: Path):
240
+ global global_processor
241
+ if global_processor is None:
242
+ raise ValueError("Global processor not initialized. Ensure initializer is set correctly.")
243
+ global_processor.process_item(img_path)
244
+
245
  def main(opts: ScriptOptions):
246
  target_path = Path(opts.image_file).resolve()
247
 
248
  batch_opts = BatchOptions(
249
  batch_size=16,
250
+ num_workers=1, # Single worker to minimize CUDA contention
251
  device="cpu" if opts.cpu else "cuda",
252
  debug=False,
253
  skip_existing=True,
 
263
  image_paths.extend(target_path.rglob(f'*{ext}'))
264
  image_paths.extend(target_path.rglob(f'*{ext.upper()}'))
265
 
266
+ # Sort and deduplicate image paths
267
+ sorted_unique_image_paths = sorted(set(image_paths))
268
+
269
+ # Initialize the multiprocessing Pool with the initializer
270
  with mp.Pool(processes=batch_opts.num_workers, initializer=initializer, initargs=(batch_opts, opts.model)) as pool:
271
+ pool.map(process_image, sorted_unique_image_paths)
272
  pool.close()
273
  pool.join()
274
  else: