from __future__ import annotations import argparse from typing import List, Literal, Union, Any, Type, TypeVar from pydantic import BaseModel def _get_base_type(annotation: Type[Any]) -> Type[Any]: if getattr(annotation, "__origin__", None) is Literal: assert hasattr(annotation, "__args__") and len(annotation.__args__) >= 1 # type: ignore return type(annotation.__args__[0]) # type: ignore elif getattr(annotation, "__origin__", None) is Union: assert hasattr(annotation, "__args__") and len(annotation.__args__) >= 1 # type: ignore non_optional_args: List[Type[Any]] = [ arg for arg in annotation.__args__ if arg is not type(None) # type: ignore ] if non_optional_args: return _get_base_type(non_optional_args[0]) elif ( getattr(annotation, "__origin__", None) is list or getattr(annotation, "__origin__", None) is List ): assert hasattr(annotation, "__args__") and len(annotation.__args__) >= 1 # type: ignore return _get_base_type(annotation.__args__[0]) # type: ignore return annotation def _contains_list_type(annotation: Type[Any] | None) -> bool: origin = getattr(annotation, "__origin__", None) if origin is list or origin is List: return True elif origin in (Literal, Union): return any(_contains_list_type(arg) for arg in annotation.__args__) # type: ignore else: return False def _parse_bool_arg(arg: str | bytes | bool) -> bool: if isinstance(arg, bytes): arg = arg.decode("utf-8") true_values = {"1", "on", "t", "true", "y", "yes"} false_values = {"0", "off", "f", "false", "n", "no"} arg_str = str(arg).lower().strip() if arg_str in true_values: return True elif arg_str in false_values: return False else: raise ValueError(f"Invalid boolean argument: {arg}") def add_args_from_model(parser: argparse.ArgumentParser, model: Type[BaseModel]): """Add arguments from a pydantic model to an argparse parser.""" for name, field in model.model_fields.items(): description = field.description if field.default and description and not field.is_required(): description += f" (default: {field.default})" base_type = ( _get_base_type(field.annotation) if field.annotation is not None else str ) list_type = _contains_list_type(field.annotation) if base_type is not bool: parser.add_argument( f"--{name}", dest=name, nargs="*" if list_type else None, type=base_type, help=description, ) if base_type is bool: parser.add_argument( f"--{name}", dest=name, type=_parse_bool_arg, help=f"{description}", ) T = TypeVar("T", bound=Type[BaseModel]) def parse_model_from_args(model: T, args: argparse.Namespace) -> T: """Parse a pydantic model from an argparse namespace.""" return model( **{ k: v for k, v in vars(args).items() if v is not None and k in model.model_fields } )