Source code for rath.llm.registry

"""Registry mapping :attr:`Provider.provider_kind` to a :class:`ChatClient` factory.

Built-in adapters (``"openai"``, ``"anthropic"``) self-register on import of
:mod:`rath.llm.openai` / :mod:`rath.llm.anthropic`, which :mod:`rath.llm` does
eagerly. Third parties can call :func:`register_chat_client` to add their own
without modifying core.

The single dispatch point :func:`chat_client_for` replaces the previous
``provider.provider_kind == "anthropic"`` string check that lived in
:mod:`rath.session.loop`.
"""

from __future__ import annotations

import threading
from typing import Callable

from rath.llm.base import ChatClient
from rath.llm.provider import Provider

__all__ = [
    "ChatClientFactory",
    "register_chat_client",
    "chat_client_for",
    "registered_kinds",
]

ChatClientFactory = Callable[[Provider], ChatClient]

_FACTORIES: dict[str, ChatClientFactory] = {}
# Guards reads from / writes to ``_FACTORIES`` only. Deliberately does
# **not** wrap ``factory(provider)`` in :func:`chat_client_for` — built-in
# factories (``RathOpenAIChatClient``, ``RathAnthropicChatClient``) are
# lightweight wrappers around the underlying SDK clients and serializing
# their construction would block parallel callers for no benefit. If you
# register a factory that needs serialization (e.g. one that calls out
# to a remote service), wrap that side effect with your own lock inside
# the factory.
_FACTORIES_LOCK = threading.Lock()


[docs] def register_chat_client(kind: str, factory: ChatClientFactory) -> None: """Register ``factory(provider) -> ChatClient`` under ``kind``. Overwrites any previous registration silently — late imports therefore win. Built-in kinds (``"openai"``, ``"anthropic"``) are registered when their subpackages are imported by :mod:`rath.llm`. """ with _FACTORIES_LOCK: _FACTORIES[kind] = factory
[docs] def chat_client_for(provider: Provider) -> ChatClient: """Return the :class:`ChatClient` for ``provider.provider_kind``. ``provider.provider_kind=None`` defaults to ``"openai"``. Unknown kinds raise ``ValueError`` listing what is currently registered. """ kind = provider.provider_kind or "openai" with _FACTORIES_LOCK: try: factory = _FACTORIES[kind] except KeyError as e: raise ValueError( f"unknown provider_kind={kind!r}; " f"registered kinds: {sorted(_FACTORIES)}", ) from e return factory(provider)
[docs] def registered_kinds() -> tuple[str, ...]: """Snapshot of currently registered kinds (useful for diagnostics / tests).""" with _FACTORIES_LOCK: return tuple(sorted(_FACTORIES))