| import asyncio
|
| import concurrent.futures
|
| import contextvars
|
| import functools
|
| import inspect
|
| import logging
|
| import os
|
| import textwrap
|
| import threading
|
| from enum import Enum
|
| from typing import Optional, Type, get_origin, get_args
|
|
|
|
|
| class TypeTracker:
|
| """Tracks types discovered during stub generation for automatic import generation."""
|
|
|
| def __init__(self):
|
| self.discovered_types = {}
|
| self.builtin_types = {
|
| "Any",
|
| "Dict",
|
| "List",
|
| "Optional",
|
| "Tuple",
|
| "Union",
|
| "Set",
|
| "Sequence",
|
| "cast",
|
| "NamedTuple",
|
| "str",
|
| "int",
|
| "float",
|
| "bool",
|
| "None",
|
| "bytes",
|
| "object",
|
| "type",
|
| "dict",
|
| "list",
|
| "tuple",
|
| "set",
|
| }
|
| self.already_imported = (
|
| set()
|
| )
|
|
|
| def track_type(self, annotation):
|
| """Track a type annotation and record its module/import info."""
|
| if annotation is None or annotation is type(None):
|
| return
|
|
|
|
|
| type_name = getattr(annotation, "__name__", None)
|
| if type_name and (
|
| type_name in self.builtin_types or type_name in self.already_imported
|
| ):
|
| return
|
|
|
|
|
| module = getattr(annotation, "__module__", None)
|
| qualname = getattr(annotation, "__qualname__", type_name or "")
|
|
|
|
|
| if module == "typing":
|
| return
|
|
|
|
|
| if module == "types" and type_name in ("UnionType", "GenericAlias"):
|
| return
|
|
|
| if module and module not in ["builtins", "__main__"]:
|
|
|
| if type_name:
|
| self.discovered_types[type_name] = (module, qualname)
|
|
|
| def get_imports(self, main_module_name: str) -> list[str]:
|
| """Generate import statements for all discovered types."""
|
| imports = []
|
| imports_by_module = {}
|
|
|
| for type_name, (module, qualname) in sorted(self.discovered_types.items()):
|
|
|
| if main_module_name and module == main_module_name:
|
| continue
|
|
|
| if module not in imports_by_module:
|
| imports_by_module[module] = []
|
| if type_name not in imports_by_module[module]:
|
| imports_by_module[module].append(type_name)
|
|
|
|
|
| for module, types in sorted(imports_by_module.items()):
|
| if len(types) == 1:
|
| imports.append(f"from {module} import {types[0]}")
|
| else:
|
| imports.append(f"from {module} import {', '.join(sorted(set(types)))}")
|
|
|
| return imports
|
|
|
|
|
| class AsyncToSyncConverter:
|
| """
|
| Provides utilities to convert async classes to sync classes with proper type hints.
|
| """
|
|
|
| _thread_pool: Optional[concurrent.futures.ThreadPoolExecutor] = None
|
| _thread_pool_lock = threading.Lock()
|
| _thread_pool_initialized = False
|
|
|
| @classmethod
|
| def get_thread_pool(cls, max_workers=None) -> concurrent.futures.ThreadPoolExecutor:
|
| """Get or create the shared thread pool with proper thread-safe initialization."""
|
|
|
| if cls._thread_pool_initialized:
|
| assert cls._thread_pool is not None, "Thread pool should be initialized"
|
| return cls._thread_pool
|
|
|
|
|
| with cls._thread_pool_lock:
|
| if not cls._thread_pool_initialized:
|
| cls._thread_pool = concurrent.futures.ThreadPoolExecutor(
|
| max_workers=max_workers, thread_name_prefix="async_to_sync_"
|
| )
|
| cls._thread_pool_initialized = True
|
|
|
|
|
| assert cls._thread_pool is not None
|
| return cls._thread_pool
|
|
|
| @classmethod
|
| def run_async_in_thread(cls, coro_func, *args, **kwargs):
|
| """
|
| Run an async function in a separate thread from the thread pool.
|
| Blocks until the async function completes.
|
| Properly propagates contextvars between threads and manages event loops.
|
| """
|
|
|
| context = contextvars.copy_context()
|
|
|
|
|
| result_container: dict = {"result": None, "exception": None}
|
|
|
|
|
| def run_in_thread():
|
|
|
| loop = asyncio.new_event_loop()
|
| asyncio.set_event_loop(loop)
|
|
|
| try:
|
|
|
| async def run_with_context():
|
|
|
| return await coro_func(*args, **kwargs)
|
|
|
|
|
|
|
| result = context.run(loop.run_until_complete, run_with_context())
|
| result_container["result"] = result
|
| except Exception as e:
|
|
|
| result_container["exception"] = e
|
| finally:
|
|
|
| try:
|
|
|
| pending = asyncio.all_tasks(loop)
|
| for task in pending:
|
| task.cancel()
|
|
|
|
|
| if pending:
|
| loop.run_until_complete(
|
| asyncio.gather(*pending, return_exceptions=True)
|
| )
|
| except Exception:
|
| pass
|
|
|
|
|
| loop.close()
|
|
|
|
|
| asyncio.set_event_loop(None)
|
|
|
|
|
| thread_pool = cls.get_thread_pool()
|
| future = thread_pool.submit(run_in_thread)
|
| future.result()
|
|
|
|
|
| if result_container["exception"] is not None:
|
| raise result_container["exception"]
|
|
|
| return result_container["result"]
|
|
|
| @classmethod
|
| def create_sync_class(cls, async_class: Type, thread_pool_size=10) -> Type:
|
| """
|
| Creates a new class with synchronous versions of all async methods.
|
|
|
| Args:
|
| async_class: The async class to convert
|
| thread_pool_size: Size of thread pool to use
|
|
|
| Returns:
|
| A new class with sync versions of all async methods
|
| """
|
| sync_class_name = "ComfyAPISyncStub"
|
| cls.get_thread_pool(thread_pool_size)
|
|
|
|
|
| sync_class_dict = {
|
| "__doc__": async_class.__doc__,
|
| "__module__": async_class.__module__,
|
| "__qualname__": sync_class_name,
|
| "__orig_class__": async_class,
|
| }
|
|
|
|
|
| def __init__(self, *args, **kwargs):
|
| self._async_instance = async_class(*args, **kwargs)
|
|
|
|
|
|
|
| all_annotations = {}
|
| for base_class in reversed(inspect.getmro(async_class)):
|
| if hasattr(base_class, "__annotations__"):
|
| all_annotations.update(base_class.__annotations__)
|
|
|
|
|
| for attr_name, attr_type in all_annotations.items():
|
| if hasattr(self._async_instance, attr_name):
|
|
|
| attr = getattr(self._async_instance, attr_name)
|
|
|
| if hasattr(attr, "__class__"):
|
| from comfy_api.internal.singleton import ProxiedSingleton
|
|
|
| if isinstance(attr, ProxiedSingleton):
|
|
|
| try:
|
| sync_attr_class = cls.create_sync_class(attr.__class__)
|
|
|
| sync_attr = object.__new__(sync_attr_class)
|
| sync_attr._async_instance = attr
|
| setattr(self, attr_name, sync_attr)
|
| except Exception:
|
|
|
| setattr(self, attr_name, attr)
|
| else:
|
|
|
| setattr(self, attr_name, attr)
|
| else:
|
|
|
|
|
| if isinstance(attr_type, type):
|
|
|
| if hasattr(async_class, attr_type.__name__):
|
| inner_class = getattr(async_class, attr_type.__name__)
|
| from comfy_api.internal.singleton import ProxiedSingleton
|
|
|
|
|
| try:
|
|
|
| if issubclass(inner_class, ProxiedSingleton):
|
| async_instance = inner_class.get_instance()
|
| else:
|
| async_instance = inner_class()
|
|
|
|
|
| sync_attr_class = cls.create_sync_class(inner_class)
|
| sync_attr = object.__new__(sync_attr_class)
|
| sync_attr._async_instance = async_instance
|
| setattr(self, attr_name, sync_attr)
|
|
|
| setattr(self._async_instance, attr_name, async_instance)
|
| except Exception as e:
|
| logging.warning(
|
| f"Failed to create instance for {attr_name}: {e}"
|
| )
|
|
|
|
|
| for name, attr in inspect.getmembers(self._async_instance):
|
| if name.startswith("_") or hasattr(self, name):
|
| continue
|
|
|
|
|
|
|
| if isinstance(attr, object) and not isinstance(
|
| attr, (str, int, float, bool, list, dict, tuple)
|
| ):
|
| from comfy_api.internal.singleton import ProxiedSingleton
|
|
|
| if isinstance(attr, ProxiedSingleton):
|
|
|
| try:
|
| sync_attr_class = cls.create_sync_class(attr.__class__)
|
|
|
| sync_attr = object.__new__(sync_attr_class)
|
| sync_attr._async_instance = attr
|
| setattr(self, name, sync_attr)
|
| except Exception:
|
|
|
| setattr(self, name, attr)
|
|
|
| sync_class_dict["__init__"] = __init__
|
|
|
|
|
| for name, method in inspect.getmembers(
|
| async_class, predicate=inspect.isfunction
|
| ):
|
| if name.startswith("_"):
|
| continue
|
|
|
|
|
| if inspect.iscoroutinefunction(method):
|
|
|
| @functools.wraps(method)
|
| def sync_method(self, *args, _method_name=name, **kwargs):
|
| async_method = getattr(self._async_instance, _method_name)
|
| return AsyncToSyncConverter.run_async_in_thread(
|
| async_method, *args, **kwargs
|
| )
|
|
|
|
|
| sync_class_dict[name] = sync_method
|
| else:
|
|
|
| @functools.wraps(method)
|
| def proxy_method(self, *args, _method_name=name, **kwargs):
|
| method = getattr(self._async_instance, _method_name)
|
| return method(*args, **kwargs)
|
|
|
|
|
| sync_class_dict[name] = proxy_method
|
|
|
|
|
| for name, prop in inspect.getmembers(
|
| async_class, lambda x: isinstance(x, property)
|
| ):
|
|
|
| def make_property(name, prop_obj):
|
| def getter(self):
|
| value = getattr(self._async_instance, name)
|
| if inspect.iscoroutinefunction(value):
|
|
|
| def sync_fn(*args, **kwargs):
|
| return AsyncToSyncConverter.run_async_in_thread(
|
| value, *args, **kwargs
|
| )
|
|
|
| return sync_fn
|
| return value
|
|
|
| def setter(self, value):
|
| setattr(self._async_instance, name, value)
|
|
|
| return property(getter, setter if prop_obj.fset else None)
|
|
|
| sync_class_dict[name] = make_property(name, prop)
|
|
|
|
|
| sync_class = type(sync_class_name, (object,), sync_class_dict)
|
|
|
| return sync_class
|
|
|
| @classmethod
|
| def _format_type_annotation(
|
| cls, annotation, type_tracker: Optional[TypeTracker] = None
|
| ) -> str:
|
| """Convert a type annotation to its string representation for stub files."""
|
| if (
|
| annotation is inspect.Parameter.empty
|
| or annotation is inspect.Signature.empty
|
| ):
|
| return "Any"
|
|
|
|
|
| if annotation is type(None):
|
| return "None"
|
|
|
|
|
| if type_tracker:
|
| type_tracker.track_type(annotation)
|
|
|
|
|
| try:
|
| origin = get_origin(annotation)
|
| args = get_args(annotation)
|
|
|
| if origin is not None:
|
|
|
| if type_tracker:
|
| type_tracker.track_type(origin)
|
|
|
|
|
| origin_name = getattr(origin, "__name__", str(origin))
|
| if "." in origin_name:
|
| origin_name = origin_name.split(".")[-1]
|
|
|
|
|
|
|
| if str(origin) == "<class 'types.UnionType'>" or origin_name == "UnionType":
|
| origin_name = "Union"
|
|
|
|
|
| if args:
|
| formatted_args = []
|
| for arg in args:
|
|
|
| if type_tracker:
|
| type_tracker.track_type(arg)
|
| formatted_args.append(cls._format_type_annotation(arg, type_tracker))
|
| return f"{origin_name}[{', '.join(formatted_args)}]"
|
| else:
|
| return origin_name
|
| except (AttributeError, TypeError):
|
|
|
| pass
|
|
|
|
|
| if hasattr(annotation, "__origin__") and hasattr(annotation, "__args__"):
|
| origin = annotation.__origin__
|
| origin_name = (
|
| origin.__name__
|
| if hasattr(origin, "__name__")
|
| else str(origin).split("'")[1]
|
| )
|
|
|
|
|
| args = []
|
| for arg in annotation.__args__:
|
| args.append(cls._format_type_annotation(arg, type_tracker))
|
|
|
| return f"{origin_name}[{', '.join(args)}]"
|
|
|
|
|
| if hasattr(annotation, "__name__"):
|
| return annotation.__name__
|
|
|
|
|
| if hasattr(annotation, "__module__") and hasattr(annotation, "__qualname__"):
|
|
|
| return annotation.__qualname__
|
|
|
|
|
| type_str = str(annotation)
|
|
|
|
|
| if type_str.startswith("<class '") and type_str.endswith("'>"):
|
| type_str = type_str[8:-2]
|
|
|
|
|
| for prefix in ["typing.", "builtins.", "types."]:
|
| if type_str.startswith(prefix):
|
| type_str = type_str[len(prefix) :]
|
|
|
|
|
| if type_str in ("_empty", "inspect._empty"):
|
| return "None"
|
|
|
|
|
| if type_str == "NoneType":
|
| return "None"
|
|
|
| return type_str
|
|
|
| @classmethod
|
| def _extract_coroutine_return_type(cls, annotation):
|
| """Extract the actual return type from a Coroutine annotation."""
|
| if hasattr(annotation, "__args__") and len(annotation.__args__) > 2:
|
|
|
| return annotation.__args__[2]
|
| return annotation
|
|
|
| @classmethod
|
| def _format_parameter_default(cls, default_value) -> str:
|
| """Format a parameter's default value for stub files."""
|
| if default_value is inspect.Parameter.empty:
|
| return ""
|
| elif default_value is None:
|
| return " = None"
|
| elif isinstance(default_value, bool):
|
| return f" = {default_value}"
|
| elif default_value == {}:
|
| return " = {}"
|
| elif default_value == []:
|
| return " = []"
|
| else:
|
| return f" = {default_value}"
|
|
|
| @classmethod
|
| def _format_method_parameters(
|
| cls,
|
| sig: inspect.Signature,
|
| skip_self: bool = True,
|
| type_hints: Optional[dict] = None,
|
| type_tracker: Optional[TypeTracker] = None,
|
| ) -> str:
|
| """Format method parameters for stub files."""
|
| params = []
|
| if type_hints is None:
|
| type_hints = {}
|
|
|
| for i, (param_name, param) in enumerate(sig.parameters.items()):
|
| if i == 0 and param_name == "self" and skip_self:
|
| params.append("self")
|
| else:
|
|
|
| annotation = type_hints.get(param_name, param.annotation)
|
| type_str = cls._format_type_annotation(annotation, type_tracker)
|
|
|
|
|
| default_str = cls._format_parameter_default(param.default)
|
|
|
|
|
| if annotation is inspect.Parameter.empty:
|
| params.append(f"{param_name}: Any{default_str}")
|
| else:
|
| params.append(f"{param_name}: {type_str}{default_str}")
|
|
|
| return ", ".join(params)
|
|
|
| @classmethod
|
| def _generate_method_signature(
|
| cls,
|
| method_name: str,
|
| method,
|
| is_async: bool = False,
|
| type_tracker: Optional[TypeTracker] = None,
|
| ) -> str:
|
| """Generate a complete method signature for stub files."""
|
| sig = inspect.signature(method)
|
|
|
|
|
| try:
|
| from typing import get_type_hints
|
| type_hints = get_type_hints(method)
|
| except Exception:
|
|
|
| type_hints = {}
|
|
|
|
|
| return_annotation = type_hints.get('return', sig.return_annotation)
|
| if is_async and inspect.iscoroutinefunction(method):
|
| return_annotation = cls._extract_coroutine_return_type(return_annotation)
|
|
|
|
|
| params_str = cls._format_method_parameters(sig, type_hints=type_hints, type_tracker=type_tracker)
|
|
|
|
|
| return_type = cls._format_type_annotation(return_annotation, type_tracker)
|
| if return_annotation is inspect.Signature.empty:
|
| return_type = "None"
|
|
|
| return f"def {method_name}({params_str}) -> {return_type}: ..."
|
|
|
| @classmethod
|
| def _generate_imports(
|
| cls, async_class: Type, type_tracker: TypeTracker
|
| ) -> list[str]:
|
| """Generate import statements for the stub file."""
|
| imports = []
|
|
|
|
|
| imports.append(
|
| "from typing import Any, Dict, List, Optional, Tuple, Union, Set, Sequence, cast, NamedTuple"
|
| )
|
|
|
|
|
| if async_class.__module__ != "builtins":
|
| module = inspect.getmodule(async_class)
|
| additional_types = []
|
|
|
| if module:
|
|
|
| module_all = getattr(module, "__all__", None)
|
|
|
| for name, obj in sorted(inspect.getmembers(module)):
|
| if isinstance(obj, type):
|
|
|
|
|
| if module_all is not None and name not in module_all:
|
|
|
| if name not in type_tracker.discovered_types:
|
| continue
|
|
|
|
|
| if issubclass(obj, tuple) and hasattr(obj, "_fields"):
|
| additional_types.append(name)
|
|
|
| type_tracker.already_imported.add(name)
|
|
|
| elif issubclass(obj, Enum) and name != "Enum":
|
| additional_types.append(name)
|
|
|
| type_tracker.already_imported.add(name)
|
|
|
| if additional_types:
|
| type_imports = ", ".join([async_class.__name__] + additional_types)
|
| imports.append(f"from {async_class.__module__} import {type_imports}")
|
| else:
|
| imports.append(
|
| f"from {async_class.__module__} import {async_class.__name__}"
|
| )
|
|
|
|
|
|
|
| imports.extend(
|
| type_tracker.get_imports(main_module_name=async_class.__module__)
|
| )
|
|
|
|
|
| if hasattr(inspect.getmodule(async_class), "__name__"):
|
| module_name = inspect.getmodule(async_class).__name__
|
| if "." in module_name:
|
| base_module = module_name.split(".")[0]
|
|
|
| if not any(imp.startswith(f"from {base_module}") for imp in imports):
|
| imports.append(f"import {base_module}")
|
|
|
| return imports
|
|
|
| @classmethod
|
| def _get_class_attributes(cls, async_class: Type) -> list[tuple[str, Type]]:
|
| """Extract class attributes that are classes themselves."""
|
| class_attributes = []
|
|
|
|
|
| for name, attr in sorted(inspect.getmembers(async_class)):
|
| if isinstance(attr, type) and not name.startswith("_"):
|
| class_attributes.append((name, attr))
|
| elif (
|
| hasattr(async_class, "__annotations__")
|
| and name in async_class.__annotations__
|
| ):
|
| annotation = async_class.__annotations__[name]
|
| if isinstance(annotation, type):
|
| class_attributes.append((name, annotation))
|
|
|
| return class_attributes
|
|
|
| @classmethod
|
| def _generate_inner_class_stub(
|
| cls,
|
| name: str,
|
| attr: Type,
|
| indent: str = " ",
|
| type_tracker: Optional[TypeTracker] = None,
|
| ) -> list[str]:
|
| """Generate stub for an inner class."""
|
| stub_lines = []
|
| stub_lines.append(f"{indent}class {name}Sync:")
|
|
|
|
|
| if hasattr(attr, "__doc__") and attr.__doc__:
|
| stub_lines.extend(
|
| cls._format_docstring_for_stub(attr.__doc__, f"{indent} ")
|
| )
|
|
|
|
|
| if hasattr(attr, "__init__"):
|
| try:
|
| init_method = getattr(attr, "__init__")
|
| init_sig = inspect.signature(init_method)
|
|
|
|
|
| try:
|
| from typing import get_type_hints
|
| init_hints = get_type_hints(init_method)
|
| except Exception:
|
| init_hints = {}
|
|
|
|
|
| params_str = cls._format_method_parameters(
|
| init_sig, type_hints=init_hints, type_tracker=type_tracker
|
| )
|
|
|
| if hasattr(init_method, "__doc__") and init_method.__doc__:
|
| stub_lines.extend(
|
| cls._format_docstring_for_stub(
|
| init_method.__doc__, f"{indent} "
|
| )
|
| )
|
| stub_lines.append(
|
| f"{indent} def __init__({params_str}) -> None: ..."
|
| )
|
| except (ValueError, TypeError):
|
| stub_lines.append(
|
| f"{indent} def __init__(self, *args, **kwargs) -> None: ..."
|
| )
|
|
|
|
|
| has_methods = False
|
| for method_name, method in sorted(
|
| inspect.getmembers(attr, predicate=inspect.isfunction)
|
| ):
|
| if method_name.startswith("_"):
|
| continue
|
|
|
| has_methods = True
|
| try:
|
|
|
| if method.__doc__:
|
| stub_lines.extend(
|
| cls._format_docstring_for_stub(method.__doc__, f"{indent} ")
|
| )
|
|
|
| method_sig = cls._generate_method_signature(
|
| method_name, method, is_async=True, type_tracker=type_tracker
|
| )
|
| stub_lines.append(f"{indent} {method_sig}")
|
| except (ValueError, TypeError):
|
| stub_lines.append(
|
| f"{indent} def {method_name}(self, *args, **kwargs): ..."
|
| )
|
|
|
| if not has_methods:
|
| stub_lines.append(f"{indent} pass")
|
|
|
| return stub_lines
|
|
|
| @classmethod
|
| def _format_docstring_for_stub(
|
| cls, docstring: str, indent: str = " "
|
| ) -> list[str]:
|
| """Format a docstring for inclusion in a stub file with proper indentation."""
|
| if not docstring:
|
| return []
|
|
|
|
|
| dedented = textwrap.dedent(docstring).strip()
|
|
|
|
|
| lines = dedented.split("\n")
|
|
|
|
|
| result = []
|
| result.append(f'{indent}"""')
|
|
|
| for line in lines:
|
| if line.strip():
|
| result.append(f"{indent}{line}")
|
| else:
|
| result.append("")
|
|
|
| result.append(f'{indent}"""')
|
| return result
|
|
|
| @classmethod
|
| def _post_process_stub_content(cls, stub_content: list[str]) -> list[str]:
|
| """Post-process stub content to fix any remaining issues."""
|
| processed = []
|
|
|
| for line in stub_content:
|
|
|
| if line.startswith(("from ", "import ")):
|
| processed.append(line)
|
| continue
|
|
|
|
|
| if (
|
| line.strip().startswith("def ")
|
| and line.strip().endswith(": ...")
|
| and ") -> " not in line
|
| ):
|
|
|
| line = line.replace(": ...", " -> None: ...")
|
|
|
| processed.append(line)
|
|
|
| return processed
|
|
|
| @classmethod
|
| def generate_stub_file(cls, async_class: Type, sync_class: Type) -> None:
|
| """
|
| Generate a .pyi stub file for the sync class to help IDEs with type checking.
|
| """
|
| try:
|
|
|
| if async_class.__module__ == "__main__":
|
| return
|
|
|
| module = inspect.getmodule(async_class)
|
| if not module:
|
| return
|
|
|
| module_path = module.__file__
|
| if not module_path:
|
| return
|
|
|
|
|
| module_dir = os.path.dirname(module_path)
|
| stub_dir = os.path.join(module_dir, "generated")
|
|
|
|
|
| os.makedirs(stub_dir, exist_ok=True)
|
|
|
| module_name = os.path.basename(module_path)
|
| if module_name.endswith(".py"):
|
| module_name = module_name[:-3]
|
|
|
| sync_stub_path = os.path.join(stub_dir, f"{sync_class.__name__}.pyi")
|
|
|
|
|
| type_tracker = TypeTracker()
|
|
|
| stub_content = []
|
|
|
|
|
|
|
| imports_placeholder_index = len(stub_content)
|
| stub_content.append("")
|
|
|
|
|
| stub_content.append(f"class {sync_class.__name__}:")
|
|
|
|
|
| if async_class.__doc__:
|
| stub_content.extend(
|
| cls._format_docstring_for_stub(async_class.__doc__, " ")
|
| )
|
|
|
|
|
| try:
|
| init_method = async_class.__init__
|
| init_signature = inspect.signature(init_method)
|
|
|
|
|
| try:
|
| from typing import get_type_hints
|
| init_hints = get_type_hints(init_method)
|
| except Exception:
|
| init_hints = {}
|
|
|
|
|
| params_str = cls._format_method_parameters(
|
| init_signature, type_hints=init_hints, type_tracker=type_tracker
|
| )
|
|
|
| if hasattr(init_method, "__doc__") and init_method.__doc__:
|
| stub_content.extend(
|
| cls._format_docstring_for_stub(init_method.__doc__, " ")
|
| )
|
| stub_content.append(f" def __init__({params_str}) -> None: ...")
|
| except (ValueError, TypeError):
|
| stub_content.append(
|
| " def __init__(self, *args, **kwargs) -> None: ..."
|
| )
|
|
|
| stub_content.append("")
|
|
|
|
|
| class_attributes = cls._get_class_attributes(async_class)
|
|
|
|
|
| for name, attr in class_attributes:
|
| inner_class_stub = cls._generate_inner_class_stub(
|
| name, attr, type_tracker=type_tracker
|
| )
|
| stub_content.extend(inner_class_stub)
|
| stub_content.append("")
|
|
|
|
|
| processed_methods = set()
|
| for name, method in sorted(
|
| inspect.getmembers(async_class, predicate=inspect.isfunction)
|
| ):
|
| if name.startswith("_") or name in processed_methods:
|
| continue
|
|
|
| processed_methods.add(name)
|
|
|
| try:
|
| method_sig = cls._generate_method_signature(
|
| name, method, is_async=True, type_tracker=type_tracker
|
| )
|
|
|
|
|
| if method.__doc__:
|
| stub_content.extend(
|
| cls._format_docstring_for_stub(method.__doc__, " ")
|
| )
|
|
|
| stub_content.append(f" {method_sig}")
|
|
|
| stub_content.append("")
|
|
|
| except (ValueError, TypeError):
|
|
|
| stub_content.append(f" def {name}(self, *args, **kwargs): ...")
|
| stub_content.append("")
|
|
|
|
|
| for name, prop in sorted(
|
| inspect.getmembers(async_class, lambda x: isinstance(x, property))
|
| ):
|
| stub_content.append(" @property")
|
| stub_content.append(f" def {name}(self) -> Any: ...")
|
| if prop.fset:
|
| stub_content.append(f" @{name}.setter")
|
| stub_content.append(
|
| f" def {name}(self, value: Any) -> None: ..."
|
| )
|
| stub_content.append("")
|
|
|
|
|
|
|
| attribute_mappings = {}
|
|
|
|
|
|
|
| all_annotations = {}
|
| for base_class in reversed(inspect.getmro(async_class)):
|
| if hasattr(base_class, "__annotations__"):
|
| all_annotations.update(base_class.__annotations__)
|
|
|
| for attr_name, attr_type in sorted(all_annotations.items()):
|
| for class_name, class_type in class_attributes:
|
|
|
| if (
|
| attr_type == class_type
|
| or (hasattr(attr_type, "__name__") and attr_type.__name__ == class_name)
|
| or (isinstance(attr_type, str) and attr_type == class_name)
|
| ):
|
| attribute_mappings[class_name] = attr_name
|
|
|
|
|
|
|
|
|
| for class_name, class_type in class_attributes:
|
|
|
| attr_name = attribute_mappings.get(class_name, class_name)
|
|
|
|
|
| stub_content.append(f" {attr_name}: {class_name}Sync")
|
|
|
| stub_content.append("")
|
|
|
|
|
| imports = cls._generate_imports(async_class, type_tracker)
|
|
|
|
|
| seen = set()
|
| unique_imports = []
|
| for imp in imports:
|
| if imp not in seen:
|
| seen.add(imp)
|
| unique_imports.append(imp)
|
| else:
|
| logging.warning(f"Duplicate import detected: {imp}")
|
|
|
|
|
| stub_content[imports_placeholder_index : imports_placeholder_index + 1] = (
|
| unique_imports
|
| )
|
|
|
|
|
| stub_content = cls._post_process_stub_content(stub_content)
|
|
|
|
|
| with open(sync_stub_path, "w") as f:
|
| f.write("\n".join(stub_content))
|
|
|
| logging.info(f"Generated stub file: {sync_stub_path}")
|
|
|
| except Exception as e:
|
|
|
| logging.error(
|
| f"Error generating stub file for {sync_class.__name__}: {str(e)}"
|
| )
|
| import traceback
|
|
|
| logging.error(traceback.format_exc())
|
|
|
|
|
| def create_sync_class(async_class: Type, thread_pool_size=10) -> Type:
|
| """
|
| Creates a sync version of an async class
|
|
|
| Args:
|
| async_class: The async class to convert
|
| thread_pool_size: Size of thread pool to use
|
|
|
| Returns:
|
| A new class with sync versions of all async methods
|
| """
|
| return AsyncToSyncConverter.create_sync_class(async_class, thread_pool_size)
|
|
|