| import asyncio |
| import concurrent.futures |
| import contextvars |
| import functools |
| import inspect |
| import logging |
| import os |
| import textwrap |
| import threading |
| from enum import Enum |
| from typing import Optional, Type, get_origin, get_args |
|
|
|
|
| class TypeTracker: |
| """Tracks types discovered during stub generation for automatic import generation.""" |
|
|
| def __init__(self): |
| self.discovered_types = {} |
| self.builtin_types = { |
| "Any", |
| "Dict", |
| "List", |
| "Optional", |
| "Tuple", |
| "Union", |
| "Set", |
| "Sequence", |
| "cast", |
| "NamedTuple", |
| "str", |
| "int", |
| "float", |
| "bool", |
| "None", |
| "bytes", |
| "object", |
| "type", |
| "dict", |
| "list", |
| "tuple", |
| "set", |
| } |
| self.already_imported = ( |
| set() |
| ) |
|
|
| def track_type(self, annotation): |
| """Track a type annotation and record its module/import info.""" |
| if annotation is None or annotation is type(None): |
| return |
|
|
| |
| type_name = getattr(annotation, "__name__", None) |
| if type_name and ( |
| type_name in self.builtin_types or type_name in self.already_imported |
| ): |
| return |
|
|
| |
| module = getattr(annotation, "__module__", None) |
| qualname = getattr(annotation, "__qualname__", type_name or "") |
|
|
| |
| if module == "typing": |
| return |
|
|
| |
| if module == "types" and type_name in ("UnionType", "GenericAlias"): |
| return |
|
|
| if module and module not in ["builtins", "__main__"]: |
| |
| if type_name: |
| self.discovered_types[type_name] = (module, qualname) |
|
|
| def get_imports(self, main_module_name: str) -> list[str]: |
| """Generate import statements for all discovered types.""" |
| imports = [] |
| imports_by_module = {} |
|
|
| for type_name, (module, qualname) in sorted(self.discovered_types.items()): |
| |
| if main_module_name and module == main_module_name: |
| continue |
|
|
| if module not in imports_by_module: |
| imports_by_module[module] = [] |
| if type_name not in imports_by_module[module]: |
| imports_by_module[module].append(type_name) |
|
|
| |
| for module, types in sorted(imports_by_module.items()): |
| if len(types) == 1: |
| imports.append(f"from {module} import {types[0]}") |
| else: |
| imports.append(f"from {module} import {', '.join(sorted(set(types)))}") |
|
|
| return imports |
|
|
|
|
| class AsyncToSyncConverter: |
| """ |
| Provides utilities to convert async classes to sync classes with proper type hints. |
| """ |
|
|
| _thread_pool: Optional[concurrent.futures.ThreadPoolExecutor] = None |
| _thread_pool_lock = threading.Lock() |
| _thread_pool_initialized = False |
|
|
| @classmethod |
| def get_thread_pool(cls, max_workers=None) -> concurrent.futures.ThreadPoolExecutor: |
| """Get or create the shared thread pool with proper thread-safe initialization.""" |
| |
| if cls._thread_pool_initialized: |
| assert cls._thread_pool is not None, "Thread pool should be initialized" |
| return cls._thread_pool |
|
|
| |
| with cls._thread_pool_lock: |
| if not cls._thread_pool_initialized: |
| cls._thread_pool = concurrent.futures.ThreadPoolExecutor( |
| max_workers=max_workers, thread_name_prefix="async_to_sync_" |
| ) |
| cls._thread_pool_initialized = True |
|
|
| |
| assert cls._thread_pool is not None |
| return cls._thread_pool |
|
|
| @classmethod |
| def run_async_in_thread(cls, coro_func, *args, **kwargs): |
| """ |
| Run an async function in a separate thread from the thread pool. |
| Blocks until the async function completes. |
| Properly propagates contextvars between threads and manages event loops. |
| """ |
| |
| context = contextvars.copy_context() |
|
|
| |
| result_container: dict = {"result": None, "exception": None} |
|
|
| |
| def run_in_thread(): |
| |
| loop = asyncio.new_event_loop() |
| asyncio.set_event_loop(loop) |
|
|
| try: |
| |
| async def run_with_context(): |
| |
| return await coro_func(*args, **kwargs) |
|
|
| |
| |
| result = context.run(loop.run_until_complete, run_with_context()) |
| result_container["result"] = result |
| except Exception as e: |
| |
| result_container["exception"] = e |
| finally: |
| |
| try: |
| |
| pending = asyncio.all_tasks(loop) |
| for task in pending: |
| task.cancel() |
|
|
| |
| if pending: |
| loop.run_until_complete( |
| asyncio.gather(*pending, return_exceptions=True) |
| ) |
| except Exception: |
| pass |
|
|
| |
| loop.close() |
|
|
| |
| asyncio.set_event_loop(None) |
|
|
| |
| thread_pool = cls.get_thread_pool() |
| future = thread_pool.submit(run_in_thread) |
| future.result() |
|
|
| |
| if result_container["exception"] is not None: |
| raise result_container["exception"] |
|
|
| return result_container["result"] |
|
|
| @classmethod |
| def create_sync_class(cls, async_class: Type, thread_pool_size=10) -> Type: |
| """ |
| Creates a new class with synchronous versions of all async methods. |
| |
| Args: |
| async_class: The async class to convert |
| thread_pool_size: Size of thread pool to use |
| |
| Returns: |
| A new class with sync versions of all async methods |
| """ |
| sync_class_name = "ComfyAPISyncStub" |
| cls.get_thread_pool(thread_pool_size) |
|
|
| |
| sync_class_dict = { |
| "__doc__": async_class.__doc__, |
| "__module__": async_class.__module__, |
| "__qualname__": sync_class_name, |
| "__orig_class__": async_class, |
| } |
|
|
| |
| def __init__(self, *args, **kwargs): |
| self._async_instance = async_class(*args, **kwargs) |
|
|
| |
| |
| all_annotations = {} |
| for base_class in reversed(inspect.getmro(async_class)): |
| if hasattr(base_class, "__annotations__"): |
| all_annotations.update(base_class.__annotations__) |
|
|
| |
| for attr_name, attr_type in all_annotations.items(): |
| if hasattr(self._async_instance, attr_name): |
| |
| attr = getattr(self._async_instance, attr_name) |
| |
| if hasattr(attr, "__class__"): |
| from comfy_api.internal.singleton import ProxiedSingleton |
|
|
| if isinstance(attr, ProxiedSingleton): |
| |
| try: |
| sync_attr_class = cls.create_sync_class(attr.__class__) |
| |
| sync_attr = object.__new__(sync_attr_class) |
| sync_attr._async_instance = attr |
| setattr(self, attr_name, sync_attr) |
| except Exception: |
| |
| setattr(self, attr_name, attr) |
| else: |
| |
| setattr(self, attr_name, attr) |
| else: |
| |
| |
| if isinstance(attr_type, type): |
| |
| if hasattr(async_class, attr_type.__name__): |
| inner_class = getattr(async_class, attr_type.__name__) |
| from comfy_api.internal.singleton import ProxiedSingleton |
|
|
| |
| try: |
| |
| if issubclass(inner_class, ProxiedSingleton): |
| async_instance = inner_class.get_instance() |
| else: |
| async_instance = inner_class() |
|
|
| |
| sync_attr_class = cls.create_sync_class(inner_class) |
| sync_attr = object.__new__(sync_attr_class) |
| sync_attr._async_instance = async_instance |
| setattr(self, attr_name, sync_attr) |
| |
| setattr(self._async_instance, attr_name, async_instance) |
| except Exception as e: |
| logging.warning( |
| f"Failed to create instance for {attr_name}: {e}" |
| ) |
|
|
| |
| for name, attr in inspect.getmembers(self._async_instance): |
| if name.startswith("_") or hasattr(self, name): |
| continue |
|
|
| |
| |
| if isinstance(attr, object) and not isinstance( |
| attr, (str, int, float, bool, list, dict, tuple) |
| ): |
| from comfy_api.internal.singleton import ProxiedSingleton |
|
|
| if isinstance(attr, ProxiedSingleton): |
| |
| try: |
| sync_attr_class = cls.create_sync_class(attr.__class__) |
| |
| sync_attr = object.__new__(sync_attr_class) |
| sync_attr._async_instance = attr |
| setattr(self, name, sync_attr) |
| except Exception: |
| |
| setattr(self, name, attr) |
|
|
| sync_class_dict["__init__"] = __init__ |
|
|
| |
| for name, method in inspect.getmembers( |
| async_class, predicate=inspect.isfunction |
| ): |
| if name.startswith("_"): |
| continue |
|
|
| |
| if inspect.iscoroutinefunction(method): |
| |
| @functools.wraps(method) |
| def sync_method(self, *args, _method_name=name, **kwargs): |
| async_method = getattr(self._async_instance, _method_name) |
| return AsyncToSyncConverter.run_async_in_thread( |
| async_method, *args, **kwargs |
| ) |
|
|
| |
| sync_class_dict[name] = sync_method |
| else: |
| |
| @functools.wraps(method) |
| def proxy_method(self, *args, _method_name=name, **kwargs): |
| method = getattr(self._async_instance, _method_name) |
| return method(*args, **kwargs) |
|
|
| |
| sync_class_dict[name] = proxy_method |
|
|
| |
| for name, prop in inspect.getmembers( |
| async_class, lambda x: isinstance(x, property) |
| ): |
|
|
| def make_property(name, prop_obj): |
| def getter(self): |
| value = getattr(self._async_instance, name) |
| if inspect.iscoroutinefunction(value): |
|
|
| def sync_fn(*args, **kwargs): |
| return AsyncToSyncConverter.run_async_in_thread( |
| value, *args, **kwargs |
| ) |
|
|
| return sync_fn |
| return value |
|
|
| def setter(self, value): |
| setattr(self._async_instance, name, value) |
|
|
| return property(getter, setter if prop_obj.fset else None) |
|
|
| sync_class_dict[name] = make_property(name, prop) |
|
|
| |
| sync_class = type(sync_class_name, (object,), sync_class_dict) |
|
|
| return sync_class |
|
|
| @classmethod |
| def _format_type_annotation( |
| cls, annotation, type_tracker: Optional[TypeTracker] = None |
| ) -> str: |
| """Convert a type annotation to its string representation for stub files.""" |
| if ( |
| annotation is inspect.Parameter.empty |
| or annotation is inspect.Signature.empty |
| ): |
| return "Any" |
|
|
| |
| if annotation is type(None): |
| return "None" |
|
|
| |
| if type_tracker: |
| type_tracker.track_type(annotation) |
|
|
| |
| try: |
| origin = get_origin(annotation) |
| args = get_args(annotation) |
|
|
| if origin is not None: |
| |
| if type_tracker: |
| type_tracker.track_type(origin) |
|
|
| |
| origin_name = getattr(origin, "__name__", str(origin)) |
| if "." in origin_name: |
| origin_name = origin_name.split(".")[-1] |
|
|
| |
| |
| if str(origin) == "<class 'types.UnionType'>" or origin_name == "UnionType": |
| origin_name = "Union" |
|
|
| |
| if args: |
| formatted_args = [] |
| for arg in args: |
| |
| if type_tracker: |
| type_tracker.track_type(arg) |
| formatted_args.append(cls._format_type_annotation(arg, type_tracker)) |
| return f"{origin_name}[{', '.join(formatted_args)}]" |
| else: |
| return origin_name |
| except (AttributeError, TypeError): |
| |
| pass |
|
|
| |
| if hasattr(annotation, "__origin__") and hasattr(annotation, "__args__"): |
| origin = annotation.__origin__ |
| origin_name = ( |
| origin.__name__ |
| if hasattr(origin, "__name__") |
| else str(origin).split("'")[1] |
| ) |
|
|
| |
| args = [] |
| for arg in annotation.__args__: |
| args.append(cls._format_type_annotation(arg, type_tracker)) |
|
|
| return f"{origin_name}[{', '.join(args)}]" |
|
|
| |
| if hasattr(annotation, "__name__"): |
| return annotation.__name__ |
|
|
| |
| if hasattr(annotation, "__module__") and hasattr(annotation, "__qualname__"): |
| |
| return annotation.__qualname__ |
|
|
| |
| type_str = str(annotation) |
|
|
| |
| if type_str.startswith("<class '") and type_str.endswith("'>"): |
| type_str = type_str[8:-2] |
|
|
| |
| for prefix in ["typing.", "builtins.", "types."]: |
| if type_str.startswith(prefix): |
| type_str = type_str[len(prefix) :] |
|
|
| |
| if type_str in ("_empty", "inspect._empty"): |
| return "None" |
|
|
| |
| if type_str == "NoneType": |
| return "None" |
|
|
| return type_str |
|
|
| @classmethod |
| def _extract_coroutine_return_type(cls, annotation): |
| """Extract the actual return type from a Coroutine annotation.""" |
| if hasattr(annotation, "__args__") and len(annotation.__args__) > 2: |
| |
| return annotation.__args__[2] |
| return annotation |
|
|
| @classmethod |
| def _format_parameter_default(cls, default_value) -> str: |
| """Format a parameter's default value for stub files.""" |
| if default_value is inspect.Parameter.empty: |
| return "" |
| elif default_value is None: |
| return " = None" |
| elif isinstance(default_value, bool): |
| return f" = {default_value}" |
| elif default_value == {}: |
| return " = {}" |
| elif default_value == []: |
| return " = []" |
| else: |
| return f" = {default_value}" |
|
|
| @classmethod |
| def _format_method_parameters( |
| cls, |
| sig: inspect.Signature, |
| skip_self: bool = True, |
| type_hints: Optional[dict] = None, |
| type_tracker: Optional[TypeTracker] = None, |
| ) -> str: |
| """Format method parameters for stub files.""" |
| params = [] |
| if type_hints is None: |
| type_hints = {} |
|
|
| for i, (param_name, param) in enumerate(sig.parameters.items()): |
| if i == 0 and param_name == "self" and skip_self: |
| params.append("self") |
| else: |
| |
| annotation = type_hints.get(param_name, param.annotation) |
| type_str = cls._format_type_annotation(annotation, type_tracker) |
|
|
| |
| default_str = cls._format_parameter_default(param.default) |
|
|
| |
| if annotation is inspect.Parameter.empty: |
| params.append(f"{param_name}: Any{default_str}") |
| else: |
| params.append(f"{param_name}: {type_str}{default_str}") |
|
|
| return ", ".join(params) |
|
|
| @classmethod |
| def _generate_method_signature( |
| cls, |
| method_name: str, |
| method, |
| is_async: bool = False, |
| type_tracker: Optional[TypeTracker] = None, |
| ) -> str: |
| """Generate a complete method signature for stub files.""" |
| sig = inspect.signature(method) |
|
|
| |
| try: |
| from typing import get_type_hints |
| type_hints = get_type_hints(method) |
| except Exception: |
| |
| type_hints = {} |
|
|
| |
| return_annotation = type_hints.get('return', sig.return_annotation) |
| if is_async and inspect.iscoroutinefunction(method): |
| return_annotation = cls._extract_coroutine_return_type(return_annotation) |
|
|
| |
| params_str = cls._format_method_parameters(sig, type_hints=type_hints, type_tracker=type_tracker) |
|
|
| |
| return_type = cls._format_type_annotation(return_annotation, type_tracker) |
| if return_annotation is inspect.Signature.empty: |
| return_type = "None" |
|
|
| return f"def {method_name}({params_str}) -> {return_type}: ..." |
|
|
| @classmethod |
| def _generate_imports( |
| cls, async_class: Type, type_tracker: TypeTracker |
| ) -> list[str]: |
| """Generate import statements for the stub file.""" |
| imports = [] |
|
|
| |
| imports.append( |
| "from typing import Any, Dict, List, Optional, Tuple, Union, Set, Sequence, cast, NamedTuple" |
| ) |
|
|
| |
| if async_class.__module__ != "builtins": |
| module = inspect.getmodule(async_class) |
| additional_types = [] |
|
|
| if module: |
| |
| module_all = getattr(module, "__all__", None) |
|
|
| for name, obj in sorted(inspect.getmembers(module)): |
| if isinstance(obj, type): |
| |
| |
| if module_all is not None and name not in module_all: |
| |
| if name not in type_tracker.discovered_types: |
| continue |
|
|
| |
| if issubclass(obj, tuple) and hasattr(obj, "_fields"): |
| additional_types.append(name) |
| |
| type_tracker.already_imported.add(name) |
| |
| elif issubclass(obj, Enum) and name != "Enum": |
| additional_types.append(name) |
| |
| type_tracker.already_imported.add(name) |
|
|
| if additional_types: |
| type_imports = ", ".join([async_class.__name__] + additional_types) |
| imports.append(f"from {async_class.__module__} import {type_imports}") |
| else: |
| imports.append( |
| f"from {async_class.__module__} import {async_class.__name__}" |
| ) |
|
|
| |
| |
| imports.extend( |
| type_tracker.get_imports(main_module_name=async_class.__module__) |
| ) |
|
|
| |
| if hasattr(inspect.getmodule(async_class), "__name__"): |
| module_name = inspect.getmodule(async_class).__name__ |
| if "." in module_name: |
| base_module = module_name.split(".")[0] |
| |
| if not any(imp.startswith(f"from {base_module}") for imp in imports): |
| imports.append(f"import {base_module}") |
|
|
| return imports |
|
|
| @classmethod |
| def _get_class_attributes(cls, async_class: Type) -> list[tuple[str, Type]]: |
| """Extract class attributes that are classes themselves.""" |
| class_attributes = [] |
|
|
| |
| for name, attr in sorted(inspect.getmembers(async_class)): |
| if isinstance(attr, type) and not name.startswith("_"): |
| class_attributes.append((name, attr)) |
| elif ( |
| hasattr(async_class, "__annotations__") |
| and name in async_class.__annotations__ |
| ): |
| annotation = async_class.__annotations__[name] |
| if isinstance(annotation, type): |
| class_attributes.append((name, annotation)) |
|
|
| return class_attributes |
|
|
| @classmethod |
| def _generate_inner_class_stub( |
| cls, |
| name: str, |
| attr: Type, |
| indent: str = " ", |
| type_tracker: Optional[TypeTracker] = None, |
| ) -> list[str]: |
| """Generate stub for an inner class.""" |
| stub_lines = [] |
| stub_lines.append(f"{indent}class {name}Sync:") |
|
|
| |
| if hasattr(attr, "__doc__") and attr.__doc__: |
| stub_lines.extend( |
| cls._format_docstring_for_stub(attr.__doc__, f"{indent} ") |
| ) |
|
|
| |
| if hasattr(attr, "__init__"): |
| try: |
| init_method = getattr(attr, "__init__") |
| init_sig = inspect.signature(init_method) |
|
|
| |
| try: |
| from typing import get_type_hints |
| init_hints = get_type_hints(init_method) |
| except Exception: |
| init_hints = {} |
|
|
| |
| params_str = cls._format_method_parameters( |
| init_sig, type_hints=init_hints, type_tracker=type_tracker |
| ) |
| |
| if hasattr(init_method, "__doc__") and init_method.__doc__: |
| stub_lines.extend( |
| cls._format_docstring_for_stub( |
| init_method.__doc__, f"{indent} " |
| ) |
| ) |
| stub_lines.append( |
| f"{indent} def __init__({params_str}) -> None: ..." |
| ) |
| except (ValueError, TypeError): |
| stub_lines.append( |
| f"{indent} def __init__(self, *args, **kwargs) -> None: ..." |
| ) |
|
|
| |
| has_methods = False |
| for method_name, method in sorted( |
| inspect.getmembers(attr, predicate=inspect.isfunction) |
| ): |
| if method_name.startswith("_"): |
| continue |
|
|
| has_methods = True |
| try: |
| |
| if method.__doc__: |
| stub_lines.extend( |
| cls._format_docstring_for_stub(method.__doc__, f"{indent} ") |
| ) |
|
|
| method_sig = cls._generate_method_signature( |
| method_name, method, is_async=True, type_tracker=type_tracker |
| ) |
| stub_lines.append(f"{indent} {method_sig}") |
| except (ValueError, TypeError): |
| stub_lines.append( |
| f"{indent} def {method_name}(self, *args, **kwargs): ..." |
| ) |
|
|
| if not has_methods: |
| stub_lines.append(f"{indent} pass") |
|
|
| return stub_lines |
|
|
| @classmethod |
| def _format_docstring_for_stub( |
| cls, docstring: str, indent: str = " " |
| ) -> list[str]: |
| """Format a docstring for inclusion in a stub file with proper indentation.""" |
| if not docstring: |
| return [] |
|
|
| |
| dedented = textwrap.dedent(docstring).strip() |
|
|
| |
| lines = dedented.split("\n") |
|
|
| |
| result = [] |
| result.append(f'{indent}"""') |
|
|
| for line in lines: |
| if line.strip(): |
| result.append(f"{indent}{line}") |
| else: |
| result.append("") |
|
|
| result.append(f'{indent}"""') |
| return result |
|
|
| @classmethod |
| def _post_process_stub_content(cls, stub_content: list[str]) -> list[str]: |
| """Post-process stub content to fix any remaining issues.""" |
| processed = [] |
|
|
| for line in stub_content: |
| |
| if line.startswith(("from ", "import ")): |
| processed.append(line) |
| continue |
|
|
| |
| if ( |
| line.strip().startswith("def ") |
| and line.strip().endswith(": ...") |
| and ") -> " not in line |
| ): |
| |
| line = line.replace(": ...", " -> None: ...") |
|
|
| processed.append(line) |
|
|
| return processed |
|
|
| @classmethod |
| def generate_stub_file(cls, async_class: Type, sync_class: Type) -> None: |
| """ |
| Generate a .pyi stub file for the sync class to help IDEs with type checking. |
| """ |
| try: |
| |
| if async_class.__module__ == "__main__": |
| return |
|
|
| module = inspect.getmodule(async_class) |
| if not module: |
| return |
|
|
| module_path = module.__file__ |
| if not module_path: |
| return |
|
|
| |
| module_dir = os.path.dirname(module_path) |
| stub_dir = os.path.join(module_dir, "generated") |
|
|
| |
| os.makedirs(stub_dir, exist_ok=True) |
|
|
| module_name = os.path.basename(module_path) |
| if module_name.endswith(".py"): |
| module_name = module_name[:-3] |
|
|
| sync_stub_path = os.path.join(stub_dir, f"{sync_class.__name__}.pyi") |
|
|
| |
| type_tracker = TypeTracker() |
|
|
| stub_content = [] |
|
|
| |
| |
| imports_placeholder_index = len(stub_content) |
| stub_content.append("") |
|
|
| |
| stub_content.append(f"class {sync_class.__name__}:") |
|
|
| |
| if async_class.__doc__: |
| stub_content.extend( |
| cls._format_docstring_for_stub(async_class.__doc__, " ") |
| ) |
|
|
| |
| try: |
| init_method = async_class.__init__ |
| init_signature = inspect.signature(init_method) |
|
|
| |
| try: |
| from typing import get_type_hints |
| init_hints = get_type_hints(init_method) |
| except Exception: |
| init_hints = {} |
|
|
| |
| params_str = cls._format_method_parameters( |
| init_signature, type_hints=init_hints, type_tracker=type_tracker |
| ) |
| |
| if hasattr(init_method, "__doc__") and init_method.__doc__: |
| stub_content.extend( |
| cls._format_docstring_for_stub(init_method.__doc__, " ") |
| ) |
| stub_content.append(f" def __init__({params_str}) -> None: ...") |
| except (ValueError, TypeError): |
| stub_content.append( |
| " def __init__(self, *args, **kwargs) -> None: ..." |
| ) |
|
|
| stub_content.append("") |
|
|
| |
| class_attributes = cls._get_class_attributes(async_class) |
|
|
| |
| for name, attr in class_attributes: |
| inner_class_stub = cls._generate_inner_class_stub( |
| name, attr, type_tracker=type_tracker |
| ) |
| stub_content.extend(inner_class_stub) |
| stub_content.append("") |
|
|
| |
| processed_methods = set() |
| for name, method in sorted( |
| inspect.getmembers(async_class, predicate=inspect.isfunction) |
| ): |
| if name.startswith("_") or name in processed_methods: |
| continue |
|
|
| processed_methods.add(name) |
|
|
| try: |
| method_sig = cls._generate_method_signature( |
| name, method, is_async=True, type_tracker=type_tracker |
| ) |
|
|
| |
| if method.__doc__: |
| stub_content.extend( |
| cls._format_docstring_for_stub(method.__doc__, " ") |
| ) |
|
|
| stub_content.append(f" {method_sig}") |
|
|
| stub_content.append("") |
|
|
| except (ValueError, TypeError): |
| |
| stub_content.append(f" def {name}(self, *args, **kwargs): ...") |
| stub_content.append("") |
|
|
| |
| for name, prop in sorted( |
| inspect.getmembers(async_class, lambda x: isinstance(x, property)) |
| ): |
| stub_content.append(" @property") |
| stub_content.append(f" def {name}(self) -> Any: ...") |
| if prop.fset: |
| stub_content.append(f" @{name}.setter") |
| stub_content.append( |
| f" def {name}(self, value: Any) -> None: ..." |
| ) |
| stub_content.append("") |
|
|
| |
| |
| attribute_mappings = {} |
|
|
| |
| |
| all_annotations = {} |
| for base_class in reversed(inspect.getmro(async_class)): |
| if hasattr(base_class, "__annotations__"): |
| all_annotations.update(base_class.__annotations__) |
|
|
| for attr_name, attr_type in sorted(all_annotations.items()): |
| for class_name, class_type in class_attributes: |
| |
| if ( |
| attr_type == class_type |
| or (hasattr(attr_type, "__name__") and attr_type.__name__ == class_name) |
| or (isinstance(attr_type, str) and attr_type == class_name) |
| ): |
| attribute_mappings[class_name] = attr_name |
|
|
| |
|
|
| |
| for class_name, class_type in class_attributes: |
| |
| attr_name = attribute_mappings.get(class_name, class_name) |
| |
| |
| stub_content.append(f" {attr_name}: {class_name}Sync") |
|
|
| stub_content.append("") |
|
|
| |
| imports = cls._generate_imports(async_class, type_tracker) |
|
|
| |
| seen = set() |
| unique_imports = [] |
| for imp in imports: |
| if imp not in seen: |
| seen.add(imp) |
| unique_imports.append(imp) |
| else: |
| logging.warning(f"Duplicate import detected: {imp}") |
|
|
| |
| stub_content[imports_placeholder_index : imports_placeholder_index + 1] = ( |
| unique_imports |
| ) |
|
|
| |
| stub_content = cls._post_process_stub_content(stub_content) |
|
|
| |
| with open(sync_stub_path, "w") as f: |
| f.write("\n".join(stub_content)) |
|
|
| logging.info(f"Generated stub file: {sync_stub_path}") |
|
|
| except Exception as e: |
| |
| logging.error( |
| f"Error generating stub file for {sync_class.__name__}: {str(e)}" |
| ) |
| import traceback |
|
|
| logging.error(traceback.format_exc()) |
|
|
|
|
| def create_sync_class(async_class: Type, thread_pool_size=10) -> Type: |
| """ |
| Creates a sync version of an async class |
| |
| Args: |
| async_class: The async class to convert |
| thread_pool_size: Size of thread pool to use |
| |
| Returns: |
| A new class with sync versions of all async methods |
| """ |
| return AsyncToSyncConverter.create_sync_class(async_class, thread_pool_size) |
|
|