Source code for linkd.solver

# -*- coding: utf-8 -*-
# Copyright (c) 2025-present tandemdude
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
from __future__ import annotations

__all__ = ["DI_CONTAINER", "DI_ENABLED", "INJECTED", "AutoInjecting", "DependencyInjectionManager", "with_di"]

import collections
import contextlib
import contextvars
import inspect
import logging
import os
import sys
import typing as t
from collections.abc import Coroutine

from linkd import conditions
from linkd import container
from linkd import context as context_
from linkd import exceptions
from linkd import registry
from linkd import utils

if t.TYPE_CHECKING:
    from collections.abc import AsyncIterator
    from collections.abc import Awaitable
    from collections.abc import Callable

    from linkd import types

P = t.ParamSpec("P")
R = t.TypeVar("R")
T = t.TypeVar("T")
AsyncFnT = t.TypeVar("AsyncFnT", bound=t.Callable[..., Coroutine[t.Any, t.Any, t.Any]])

DI_ENABLED: t.Final[bool] = os.environ.get("LINKD_DI_DISABLED", "false").lower() != "true"
DI_CONTAINER: contextvars.ContextVar[container.Container | None] = contextvars.ContextVar(
    "linkd_container", default=None
)
LOGGER = logging.getLogger(__name__)

INJECTED: t.Final[t.Any] = utils.Marker("INJECTED")
"""
Flag value used to explicitly mark that a function parameter should be dependency-injected.

This exists to stop type checkers complaining that function arguments are not provided when calling
dependency-injection enabled functions.

Example:

    .. code-block:: python

        @linkd.with_di
        async def foo(bar: SomeClass = linkd.INJECTED) -> None:
            ...

        # Type-checker shouldn't error that a parameter is missing
        await foo()
"""


class _NoOpContainer(container.Container):
    __slots__ = ()

    def add_factory(
        self,
        typ: type[T],
        factory: Callable[..., types.MaybeAwaitable[T]],
        *,
        teardown: Callable[[T], types.MaybeAwaitable[None]] | None = None,
    ) -> None: ...

    def add_value(
        self,
        typ: type[T],
        value: T,
        *,
        teardown: Callable[[T], types.MaybeAwaitable[None]] | None = None,
    ) -> None: ...

    def _get(self, dependency_id: str) -> t.Any:
        raise exceptions.DependencyNotSatisfiableException("dependency injection is globally disabled")


_NOOP_CONTAINER = _NoOpContainer(registry.Registry(), tag=context_.Contexts.DEFAULT)


[docs] class DependencyInjectionManager: """Class which contains dependency injection functionality.""" __slots__ = ("_default_container", "_registries") def __init__(self) -> None: self._registries: dict[context_.Context, registry.Registry] = collections.defaultdict(registry.Registry) self._default_container: container.Container | None = None @property def default_container(self) -> container.Container | None: """ The container being used to provide dependencies for the :attr:`~Contexts.DEFAULT` context. This will be :obj:`None` until the first time any injection context is entered. """ return self._default_container
[docs] def registry_for(self, context: context_.Context, /) -> registry.Registry: """ Get the dependency registry for the given context. Creates one if necessary. Args: context: The injection context to get the registry for. Returns: The dependency registry for the given context. """ return self._registries[context]
[docs] @contextlib.asynccontextmanager async def enter_context( self, context: context_.Context = context_.Contexts.DEFAULT, / ) -> AsyncIterator[container.Container]: """ Context manager that ensures a dependency injection context is available for the nested operations. Args: context: The context to enter. If a container for the given context already exists, it will be returned and a new container will not be created. Yields: :obj:`~linkd.container.Container`: The container that has been entered. Example: .. code-block:: python # Enter a specific context ('manager' is your DependencyInjectionManager instance) async with manager.enter_context(linkd.Contexts.DEFAULT): await some_function_that_needs_dependencies() Note: If you want to enter multiple contexts - i.e. a command context that requires the default context to be available first - you should call this once for each context that is needed. .. code-block:: python async with ( manager.enter_context(linkd.Contexts.DEFAULT), manager.enter_context(SOME_COMMAND_CONTEXT) ): ... """ if not DI_ENABLED: # Return a container that will never register dependencies and cannot have dependencies # retrieved from it - it will always raise an error if someone tries to use DI while it is # globally disabled. yield _NOOP_CONTAINER return LOGGER.debug("attempting to enter context %r", context) new_container: container.Container | None = None created: bool = False token, value = None, DI_CONTAINER.get(None) if value is not None: LOGGER.debug("searching for existing container for context %r", context) this = value while this: if this._tag == context: new_container = this LOGGER.debug("existing container found for context %r", context) break this = this._parent if new_container is None: if context == context_.Contexts.DEFAULT and self._default_container is not None: LOGGER.debug("reusing existing container for context %r", context) new_container = self._default_container else: LOGGER.debug("creating new container for context %r", context) new_container = container.Container(self._registries[context], parent=value, tag=context) if (cls := context_.global_context_registry.type_for(context)) is not None: new_container.add_value(cls, new_container) if context == context_.Contexts.DEFAULT: self._default_container = new_container created = True token = DI_CONTAINER.set(new_container) LOGGER.debug("entered context %r", context) try: if new_container is self._default_container or not created: yield new_container else: async with new_container: yield new_container finally: DI_CONTAINER.reset(token) LOGGER.debug("cleared context %r", context)
[docs] async def close(self) -> None: """ Close the default dependency injection context. This **must** be called if you wish the teardown functions for any dependencies registered for the default registry to be called. Returns: :obj:`None` """ if self._default_container is not None: await self._default_container.close() self._default_container = None
CANNOT_INJECT: t.Final[t.Any] = utils.Marker("CANNOT_INJECT") def _parse_injectable_params(func: Callable[..., t.Any]) -> tuple[list[tuple[str, t.Any]], dict[str, t.Any]]: positional_or_keyword_params: list[tuple[str, t.Any]] = [] keyword_only_params: dict[str, t.Any] = {} parameters = inspect.signature(func, locals={"linkd": sys.modules["linkd"]}, eval_str=True).parameters for parameter in parameters.values(): if ( # If the parameter has no annotation parameter.annotation is inspect.Parameter.empty # If the parameter is not positional-or-keyword or keyword-only or parameter.kind in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD) # If it has a default that isn't INJECTED or ((default := parameter.default) is not inspect.Parameter.empty and default is not INJECTED) ): if parameter.kind in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD): positional_or_keyword_params.append((parameter.name, CANNOT_INJECT)) continue expr = conditions.DependencyExpression.create(parameter.annotation) if parameter.kind is inspect.Parameter.POSITIONAL_OR_KEYWORD: positional_or_keyword_params.append((parameter.name, expr)) else: # It has to be a keyword-only parameter keyword_only_params[parameter.name] = expr return positional_or_keyword_params, keyword_only_params
[docs] class AutoInjecting: """ Wrapper for a callable that implements dependency injection. When called, resolves the required dependencies and calls the original callable. Supports both synchronous and asynchronous functions, however this cannot be called synchronously - synchronous functions will need to be awaited. You should generally never have to instantiate this yourself - you should instead use one of the decorators that applies this to the target automatically. See Also: :meth:`~with_di` """ __slots__ = ("_func", "_kw_only_params", "_pos_or_kw_params", "_self") def __init__( self, func: Callable[..., Awaitable[t.Any]], self_: t.Any = None, _cached_pos_or_kw_params: list[tuple[str, t.Any]] | None = None, _cached_kw_only_params: dict[str, t.Any] | None = None, ) -> None: self._func = func self._self: t.Any = self_ if _cached_pos_or_kw_params is not None and _cached_kw_only_params is not None: self._pos_or_kw_params = _cached_pos_or_kw_params self._kw_only_params = _cached_kw_only_params else: self._pos_or_kw_params, self._kw_only_params = _parse_injectable_params(func) def __get__(self, instance: t.Any, _: type[t.Any]) -> AutoInjecting: if instance is not None: return AutoInjecting(self._func, instance, self._pos_or_kw_params, self._kw_only_params) return self def __getattr__(self, item: str) -> t.Any: return getattr(self._func, item) def __setattr__(self, key: str, value: t.Any) -> None: if key in self.__slots__: return super().__setattr__(key, value) setattr(self._func, key, value) async def __call__(self, *args: t.Any, **kwargs: t.Any) -> t.Any: new_kwargs: dict[str, t.Any] = {} new_kwargs.update(kwargs) di_container: container.Container | None = DI_CONTAINER.get(None) injectables = { name: type for name, type in self._pos_or_kw_params[len(args) + (self._self is not None) :] if name not in new_kwargs } injectables.update({name: type for name, type in self._kw_only_params.items() if name not in new_kwargs}) for name, type_expr in injectables.items(): # Skip any arguments that we can't inject if type_expr is CANNOT_INJECT: continue if di_container is None: raise exceptions.DependencyNotSatisfiableException("no DI context is available") assert isinstance(type_expr, conditions.DependencyExpression) LOGGER.debug("requesting dependency matching %r", type_expr) # type: ignore[reportUnknownArgumentType] new_kwargs[name] = await type_expr.resolve(di_container) if len(new_kwargs) > len(kwargs): func_name = ((self._self.__class__.__name__ + ".") if self._self else "") + self._func.__name__ LOGGER.debug("calling function %r with resolved dependencies", func_name) if self._self is not None: return await utils.maybe_await(self._func(self._self, *args, **new_kwargs)) return await utils.maybe_await(self._func(*args, **new_kwargs))
@t.overload def with_di(func: AsyncFnT) -> AsyncFnT: ... @t.overload def with_di(func: Callable[P, R]) -> Callable[P, Coroutine[t.Any, t.Any, R]]: ...
[docs] def with_di(func: Callable[P, types.MaybeAwaitable[R]]) -> Callable[P, Coroutine[t.Any, t.Any, R]]: """ Decorator that enables dependency injection on the decorated function. If dependency injection has been disabled globally then this function does nothing and simply returns the object that was passed in. Args: func: The function to enable dependency injection for. Returns: The function with dependency injection enabled, or the same function if DI has been disabled globally. Warning: Dependency injection relies on a context being available when the function is called. Refer to library documentation for which flows will have a dependency injection context set-up for you automatically. Otherwise, you will have to set up the context yourself using the helper context manager :meth:`~DependencyInjectionManager.enter_context`. """ if DI_ENABLED and not isinstance(func, AutoInjecting): return AutoInjecting(func) # type: ignore[reportReturnType] return func # type: ignore[reportReturnType]