Spaces:
Runtime error
Runtime error
File size: 3,194 Bytes
ffa9e8f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 |
from typing import Any, List, Callable
import psutil
import os
from concurrent.futures import ThreadPoolExecutor, as_completed
from queue import Queue
from .image import ChainImgProcessor
from tqdm import tqdm
import cv2
def create_queue(temp_frame_paths: List[str]) -> Queue[str]:
queue: Queue[str] = Queue()
for frame_path in temp_frame_paths:
queue.put(frame_path)
return queue
def pick_queue(queue: Queue[str], queue_per_future: int) -> List[str]:
queues = []
for _ in range(queue_per_future):
if not queue.empty():
queues.append(queue.get())
return queues
class ChainBatchImageProcessor(ChainImgProcessor):
chain = None
func_params_gen = None
num_threads = 1
def __init__(self):
ChainImgProcessor.__init__(self)
def init_with_plugins(self):
self.init_plugins(["core"])
self.display_init_info()
init_on_start_arr = self.init_on_start.split(",")
for proc_id in init_on_start_arr:
self.init_processor(proc_id)
def update_progress(self, progress: Any = None) -> None:
process = psutil.Process(os.getpid())
memory_usage = process.memory_info().rss / 1024 / 1024 / 1024
progress.set_postfix({
'memory_usage': '{:.2f}'.format(memory_usage).zfill(5) + 'GB',
'execution_threads': self.num_threads
})
progress.refresh()
progress.update(1)
def process_frames(self, source_files: List[str], target_files: List[str], current_files, update: Callable[[], None]) -> None:
for f in current_files:
temp_frame = cv2.imread(f)
if temp_frame is not None:
if self.func_params_gen:
params = self.func_params_gen(None, temp_frame)
else:
params = {}
resimg, _ = self.run_chain(temp_frame, params, self.chain)
if resimg is not None:
i = source_files.index(f)
cv2.imwrite(target_files[i], resimg)
if update:
update()
def run_batch_chain(self, source_files, target_files, threads:int = 1, chain = None, params_frame_gen_func = None):
self.chain = chain
self.func_params_gen = params_frame_gen_func
progress_bar_format = '{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}{postfix}]'
total = len(source_files)
self.num_threads = threads
with tqdm(total=total, desc='Processing', unit='frame', dynamic_ncols=True, bar_format=progress_bar_format) as progress:
with ThreadPoolExecutor(max_workers=threads) as executor:
futures = []
queue = create_queue(source_files)
queue_per_future = max(len(source_files) // threads, 1)
while not queue.empty():
future = executor.submit(self.process_frames, source_files, target_files, pick_queue(queue, queue_per_future), lambda: self.update_progress(progress))
futures.append(future)
for future in as_completed(futures):
future.result()
|