import contextlib import inspect import sys import threading import time from typing import Any from shared.api import GenerationError, GenerationResult, SessionJob, _GENERATION_LOCK, _OutputCapture, _pushd from shared.utils.thread_utils import AsyncStream def run_cli_job(session, job: SessionJob, tasks: list[dict[str, Any]]) -> None: stream = AsyncStream() gen = session._state["gen"] worker_done = threading.Event() base_file_count = len(gen["file_list"]) base_audio_count = len(gen["audio_file_list"]) total_tasks = len(tasks) runtime = None task_summary: dict[str, Any] = { "errors": [], "successful_tasks": 0, "failed_tasks": 0, "total_tasks": total_tasks, } try: runtime = session._ensure_runtime() with _GENERATION_LOCK, _pushd(runtime.root): session._configure_runtime(runtime) session._prepare_state_for_run(tasks) job.events.put("started", {"tasks": len(tasks)}) def worker() -> None: stdout_capture = _OutputCapture( "stdout", lambda stream_name, line: session._emit_stream(job, stream_name, line), console=sys.__stdout__ if session._console_output else None, console_isatty=session._console_isatty, ) stderr_capture = _OutputCapture( "stderr", lambda stream_name, line: session._emit_stream(job, stream_name, line), console=sys.__stderr__ if session._console_output else None, console_isatty=session._console_isatty, ) try: with contextlib.redirect_stdout(stdout_capture), contextlib.redirect_stderr(stderr_capture): _run_tasks_worker(session, runtime.module, tasks, stream, job, task_summary) except BaseException as exc: failure = session._make_generation_error(exc, task_index=None, task_id=None, stage="runtime") task_summary["errors"].append(failure) stream.output_queue.push("error", failure) finally: stdout_capture.flush() stderr_capture.flush() stream.output_queue.push("worker_exit", None) worker_done.set() worker_thread = threading.Thread(target=worker, daemon=True, name="wangp-session-worker") worker_thread.start() while True: if job.cancel_requested: session._request_cancel_unlocked(runtime.module) item = stream.output_queue.pop() if item is None: if worker_done.is_set() and not worker_thread.is_alive(): break time.sleep(0.01) continue command, data = item if command == "worker_exit": break _handle_command(session, job, runtime.module, tasks, command, data) worker_thread.join(timeout=0.1) outputs = session._collect_outputs(base_file_count, base_audio_count) artifacts = session._consume_output_artifacts(tasks) if job.cancel_requested and not task_summary["errors"]: task_summary["errors"].append(GenerationError(message="Generation was cancelled", stage="cancelled")) task_summary["failed_tasks"] = max(task_summary["failed_tasks"], 1) result = GenerationResult( success=not task_summary["errors"], generated_files=outputs, errors=list(task_summary["errors"]), total_tasks=task_summary["total_tasks"], successful_tasks=task_summary["successful_tasks"], failed_tasks=task_summary["failed_tasks"], artifacts=artifacts, ) job.events.put("completed", result) session._emit_callback("on_complete", result, job=job) job._set_result(result) except BaseException as exc: failure = session._make_generation_error(exc, task_index=None, task_id=None, stage="runtime") result = GenerationResult( success=False, generated_files=[], errors=[failure], total_tasks=total_tasks, successful_tasks=task_summary["successful_tasks"], failed_tasks=max(task_summary["failed_tasks"], 1 if total_tasks > 0 else 0), artifacts=(), ) job.events.put("error", failure) session._emit_callback("on_error", failure, job=job) job.events.put("completed", result) session._emit_callback("on_complete", result, job=job) job._set_result(result) finally: job.events.close() if runtime is not None: session._reset_state_after_run() with session._job_lock: if session._active_job is job: session._active_job = None def _run_tasks_worker(session, wgp, tasks: list[dict[str, Any]], stream: AsyncStream, job: SessionJob, task_summary: dict[str, Any]) -> None: expected_args = set(inspect.signature(wgp.generate_video).parameters.keys()) total_tasks = len(tasks) for task_index, task in enumerate(tasks, start=1): if job.cancel_requested: break session._state["gen"]["prompt_no"] = task_index session._state["gen"]["prompts_max"] = total_tasks session._state["gen"]["queue"] = tasks task_id = task.get("id") task_errors: list[GenerationError] = [] def send_cmd(command: str, data: Any = None) -> None: if command == "error": failure = session._make_generation_error(data, task_index=task_index, task_id=task_id, stage="generation") task_errors.append(failure) stream.output_queue.push("error", failure) return stream.output_queue.push(command, data) validated_settings, validation_error = wgp.validate_task(task, session._state) if validated_settings is None: failure = GenerationError( message=validation_error or f"Task {task_index} failed validation", task_index=task_index, task_id=task_id, stage="validation", ) task_summary["errors"].append(failure) task_summary["failed_tasks"] += 1 stream.output_queue.push("error", failure) continue task_settings = validated_settings.copy() task_settings["state"] = session._state filtered_params = {key: value for key, value in task_settings.items() if key in expected_args} plugin_data = task.get("plugin_data", {}) try: success = wgp.generate_video(task, send_cmd, plugin_data=plugin_data, **filtered_params) except BaseException as exc: if not task_errors: task_errors.append(session._make_generation_error(exc, task_index=task_index, task_id=task_id, stage="generation")) stream.output_queue.push("error", task_errors[-1]) success = False if session._state["gen"].get("abort", False) or job.cancel_requested: task_errors.append(GenerationError(message="Generation was cancelled", task_index=task_index, task_id=task_id, stage="cancelled")) stream.output_queue.push("error", task_errors[-1]) task_summary["errors"].extend(task_errors) task_summary["failed_tasks"] += 1 break if task_errors: task_summary["errors"].extend(task_errors) task_summary["failed_tasks"] += 1 continue if not success: failure = GenerationError( message=f"Task {task_index} did not complete successfully", task_index=task_index, task_id=task_id, stage="generation", ) task_summary["errors"].append(failure) task_summary["failed_tasks"] += 1 stream.output_queue.push("error", failure) continue task_summary["successful_tasks"] += 1 def _handle_command(session, job: SessionJob, wgp, tasks: list[dict[str, Any]], command: str, data: Any) -> None: if command == "progress": progress = session._build_progress_update(data) job.events.put("progress", progress) session._emit_callback("on_progress", progress, job=job) return if command == "preview": preview = session._build_preview_update(wgp, tasks, data) if preview is not None: job.events.put("preview", preview) session._emit_callback("on_preview", preview, job=job) return if command == "status": text = str(data or "") job.events.put("status", text) session._emit_callback("on_status", text, job=job) return if command == "info": text = str(data or "") job.events.put("info", text) session._emit_callback("on_info", text, job=job) return if command == "output": job.events.put("output", data) session._emit_callback("on_output", data, job=job) return if command == "refresh_models": job.events.put("refresh_models", data) return if command == "error": error = data if isinstance(data, GenerationError) else session._make_generation_error(data) job.events.put("error", error) session._emit_callback("on_error", error, job=job) return