Skip to content

Dependency Injection API

Module: wepositive_di.di

Depends = _DependsType() module-attribute

registry = containers.DynamicContainer() module-attribute

register_provider(name=None, singleton=False, context_manager=False)

Source code in src/wepositive_di/di.py
def register_provider(
    name: str | None = None,
    singleton: bool = False,
    context_manager: bool = False,
):
    def decorator(func: Callable[..., Any]):
        """Register a provider function (sync or async) in the registry.

        Args:
            name: Optional name for the provider (defaults to function name)
            singleton: If True, caches and reuses the first created instance.
            context_manager: If True, enter and exit the provider's context manager
                when resolving dependencies. Context-manager handling is opt-in.
        """
        provider_name = name or func.__name__
        is_async_func = inspect.iscoroutinefunction(func)

        if is_async_func and not context_manager and singleton:
            raise ValueError(
                f"Async provider '{provider_name}' cannot be a singleton. "
                f"The dependency-injector library doesn't support singleton caching for Coroutine providers. "
                f"Make your provider a sync function instead: def {func.__name__}(...)"
            )

        provider = _create_provider(
            func,
            provider_name=provider_name,
            singleton=singleton,
            context_manager=context_manager,
        )
        setattr(registry, provider_name, provider)
        if context_manager:
            _context_manager_providers.add(provider_name)
        else:
            _context_manager_providers.discard(provider_name)

        module = inspect.getmodule(func)
        if module is not None:  # pragma: no branch
            _registered_modules.add(module)
        return func

    return decorator

setup(overrides=None)

Wire the dependency injection system.

Parameters:

Name Type Description Default
overrides dict[Callable[..., Any] | str, Callable[..., Any]] | None

Optional dictionary of provider overrides to apply before wiring. Maps original providers to their override implementations.

None

Usage:

setup()

def redis_storage() -> ContextStorage:
    return RedisContextStorage()

setup(overrides={context_storage_singleton: redis_storage})
Source code in src/wepositive_di/di.py
def setup(
    overrides: dict[Callable[..., Any] | str, Callable[..., Any]] | None = None,
):
    """Wire the dependency injection system.

    Args:
        overrides: Optional dictionary of provider overrides to apply before wiring.
                  Maps original providers to their override implementations.

    Usage:

    ```python
    setup()

    def redis_storage() -> ContextStorage:
        return RedisContextStorage()

    setup(overrides={context_storage_singleton: redis_storage})
    ```
    """
    if overrides:
        for original, override_func in overrides.items():
            provider_name = original if isinstance(original, str) else original.__name__
            _provider_overrides[provider_name] = _create_provider(
                override_func,
                provider_name=provider_name,
                context_manager=_uses_context_manager(provider_name),
            )

    registry.wire(modules=list(_registered_modules))

override_provider(original, override=None)

override_provider(
    original: Callable[..., Any] | str,
) -> Callable[[Callable[..., Any]], Callable[..., Any]]
override_provider(
    original: Callable[..., Any] | str,
    override: Callable[..., Any],
) -> None

Override a provider with a new implementation.

Can be used as a function or a decorator.

Parameters:

Name Type Description Default
original Callable[..., Any] | str

The original provider function or its name

required
override Callable[..., Any] | None

The new provider function (when used as a function call)

None

Returns:

Type Description
None | Callable[[Callable[..., Any]], Callable[..., Any]]

None when used as a function, decorator when used as @override_provider(original)

Usage:

@register_provider()
async def config() -> Config:
    return Config()

async def prod_config() -> Config:
    return Config(db_url="production")

override_provider(config, prod_config)
@override_provider(config)
async def prod_config() -> Config:
    return Config(db_url="production")
Source code in src/wepositive_di/di.py
def override_provider(
    original: Callable[..., Any] | str,
    override: Callable[..., Any] | None = None,
) -> None | Callable[[Callable[..., Any]], Callable[..., Any]]:
    """Override a provider with a new implementation.

    Can be used as a function or a decorator.

    Args:
        original: The original provider function or its name
        override: The new provider function (when used as a function call)

    Returns:
        None when used as a function, decorator when used as @override_provider(original)

    Usage:

    ```python
    @register_provider()
    async def config() -> Config:
        return Config()

    async def prod_config() -> Config:
        return Config(db_url="production")

    override_provider(config, prod_config)
    ```

    ```python
    @override_provider(config)
    async def prod_config() -> Config:
        return Config(db_url="production")
    ```
    """
    provider_name = original if isinstance(original, str) else original.__name__

    # Used as @override_provider(original)
    if override is None:

        def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
            _provider_overrides[provider_name] = _create_provider(
                func,
                provider_name=provider_name,
                context_manager=_uses_context_manager(provider_name),
            )
            return func

        return decorator

    # Used as override_provider(original, override_func)
    _provider_overrides[provider_name] = _create_provider(
        override,
        provider_name=provider_name,
        context_manager=_uses_context_manager(provider_name),
    )
    return None

clear_overrides()

Clear all provider overrides.

Source code in src/wepositive_di/di.py
def clear_overrides() -> None:
    """Clear all provider overrides."""
    _provider_overrides.clear()

provider_overrides(overrides)

Context manager to temporarily override providers for testing.

Parameters:

Name Type Description Default
overrides dict[Callable[..., Any] | str, Callable[..., Any]]

Dictionary mapping original providers to their overrides

required

Usage:

async def test_config() -> Config:
    return Config(sqlalchemy_db_uri=SecretStr("sqlite:///:memory:"))

with provider_overrides({config: test_config}):
    result = await my_function()
Source code in src/wepositive_di/di.py
@contextmanager
def provider_overrides(
    overrides: dict[Callable[..., Any] | str, Callable[..., Any]],
):
    """Context manager to temporarily override providers for testing.

    Args:
        overrides: Dictionary mapping original providers to their overrides

    Usage:

    ```python
    async def test_config() -> Config:
        return Config(sqlalchemy_db_uri=SecretStr("sqlite:///:memory:"))

    with provider_overrides({config: test_config}):
        result = await my_function()
    ```
    """
    # Save current state
    old_overrides = _provider_overrides.copy()

    # Apply new overrides
    for original, override in overrides.items():
        provider_name = original if isinstance(original, str) else original.__name__
        _provider_overrides[provider_name] = _create_provider(
            override,
            provider_name=provider_name,
            context_manager=_uses_context_manager(provider_name),
        )

    try:
        yield
    finally:
        # Restore previous state
        _provider_overrides.clear()
        _provider_overrides.update(old_overrides)

inject(func)

inject(
    func: Callable[..., Coroutine[Any, Any, Any]],
) -> Callable[..., Coroutine[Any, Any, Any]]
inject(func: Callable[..., Any]) -> Callable[..., Any]

Decorator that resolves Depends markers in function arguments.

Works with both sync and async functions.

The decorator:

  • Inspects the function signature for _DependsMarker defaults.
  • At call time, resolves each marker by calling the registry provider.
  • Handles context manager providers transparently: enters the context manager, passes the yielded value to the function, and exits on completion.
  • Returns the appropriate wrapper based on the function type.

Provider types and how they are resolved:

  • AsyncCMFactory: await coroutine, get context manager, await __aenter__.
  • CMFactory: get context manager, call __enter__.
  • providers.Coroutine: await result.
  • providers.Factory / providers.Singleton: use result directly.

Usage:

@inject
def my_func(config: Config = Depends[config]):
    return config.value

@inject
async def my_async_func(session: AsyncSession = Depends[async_session]):
    return session.query(...)
Source code in src/wepositive_di/di.py
def inject(func: T) -> T:
    """Decorator that resolves Depends markers in function arguments.

    Works with both sync and async functions.

    The decorator:

    - Inspects the function signature for `_DependsMarker` defaults.
    - At call time, resolves each marker by calling the registry provider.
    - Handles context manager providers transparently: enters the context manager,
      passes the yielded value to the function, and exits on completion.
    - Returns the appropriate wrapper based on the function type.

    Provider types and how they are resolved:

    - AsyncCMFactory: await coroutine, get context manager, await `__aenter__`.
    - CMFactory: get context manager, call `__enter__`.
    - providers.Coroutine: await result.
    - providers.Factory / providers.Singleton: use result directly.

    Usage:

    ```python
    @inject
    def my_func(config: Config = Depends[config]):
        return config.value

    @inject
    async def my_async_func(session: AsyncSession = Depends[async_session]):
        return session.query(...)
    ```

    """
    sig = inspect.signature(func)
    dependant_is_async = inspect.iscoroutinefunction(
        func
    ) or inspect.isasyncgenfunction(func)

    async def _resolve_dependencies_async(
        args: tuple[Any, ...], kwargs: dict[str, Any]
    ) -> tuple[
        dict[str, Any],
        set[AbstractAsyncContextManager[Any]],
        set[AbstractContextManager[Any]],
    ]:
        bound = sig.bind_partial(*args, **kwargs)
        bound.apply_defaults()
        sync_lifecycles_to_cleanup: set[AbstractContextManager[Any]] = set()
        async_lifecycles_to_cleanup: set[AbstractAsyncContextManager[Any]] = set()

        for param_name in list(bound.arguments.keys()):
            value = bound.arguments[param_name]
            if not isinstance(value, _DependsMarker):
                continue
            provider = _lookup_provider(value.name)
            result = provider()

            if isinstance(provider, _AsyncSingletonCMFactory):
                result = await result
            elif isinstance(provider, _SyncSingletonCMFactory):
                pass  # sync Resource returns the cached value directly
            elif isinstance(provider, _AsyncCMFactory):
                cm = await result  # await async wrapper to get the CM object
                async_lifecycles_to_cleanup.add(cm)
                result = await cm.__aenter__()
            elif isinstance(provider, _CMFactory):
                cm = result  # sync wrapper returns the CM directly
                sync_lifecycles_to_cleanup.add(cm)
                result = cm.__enter__()
            elif isinstance(provider, providers.Coroutine):
                result = await result

            bound.arguments[param_name] = result

        return bound.arguments, async_lifecycles_to_cleanup, sync_lifecycles_to_cleanup

    def _resolve_dependencies_sync(
        args: tuple[Any, ...], kwargs: dict[str, Any]
    ) -> tuple[
        dict[str, Any],
        set[AbstractAsyncContextManager[Any]],
        set[AbstractContextManager[Any]],
    ]:
        bound = sig.bind_partial(*args, **kwargs)
        bound.apply_defaults()
        sync_lifecycles_to_cleanup: set[AbstractContextManager[Any]] = set()
        async_lifecycles_to_cleanup: set[AbstractAsyncContextManager[Any]] = set()

        for param_name in list(bound.arguments.keys()):
            value = bound.arguments[param_name]
            if not isinstance(value, _DependsMarker):
                continue
            provider = _lookup_provider(value.name)
            result = provider()

            if isinstance(provider, _AsyncSingletonCMFactory):
                with _create_event_loop(param_name, func.__name__) as loop:
                    result = loop.run_until_complete(result)
            elif isinstance(provider, _SyncSingletonCMFactory):
                pass  # sync Resource returns the cached value directly
            elif isinstance(provider, _AsyncCMFactory):
                with _create_event_loop(param_name, func.__name__) as loop:
                    cm = loop.run_until_complete(result)  # await async wrapper → CM
                    async_lifecycles_to_cleanup.add(cm)
                    result = loop.run_until_complete(cm.__aenter__())
            elif isinstance(provider, _CMFactory):
                cm = result
                sync_lifecycles_to_cleanup.add(cm)
                result = cm.__enter__()
            elif isinstance(provider, providers.Coroutine):
                with _create_event_loop(param_name, func.__name__) as loop:
                    result = loop.run_until_complete(result)

            bound.arguments[param_name] = result

        return bound.arguments, async_lifecycles_to_cleanup, sync_lifecycles_to_cleanup

    if dependant_is_async:

        @functools.wraps(func)
        async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
            (
                resolved_args,
                async_lifecycles_to_cleanup,
                sync_lifecycles_to_cleanup,
            ) = await _resolve_dependencies_async(args, kwargs)

            exc_info = (None, None, None)
            try:
                result = await func(**resolved_args)
                return result
            except Exception:
                exc_info = sys.exc_info()
                raise
            finally:
                suppressed = False
                for cm in async_lifecycles_to_cleanup:
                    if await cm.__aexit__(*exc_info):
                        suppressed = True
                for cm in sync_lifecycles_to_cleanup:
                    if cm.__exit__(*exc_info):
                        suppressed = True
                if exc_info[0] is not None and not suppressed:
                    raise exc_info[1].with_traceback(exc_info[2])  # type: ignore[union-attr]

        return async_wrapper  # type: ignore
    else:

        @functools.wraps(func)
        def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
            resolved_args, async_lifecycles_to_cleanup, sync_lifecycles_to_cleanup = (
                _resolve_dependencies_sync(args, kwargs)
            )
            exc_info = (None, None, None)
            try:
                result = func(**resolved_args)
                return result
            except Exception:
                exc_info = sys.exc_info()
                raise
            finally:
                suppressed = False
                for cm in async_lifecycles_to_cleanup:
                    cm_name = getattr(cm, "__wrapped__", type(cm)).__name__
                    with _create_event_loop(cm_name, func.__name__) as loop:
                        if loop.run_until_complete(cm.__aexit__(*exc_info)):
                            suppressed = True
                for cm in sync_lifecycles_to_cleanup:
                    if cm.__exit__(*exc_info):
                        suppressed = True
                if exc_info[0] is not None and not suppressed:
                    raise exc_info[1].with_traceback(exc_info[2])  # type: ignore[union-attr]

        return sync_wrapper  # type: ignore

The custom provider wrapper classes used internally by register_provider() are documented in Provider Classes.