Spaces:
Sleeping
Sleeping
| import asyncio | |
| import contextlib | |
| import json | |
| import logging | |
| import time | |
| import uuid | |
| from dataclasses import dataclass | |
| from enum import Enum | |
| from io import BytesIO | |
| from typing import Any, Callable, Iterable, Literal, Optional, Type, TypeVar, Union | |
| from urllib.parse import urljoin, urlparse | |
| import aiohttp | |
| from aiohttp.client_exceptions import ClientError, ContentTypeError | |
| from pydantic import BaseModel | |
| from comfy import utils | |
| from comfy_api.latest import IO | |
| from comfy_api_nodes.apis import request_logger | |
| from server import PromptServer | |
| from ._helpers import ( | |
| default_base_url, | |
| get_auth_header, | |
| get_node_id, | |
| is_processing_interrupted, | |
| sleep_with_interrupt, | |
| ) | |
| from .common_exceptions import ApiServerError, LocalNetworkError, ProcessingInterrupted | |
| M = TypeVar("M", bound=BaseModel) | |
| class ApiEndpoint: | |
| def __init__( | |
| self, | |
| path: str, | |
| method: Literal["GET", "POST", "PUT", "DELETE", "PATCH"] = "GET", | |
| *, | |
| query_params: Optional[dict[str, Any]] = None, | |
| headers: Optional[dict[str, str]] = None, | |
| ): | |
| self.path = path | |
| self.method = method | |
| self.query_params = query_params or {} | |
| self.headers = headers or {} | |
| class _RequestConfig: | |
| node_cls: type[IO.ComfyNode] | |
| endpoint: ApiEndpoint | |
| timeout: float | |
| content_type: str | |
| data: Optional[dict[str, Any]] | |
| files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]] | |
| multipart_parser: Optional[Callable] | |
| max_retries: int | |
| retry_delay: float | |
| retry_backoff: float | |
| wait_label: str = "Waiting" | |
| monitor_progress: bool = True | |
| estimated_total: Optional[int] = None | |
| final_label_on_success: Optional[str] = "Completed" | |
| progress_origin_ts: Optional[float] = None | |
| class _PollUIState: | |
| started: float | |
| status_label: str = "Queued" | |
| is_queued: bool = True | |
| price: Optional[float] = None | |
| estimated_duration: Optional[int] = None | |
| base_processing_elapsed: float = 0.0 # sum of completed active intervals | |
| active_since: Optional[float] = None # start time of current active interval (None if queued) | |
| _RETRY_STATUS = {408, 429, 500, 502, 503, 504} | |
| COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed"] | |
| FAILED_STATUSES = ["cancelled", "canceled", "failed", "error"] | |
| QUEUED_STATUSES = ["created", "queued", "queueing", "submitted"] | |
| async def sync_op( | |
| cls: type[IO.ComfyNode], | |
| endpoint: ApiEndpoint, | |
| *, | |
| response_model: Type[M], | |
| data: Optional[BaseModel] = None, | |
| files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]] = None, | |
| content_type: str = "application/json", | |
| timeout: float = 3600.0, | |
| multipart_parser: Optional[Callable] = None, | |
| max_retries: int = 3, | |
| retry_delay: float = 1.0, | |
| retry_backoff: float = 2.0, | |
| wait_label: str = "Waiting for server", | |
| estimated_duration: Optional[int] = None, | |
| final_label_on_success: Optional[str] = "Completed", | |
| progress_origin_ts: Optional[float] = None, | |
| monitor_progress: bool = True, | |
| ) -> M: | |
| raw = await sync_op_raw( | |
| cls, | |
| endpoint, | |
| data=data, | |
| files=files, | |
| content_type=content_type, | |
| timeout=timeout, | |
| multipart_parser=multipart_parser, | |
| max_retries=max_retries, | |
| retry_delay=retry_delay, | |
| retry_backoff=retry_backoff, | |
| wait_label=wait_label, | |
| estimated_duration=estimated_duration, | |
| as_binary=False, | |
| final_label_on_success=final_label_on_success, | |
| progress_origin_ts=progress_origin_ts, | |
| monitor_progress=monitor_progress, | |
| ) | |
| if not isinstance(raw, dict): | |
| raise Exception("Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text).") | |
| return _validate_or_raise(response_model, raw) | |
| async def poll_op( | |
| cls: type[IO.ComfyNode], | |
| poll_endpoint: ApiEndpoint, | |
| *, | |
| response_model: Type[M], | |
| status_extractor: Callable[[M], Optional[Union[str, int]]], | |
| progress_extractor: Optional[Callable[[M], Optional[int]]] = None, | |
| price_extractor: Optional[Callable[[M], Optional[float]]] = None, | |
| completed_statuses: Optional[list[Union[str, int]]] = None, | |
| failed_statuses: Optional[list[Union[str, int]]] = None, | |
| queued_statuses: Optional[list[Union[str, int]]] = None, | |
| data: Optional[BaseModel] = None, | |
| poll_interval: float = 5.0, | |
| max_poll_attempts: int = 120, | |
| timeout_per_poll: float = 120.0, | |
| max_retries_per_poll: int = 3, | |
| retry_delay_per_poll: float = 1.0, | |
| retry_backoff_per_poll: float = 2.0, | |
| estimated_duration: Optional[int] = None, | |
| cancel_endpoint: Optional[ApiEndpoint] = None, | |
| cancel_timeout: float = 10.0, | |
| ) -> M: | |
| raw = await poll_op_raw( | |
| cls, | |
| poll_endpoint=poll_endpoint, | |
| status_extractor=_wrap_model_extractor(response_model, status_extractor), | |
| progress_extractor=_wrap_model_extractor(response_model, progress_extractor), | |
| price_extractor=_wrap_model_extractor(response_model, price_extractor), | |
| completed_statuses=completed_statuses, | |
| failed_statuses=failed_statuses, | |
| queued_statuses=queued_statuses, | |
| data=data, | |
| poll_interval=poll_interval, | |
| max_poll_attempts=max_poll_attempts, | |
| timeout_per_poll=timeout_per_poll, | |
| max_retries_per_poll=max_retries_per_poll, | |
| retry_delay_per_poll=retry_delay_per_poll, | |
| retry_backoff_per_poll=retry_backoff_per_poll, | |
| estimated_duration=estimated_duration, | |
| cancel_endpoint=cancel_endpoint, | |
| cancel_timeout=cancel_timeout, | |
| ) | |
| if not isinstance(raw, dict): | |
| raise Exception("Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text).") | |
| return _validate_or_raise(response_model, raw) | |
| async def sync_op_raw( | |
| cls: type[IO.ComfyNode], | |
| endpoint: ApiEndpoint, | |
| *, | |
| data: Optional[Union[dict[str, Any], BaseModel]] = None, | |
| files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]] = None, | |
| content_type: str = "application/json", | |
| timeout: float = 3600.0, | |
| multipart_parser: Optional[Callable] = None, | |
| max_retries: int = 3, | |
| retry_delay: float = 1.0, | |
| retry_backoff: float = 2.0, | |
| wait_label: str = "Waiting for server", | |
| estimated_duration: Optional[int] = None, | |
| as_binary: bool = False, | |
| final_label_on_success: Optional[str] = "Completed", | |
| progress_origin_ts: Optional[float] = None, | |
| monitor_progress: bool = True, | |
| ) -> Union[dict[str, Any], bytes]: | |
| """ | |
| Make a single network request. | |
| - If as_binary=False (default): returns JSON dict (or {'_raw': '<text>'} if non-JSON). | |
| - If as_binary=True: returns bytes. | |
| """ | |
| if isinstance(data, BaseModel): | |
| data = data.model_dump(exclude_none=True) | |
| for k, v in list(data.items()): | |
| if isinstance(v, Enum): | |
| data[k] = v.value | |
| cfg = _RequestConfig( | |
| node_cls=cls, | |
| endpoint=endpoint, | |
| timeout=timeout, | |
| content_type=content_type, | |
| data=data, | |
| files=files, | |
| multipart_parser=multipart_parser, | |
| max_retries=max_retries, | |
| retry_delay=retry_delay, | |
| retry_backoff=retry_backoff, | |
| wait_label=wait_label, | |
| monitor_progress=monitor_progress, | |
| estimated_total=estimated_duration, | |
| final_label_on_success=final_label_on_success, | |
| progress_origin_ts=progress_origin_ts, | |
| ) | |
| return await _request_base(cfg, expect_binary=as_binary) | |
| async def poll_op_raw( | |
| cls: type[IO.ComfyNode], | |
| poll_endpoint: ApiEndpoint, | |
| *, | |
| status_extractor: Callable[[dict[str, Any]], Optional[Union[str, int]]], | |
| progress_extractor: Optional[Callable[[dict[str, Any]], Optional[int]]] = None, | |
| price_extractor: Optional[Callable[[dict[str, Any]], Optional[float]]] = None, | |
| completed_statuses: Optional[list[Union[str, int]]] = None, | |
| failed_statuses: Optional[list[Union[str, int]]] = None, | |
| queued_statuses: Optional[list[Union[str, int]]] = None, | |
| data: Optional[Union[dict[str, Any], BaseModel]] = None, | |
| poll_interval: float = 5.0, | |
| max_poll_attempts: int = 120, | |
| timeout_per_poll: float = 120.0, | |
| max_retries_per_poll: int = 3, | |
| retry_delay_per_poll: float = 1.0, | |
| retry_backoff_per_poll: float = 2.0, | |
| estimated_duration: Optional[int] = None, | |
| cancel_endpoint: Optional[ApiEndpoint] = None, | |
| cancel_timeout: float = 10.0, | |
| ) -> dict[str, Any]: | |
| """ | |
| Polls an endpoint until the task reaches a terminal state. Displays time while queued/processing, | |
| checks interruption every second, and calls Cancel endpoint (if provided) on interruption. | |
| Uses default complete, failed and queued states assumption. | |
| Returns the final JSON response from the poll endpoint. | |
| """ | |
| completed_states = _normalize_statuses(COMPLETED_STATUSES if completed_statuses is None else completed_statuses) | |
| failed_states = _normalize_statuses(FAILED_STATUSES if failed_statuses is None else failed_statuses) | |
| queued_states = _normalize_statuses(QUEUED_STATUSES if queued_statuses is None else queued_statuses) | |
| started = time.monotonic() | |
| consumed_attempts = 0 # counts only non-queued polls | |
| progress_bar = utils.ProgressBar(100) if progress_extractor else None | |
| last_progress: Optional[int] = None | |
| state = _PollUIState(started=started, estimated_duration=estimated_duration) | |
| stop_ticker = asyncio.Event() | |
| async def _ticker(): | |
| """Emit a UI update every second while polling is in progress.""" | |
| try: | |
| while not stop_ticker.is_set(): | |
| if is_processing_interrupted(): | |
| break | |
| now = time.monotonic() | |
| proc_elapsed = state.base_processing_elapsed + ( | |
| (now - state.active_since) if state.active_since is not None else 0.0 | |
| ) | |
| _display_time_progress( | |
| cls, | |
| status=state.status_label, | |
| elapsed_seconds=int(now - state.started), | |
| estimated_total=state.estimated_duration, | |
| price=state.price, | |
| is_queued=state.is_queued, | |
| processing_elapsed_seconds=int(proc_elapsed), | |
| ) | |
| await asyncio.sleep(1.0) | |
| except Exception as exc: | |
| logging.debug("Polling ticker exited: %s", exc) | |
| ticker_task = asyncio.create_task(_ticker()) | |
| try: | |
| while consumed_attempts < max_poll_attempts: | |
| try: | |
| resp_json = await sync_op_raw( | |
| cls, | |
| poll_endpoint, | |
| data=data, | |
| timeout=timeout_per_poll, | |
| max_retries=max_retries_per_poll, | |
| retry_delay=retry_delay_per_poll, | |
| retry_backoff=retry_backoff_per_poll, | |
| wait_label="Checking", | |
| estimated_duration=None, | |
| as_binary=False, | |
| final_label_on_success=None, | |
| monitor_progress=False, | |
| ) | |
| if not isinstance(resp_json, dict): | |
| raise Exception("Polling endpoint returned non-JSON response.") | |
| except ProcessingInterrupted: | |
| if cancel_endpoint: | |
| with contextlib.suppress(Exception): | |
| await sync_op_raw( | |
| cls, | |
| cancel_endpoint, | |
| timeout=cancel_timeout, | |
| max_retries=0, | |
| wait_label="Cancelling task", | |
| estimated_duration=None, | |
| as_binary=False, | |
| final_label_on_success=None, | |
| monitor_progress=False, | |
| ) | |
| raise | |
| try: | |
| status = _normalize_status_value(status_extractor(resp_json)) | |
| except Exception as e: | |
| logging.error("Status extraction failed: %s", e) | |
| status = None | |
| if price_extractor: | |
| new_price = price_extractor(resp_json) | |
| if new_price is not None: | |
| state.price = new_price | |
| if progress_extractor: | |
| new_progress = progress_extractor(resp_json) | |
| if new_progress is not None and last_progress != new_progress: | |
| progress_bar.update_absolute(new_progress, total=100) | |
| last_progress = new_progress | |
| now_ts = time.monotonic() | |
| is_queued = status in queued_states | |
| if is_queued: | |
| if state.active_since is not None: # If we just moved from active -> queued, close the active interval | |
| state.base_processing_elapsed += now_ts - state.active_since | |
| state.active_since = None | |
| else: | |
| if state.active_since is None: # If we just moved from queued -> active, open a new active interval | |
| state.active_since = now_ts | |
| state.is_queued = is_queued | |
| state.status_label = status or ("Queued" if is_queued else "Processing") | |
| if status in completed_states: | |
| if state.active_since is not None: | |
| state.base_processing_elapsed += now_ts - state.active_since | |
| state.active_since = None | |
| stop_ticker.set() | |
| with contextlib.suppress(Exception): | |
| await ticker_task | |
| if progress_bar and last_progress != 100: | |
| progress_bar.update_absolute(100, total=100) | |
| _display_time_progress( | |
| cls, | |
| status=status if status else "Completed", | |
| elapsed_seconds=int(now_ts - started), | |
| estimated_total=estimated_duration, | |
| price=state.price, | |
| is_queued=False, | |
| processing_elapsed_seconds=int(state.base_processing_elapsed), | |
| ) | |
| return resp_json | |
| if status in failed_states: | |
| msg = f"Task failed: {json.dumps(resp_json)}" | |
| logging.error(msg) | |
| raise Exception(msg) | |
| try: | |
| await sleep_with_interrupt(poll_interval, cls, None, None, None) | |
| except ProcessingInterrupted: | |
| if cancel_endpoint: | |
| with contextlib.suppress(Exception): | |
| await sync_op_raw( | |
| cls, | |
| cancel_endpoint, | |
| timeout=cancel_timeout, | |
| max_retries=0, | |
| wait_label="Cancelling task", | |
| estimated_duration=None, | |
| as_binary=False, | |
| final_label_on_success=None, | |
| monitor_progress=False, | |
| ) | |
| raise | |
| if not is_queued: | |
| consumed_attempts += 1 | |
| raise Exception( | |
| f"Polling timed out after {max_poll_attempts} non-queued attempts " | |
| f"(~{int(max_poll_attempts * poll_interval)}s of active polling)." | |
| ) | |
| except ProcessingInterrupted: | |
| raise | |
| except (LocalNetworkError, ApiServerError): | |
| raise | |
| except Exception as e: | |
| raise Exception(f"Polling aborted due to error: {e}") from e | |
| finally: | |
| stop_ticker.set() | |
| with contextlib.suppress(Exception): | |
| await ticker_task | |
| def _display_text( | |
| node_cls: type[IO.ComfyNode], | |
| text: Optional[str], | |
| *, | |
| status: Optional[Union[str, int]] = None, | |
| price: Optional[float] = None, | |
| ) -> None: | |
| display_lines: list[str] = [] | |
| if status: | |
| display_lines.append(f"Status: {status.capitalize() if isinstance(status, str) else status}") | |
| if price is not None: | |
| display_lines.append(f"Price: ${float(price):,.4f}") | |
| if text is not None: | |
| display_lines.append(text) | |
| if display_lines: | |
| PromptServer.instance.send_progress_text("\n".join(display_lines), get_node_id(node_cls)) | |
| def _display_time_progress( | |
| node_cls: type[IO.ComfyNode], | |
| status: Optional[Union[str, int]], | |
| elapsed_seconds: int, | |
| estimated_total: Optional[int] = None, | |
| *, | |
| price: Optional[float] = None, | |
| is_queued: Optional[bool] = None, | |
| processing_elapsed_seconds: Optional[int] = None, | |
| ) -> None: | |
| if estimated_total is not None and estimated_total > 0 and is_queued is False: | |
| pe = processing_elapsed_seconds if processing_elapsed_seconds is not None else elapsed_seconds | |
| remaining = max(0, int(estimated_total) - int(pe)) | |
| time_line = f"Time elapsed: {int(elapsed_seconds)}s (~{remaining}s remaining)" | |
| else: | |
| time_line = f"Time elapsed: {int(elapsed_seconds)}s" | |
| _display_text(node_cls, time_line, status=status, price=price) | |
| async def _diagnose_connectivity() -> dict[str, bool]: | |
| """Best-effort connectivity diagnostics to distinguish local vs. server issues.""" | |
| results = { | |
| "internet_accessible": False, | |
| "api_accessible": False, | |
| } | |
| timeout = aiohttp.ClientTimeout(total=5.0) | |
| async with aiohttp.ClientSession(timeout=timeout) as session: | |
| with contextlib.suppress(ClientError, OSError): | |
| async with session.get("https://www.google.com") as resp: | |
| results["internet_accessible"] = resp.status < 500 | |
| if not results["internet_accessible"]: | |
| return results | |
| parsed = urlparse(default_base_url()) | |
| health_url = f"{parsed.scheme}://{parsed.netloc}/health" | |
| with contextlib.suppress(ClientError, OSError): | |
| async with session.get(health_url) as resp: | |
| results["api_accessible"] = resp.status < 500 | |
| return results | |
| def _unpack_tuple(t: tuple) -> tuple[str, Any, str]: | |
| """Normalize (filename, value, content_type).""" | |
| if len(t) == 2: | |
| return t[0], t[1], "application/octet-stream" | |
| if len(t) == 3: | |
| return t[0], t[1], t[2] | |
| raise ValueError("files tuple must be (filename, file[, content_type])") | |
| def _merge_params(endpoint_params: dict[str, Any], method: str, data: Optional[dict[str, Any]]) -> dict[str, Any]: | |
| params = dict(endpoint_params or {}) | |
| if method.upper() == "GET" and data: | |
| for k, v in data.items(): | |
| if v is not None: | |
| params[k] = v | |
| return params | |
| def _friendly_http_message(status: int, body: Any) -> str: | |
| if status == 401: | |
| return "Unauthorized: Please login first to use this node." | |
| if status == 402: | |
| return "Payment Required: Please add credits to your account to use this node." | |
| if status == 409: | |
| return "There is a problem with your account. Please contact support@comfy.org." | |
| if status == 429: | |
| return "Rate Limit Exceeded: Please try again later." | |
| try: | |
| if isinstance(body, dict): | |
| err = body.get("error") | |
| if isinstance(err, dict): | |
| msg = err.get("message") | |
| typ = err.get("type") | |
| if msg and typ: | |
| return f"API Error: {msg} (Type: {typ})" | |
| if msg: | |
| return f"API Error: {msg}" | |
| return f"API Error: {json.dumps(body)}" | |
| else: | |
| txt = str(body) | |
| if len(txt) <= 200: | |
| return f"API Error (raw): {txt}" | |
| return f"API Error (status {status})" | |
| except Exception: | |
| return f"HTTP {status}: Unknown error" | |
| def _generate_operation_id(method: str, path: str, attempt: int) -> str: | |
| slug = path.strip("/").replace("/", "_") or "op" | |
| return f"{method}_{slug}_try{attempt}_{uuid.uuid4().hex[:8]}" | |
| def _snapshot_request_body_for_logging( | |
| content_type: str, | |
| method: str, | |
| data: Optional[dict[str, Any]], | |
| files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]], | |
| ) -> Optional[Union[dict[str, Any], str]]: | |
| if method.upper() == "GET": | |
| return None | |
| if content_type == "multipart/form-data": | |
| form_fields = sorted([k for k, v in (data or {}).items() if v is not None]) | |
| file_fields: list[dict[str, str]] = [] | |
| if files: | |
| file_iter = files if isinstance(files, list) else list(files.items()) | |
| for field_name, file_obj in file_iter: | |
| if file_obj is None: | |
| continue | |
| if isinstance(file_obj, tuple): | |
| filename = file_obj[0] | |
| else: | |
| filename = getattr(file_obj, "name", field_name) | |
| file_fields.append({"field": field_name, "filename": str(filename or "")}) | |
| return {"_multipart": True, "form_fields": form_fields, "file_fields": file_fields} | |
| if content_type == "application/x-www-form-urlencoded": | |
| return data or {} | |
| return data or {} | |
| async def _request_base(cfg: _RequestConfig, expect_binary: bool): | |
| """Core request with retries, per-second interruption monitoring, true cancellation, and friendly errors.""" | |
| url = cfg.endpoint.path | |
| parsed_url = urlparse(url) | |
| if not parsed_url.scheme and not parsed_url.netloc: # is URL relative? | |
| url = urljoin(default_base_url().rstrip("/") + "/", url.lstrip("/")) | |
| method = cfg.endpoint.method | |
| params = _merge_params(cfg.endpoint.query_params, method, cfg.data if method == "GET" else None) | |
| async def _monitor(stop_evt: asyncio.Event, start_ts: float): | |
| """Every second: update elapsed time and signal interruption.""" | |
| try: | |
| while not stop_evt.is_set(): | |
| if is_processing_interrupted(): | |
| return | |
| if cfg.monitor_progress: | |
| _display_time_progress( | |
| cfg.node_cls, cfg.wait_label, int(time.monotonic() - start_ts), cfg.estimated_total | |
| ) | |
| await asyncio.sleep(1.0) | |
| except asyncio.CancelledError: | |
| return # normal shutdown | |
| start_time = cfg.progress_origin_ts if cfg.progress_origin_ts is not None else time.monotonic() | |
| attempt = 0 | |
| delay = cfg.retry_delay | |
| operation_succeeded: bool = False | |
| final_elapsed_seconds: Optional[int] = None | |
| while True: | |
| attempt += 1 | |
| stop_event = asyncio.Event() | |
| monitor_task: Optional[asyncio.Task] = None | |
| sess: Optional[aiohttp.ClientSession] = None | |
| operation_id = _generate_operation_id(method, cfg.endpoint.path, attempt) | |
| logging.debug("[DEBUG] HTTP %s %s (attempt %d)", method, url, attempt) | |
| payload_headers = {"Accept": "*/*"} | |
| if not parsed_url.scheme and not parsed_url.netloc: # is URL relative? | |
| payload_headers.update(get_auth_header(cfg.node_cls)) | |
| if cfg.endpoint.headers: | |
| payload_headers.update(cfg.endpoint.headers) | |
| payload_kw: dict[str, Any] = {"headers": payload_headers} | |
| if method == "GET": | |
| payload_headers.pop("Content-Type", None) | |
| request_body_log = _snapshot_request_body_for_logging(cfg.content_type, method, cfg.data, cfg.files) | |
| try: | |
| if cfg.monitor_progress: | |
| monitor_task = asyncio.create_task(_monitor(stop_event, start_time)) | |
| timeout = aiohttp.ClientTimeout(total=cfg.timeout) | |
| sess = aiohttp.ClientSession(timeout=timeout) | |
| if cfg.content_type == "multipart/form-data" and method != "GET": | |
| # aiohttp will set Content-Type boundary; remove any fixed Content-Type | |
| payload_headers.pop("Content-Type", None) | |
| if cfg.multipart_parser and cfg.data: | |
| form = cfg.multipart_parser(cfg.data) | |
| if not isinstance(form, aiohttp.FormData): | |
| raise ValueError("multipart_parser must return aiohttp.FormData") | |
| else: | |
| form = aiohttp.FormData(default_to_multipart=True) | |
| if cfg.data: | |
| for k, v in cfg.data.items(): | |
| if v is None: | |
| continue | |
| form.add_field(k, str(v) if not isinstance(v, (bytes, bytearray)) else v) | |
| if cfg.files: | |
| file_iter = cfg.files if isinstance(cfg.files, list) else cfg.files.items() | |
| for field_name, file_obj in file_iter: | |
| if file_obj is None: | |
| continue | |
| if isinstance(file_obj, tuple): | |
| filename, file_value, content_type = _unpack_tuple(file_obj) | |
| else: | |
| filename = getattr(file_obj, "name", field_name) | |
| file_value = file_obj | |
| content_type = "application/octet-stream" | |
| # Attempt to rewind BytesIO for retries | |
| if isinstance(file_value, BytesIO): | |
| with contextlib.suppress(Exception): | |
| file_value.seek(0) | |
| form.add_field(field_name, file_value, filename=filename, content_type=content_type) | |
| payload_kw["data"] = form | |
| elif cfg.content_type == "application/x-www-form-urlencoded" and method != "GET": | |
| payload_headers["Content-Type"] = "application/x-www-form-urlencoded" | |
| payload_kw["data"] = cfg.data or {} | |
| elif method != "GET": | |
| payload_headers["Content-Type"] = "application/json" | |
| payload_kw["json"] = cfg.data or {} | |
| try: | |
| request_logger.log_request_response( | |
| operation_id=operation_id, | |
| request_method=method, | |
| request_url=url, | |
| request_headers=dict(payload_headers) if payload_headers else None, | |
| request_params=dict(params) if params else None, | |
| request_data=request_body_log, | |
| ) | |
| except Exception as _log_e: | |
| logging.debug("[DEBUG] request logging failed: %s", _log_e) | |
| req_coro = sess.request(method, url, params=params, **payload_kw) | |
| req_task = asyncio.create_task(req_coro) | |
| # Race: request vs. monitor (interruption) | |
| tasks = {req_task} | |
| if monitor_task: | |
| tasks.add(monitor_task) | |
| done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) | |
| if monitor_task and monitor_task in done: | |
| # Interrupted – cancel the request and abort | |
| if req_task in pending: | |
| req_task.cancel() | |
| raise ProcessingInterrupted("Task cancelled") | |
| # Otherwise, request finished | |
| resp = await req_task | |
| async with resp: | |
| if resp.status >= 400: | |
| try: | |
| body = await resp.json() | |
| except (ContentTypeError, json.JSONDecodeError): | |
| body = await resp.text() | |
| if resp.status in _RETRY_STATUS and attempt <= cfg.max_retries: | |
| logging.warning( | |
| "HTTP %s %s -> %s. Retrying in %.2fs (retry %d of %d).", | |
| method, | |
| url, | |
| resp.status, | |
| delay, | |
| attempt, | |
| cfg.max_retries, | |
| ) | |
| try: | |
| request_logger.log_request_response( | |
| operation_id=operation_id, | |
| request_method=method, | |
| request_url=url, | |
| response_status_code=resp.status, | |
| response_headers=dict(resp.headers), | |
| response_content=body, | |
| error_message=_friendly_http_message(resp.status, body), | |
| ) | |
| except Exception as _log_e: | |
| logging.debug("[DEBUG] response logging failed: %s", _log_e) | |
| await sleep_with_interrupt( | |
| delay, | |
| cfg.node_cls, | |
| cfg.wait_label if cfg.monitor_progress else None, | |
| start_time if cfg.monitor_progress else None, | |
| cfg.estimated_total, | |
| display_callback=_display_time_progress if cfg.monitor_progress else None, | |
| ) | |
| delay *= cfg.retry_backoff | |
| continue | |
| msg = _friendly_http_message(resp.status, body) | |
| try: | |
| request_logger.log_request_response( | |
| operation_id=operation_id, | |
| request_method=method, | |
| request_url=url, | |
| response_status_code=resp.status, | |
| response_headers=dict(resp.headers), | |
| response_content=body, | |
| error_message=msg, | |
| ) | |
| except Exception as _log_e: | |
| logging.debug("[DEBUG] response logging failed: %s", _log_e) | |
| raise Exception(msg) | |
| if expect_binary: | |
| buff = bytearray() | |
| last_tick = time.monotonic() | |
| async for chunk in resp.content.iter_chunked(64 * 1024): | |
| buff.extend(chunk) | |
| now = time.monotonic() | |
| if now - last_tick >= 1.0: | |
| last_tick = now | |
| if is_processing_interrupted(): | |
| raise ProcessingInterrupted("Task cancelled") | |
| if cfg.monitor_progress: | |
| _display_time_progress( | |
| cfg.node_cls, cfg.wait_label, int(now - start_time), cfg.estimated_total | |
| ) | |
| bytes_payload = bytes(buff) | |
| operation_succeeded = True | |
| final_elapsed_seconds = int(time.monotonic() - start_time) | |
| try: | |
| request_logger.log_request_response( | |
| operation_id=operation_id, | |
| request_method=method, | |
| request_url=url, | |
| response_status_code=resp.status, | |
| response_headers=dict(resp.headers), | |
| response_content=bytes_payload, | |
| ) | |
| except Exception as _log_e: | |
| logging.debug("[DEBUG] response logging failed: %s", _log_e) | |
| return bytes_payload | |
| else: | |
| try: | |
| payload = await resp.json() | |
| response_content_to_log: Any = payload | |
| except (ContentTypeError, json.JSONDecodeError): | |
| text = await resp.text() | |
| try: | |
| payload = json.loads(text) if text else {} | |
| except json.JSONDecodeError: | |
| payload = {"_raw": text} | |
| response_content_to_log = payload if isinstance(payload, dict) else text | |
| operation_succeeded = True | |
| final_elapsed_seconds = int(time.monotonic() - start_time) | |
| try: | |
| request_logger.log_request_response( | |
| operation_id=operation_id, | |
| request_method=method, | |
| request_url=url, | |
| response_status_code=resp.status, | |
| response_headers=dict(resp.headers), | |
| response_content=response_content_to_log, | |
| ) | |
| except Exception as _log_e: | |
| logging.debug("[DEBUG] response logging failed: %s", _log_e) | |
| return payload | |
| except ProcessingInterrupted: | |
| logging.debug("Polling was interrupted by user") | |
| raise | |
| except (ClientError, OSError) as e: | |
| if attempt <= cfg.max_retries: | |
| logging.warning( | |
| "Connection error calling %s %s. Retrying in %.2fs (%d/%d): %s", | |
| method, | |
| url, | |
| delay, | |
| attempt, | |
| cfg.max_retries, | |
| str(e), | |
| ) | |
| try: | |
| request_logger.log_request_response( | |
| operation_id=operation_id, | |
| request_method=method, | |
| request_url=url, | |
| request_headers=dict(payload_headers) if payload_headers else None, | |
| request_params=dict(params) if params else None, | |
| request_data=request_body_log, | |
| error_message=f"{type(e).__name__}: {str(e)} (will retry)", | |
| ) | |
| except Exception as _log_e: | |
| logging.debug("[DEBUG] request error logging failed: %s", _log_e) | |
| await sleep_with_interrupt( | |
| delay, | |
| cfg.node_cls, | |
| cfg.wait_label if cfg.monitor_progress else None, | |
| start_time if cfg.monitor_progress else None, | |
| cfg.estimated_total, | |
| display_callback=_display_time_progress if cfg.monitor_progress else None, | |
| ) | |
| delay *= cfg.retry_backoff | |
| continue | |
| diag = await _diagnose_connectivity() | |
| if not diag["internet_accessible"]: | |
| try: | |
| request_logger.log_request_response( | |
| operation_id=operation_id, | |
| request_method=method, | |
| request_url=url, | |
| request_headers=dict(payload_headers) if payload_headers else None, | |
| request_params=dict(params) if params else None, | |
| request_data=request_body_log, | |
| error_message=f"LocalNetworkError: {str(e)}", | |
| ) | |
| except Exception as _log_e: | |
| logging.debug("[DEBUG] final error logging failed: %s", _log_e) | |
| raise LocalNetworkError( | |
| "Unable to connect to the API server due to local network issues. " | |
| "Please check your internet connection and try again." | |
| ) from e | |
| try: | |
| request_logger.log_request_response( | |
| operation_id=operation_id, | |
| request_method=method, | |
| request_url=url, | |
| request_headers=dict(payload_headers) if payload_headers else None, | |
| request_params=dict(params) if params else None, | |
| request_data=request_body_log, | |
| error_message=f"ApiServerError: {str(e)}", | |
| ) | |
| except Exception as _log_e: | |
| logging.debug("[DEBUG] final error logging failed: %s", _log_e) | |
| raise ApiServerError( | |
| f"The API server at {default_base_url()} is currently unreachable. " | |
| f"The service may be experiencing issues." | |
| ) from e | |
| finally: | |
| stop_event.set() | |
| if monitor_task: | |
| monitor_task.cancel() | |
| with contextlib.suppress(Exception): | |
| await monitor_task | |
| if sess: | |
| with contextlib.suppress(Exception): | |
| await sess.close() | |
| if operation_succeeded and cfg.monitor_progress and cfg.final_label_on_success: | |
| _display_time_progress( | |
| cfg.node_cls, | |
| status=cfg.final_label_on_success, | |
| elapsed_seconds=( | |
| final_elapsed_seconds | |
| if final_elapsed_seconds is not None | |
| else int(time.monotonic() - start_time) | |
| ), | |
| estimated_total=cfg.estimated_total, | |
| price=None, | |
| is_queued=False, | |
| processing_elapsed_seconds=final_elapsed_seconds, | |
| ) | |
| def _validate_or_raise(response_model: Type[M], payload: Any) -> M: | |
| try: | |
| return response_model.model_validate(payload) | |
| except Exception as e: | |
| logging.error( | |
| "Response validation failed for %s: %s", | |
| getattr(response_model, "__name__", response_model), | |
| e, | |
| ) | |
| raise Exception( | |
| f"Response validation failed for {getattr(response_model, '__name__', response_model)}: {e}" | |
| ) from e | |
| def _wrap_model_extractor( | |
| response_model: Type[M], | |
| extractor: Optional[Callable[[M], Any]], | |
| ) -> Optional[Callable[[dict[str, Any]], Any]]: | |
| """Wrap a typed extractor so it can be used by the dict-based poller. | |
| Validates the dict into `response_model` before invoking `extractor`. | |
| Uses a small per-wrapper cache keyed by `id(dict)` to avoid re-validating | |
| the same response for multiple extractors in a single poll attempt. | |
| """ | |
| if extractor is None: | |
| return None | |
| _cache: dict[int, M] = {} | |
| def _wrapped(d: dict[str, Any]) -> Any: | |
| try: | |
| key = id(d) | |
| model = _cache.get(key) | |
| if model is None: | |
| model = response_model.model_validate(d) | |
| _cache[key] = model | |
| return extractor(model) | |
| except Exception as e: | |
| logging.error("Extractor failed (typed -> dict wrapper): %s", e) | |
| raise | |
| return _wrapped | |
| def _normalize_statuses(values: Optional[Iterable[Union[str, int]]]) -> set[Union[str, int]]: | |
| if not values: | |
| return set() | |
| out: set[Union[str, int]] = set() | |
| for v in values: | |
| nv = _normalize_status_value(v) | |
| if nv is not None: | |
| out.add(nv) | |
| return out | |
| def _normalize_status_value(val: Union[str, int, None]) -> Union[str, int, None]: | |
| if isinstance(val, str): | |
| return val.strip().lower() | |
| return val | |