Source code for rath.llm.embedding

"""Synchronous OpenAI-compatible embedding client (thin SDK wrapper).

Mirrors :class:`rath.llm.openai.client.RathOpenAIChatClient` in style:

* :class:`EmbeddingProvider` carries credentials + model + optional output
  dimension; the only required field is ``model`` (the OpenAI SDK refuses
  to pick one for you).
* :class:`RathOpenAIEmbeddingClient` wraps ``openai.OpenAI().embeddings``.
* Credential resolution: ``EmbeddingProvider.api_key`` →
  ``OPENAI_API_KEY`` env → ``llm.embedding_provider`` config entry →
  ``llm.default_provider`` config entry.

When the ``Provider`` (chat) and ``EmbeddingProvider`` share credentials,
:meth:`EmbeddingProvider.from_config` is the recommended constructor — it
reads both ``llm.embedding_provider`` (preferred) and ``llm.default_provider``
from ``~/.openrath/config.json`` and falls back gracefully.
"""

from __future__ import annotations

import os
from dataclasses import dataclass, replace
from typing import TYPE_CHECKING, Any, Sequence

from openai import (
    APIConnectionError,
    APITimeoutError,
    InternalServerError,
    OpenAI,
    RateLimitError,
)

from rath.llm.credentials import resolve_credential
from rath.llm.retry import retry_with_backoff

if TYPE_CHECKING:
    from rath.config.store import ConfigStore

__all__ = [
    "EmbeddingProvider",
    "RathOpenAIEmbeddingClient",
    "DEFAULT_EMBEDDING_MODEL",
]


#: Default model used by :meth:`EmbeddingProvider.from_config` when the
#: looked-up provider entry has no ``model`` set. Picked to match the
#: open-source default rather than a vendor-specific id.
DEFAULT_EMBEDDING_MODEL = "text-embedding-3-small"


_EMBEDDING_RETRYABLE: tuple[type[BaseException], ...] = (
    RateLimitError,
    APIConnectionError,
    APITimeoutError,
    InternalServerError,
)


[docs] @dataclass(frozen=True, kw_only=True, slots=True) class EmbeddingProvider: """Routing + credentials for an OpenAI-compatible embeddings endpoint. The chat ``Provider`` (in :mod:`rath.llm.provider`) is intentionally *not* reused: embedding endpoints frequently live under a different base_url / model namespace even when the api_key is shared. """ model: str base_url: str | None = None api_key: str | None = None #: When set, request a truncated/projected embedding vector. The OpenAI #: SDK passes this as ``dimensions=``. ``None`` means use the model's #: native dimension. dimensions: int | None = None #: Same retry knobs as :class:`Provider`; ``None`` uses built-in defaults. retry_max_attempts: int | None = None retry_base_seconds: float | None = None def __str__(self) -> str: return self.model def __repr__(self) -> str: return self.__str__()
[docs] @classmethod def from_config( cls, name: str | None = None, *, store: "ConfigStore | None" = None, **overrides: Any, ) -> "EmbeddingProvider": """Build an :class:`EmbeddingProvider` from ``~/.openrath/config.json``. Lookup order: 1. ``name`` if given. 2. ``llm.embedding_provider`` if set. 3. ``llm.default_provider`` (chat fallback) — uses its credentials but replaces ``model`` with :data:`DEFAULT_EMBEDDING_MODEL` since the chat model is unsuitable for embeddings. Raises :class:`KeyError` only when ``name`` is given explicitly and the entry is missing. """ from rath.config.store import ConfigStore # local — see Provider.from_config s = store or ConfigStore.load() entry = None use_default_fallback = False if name is not None: entry = s.get_llm_provider(name) else: embed_name = getattr(s.config.llm, "embedding_provider", None) if embed_name is not None and embed_name in s.config.llm.providers: entry = s.config.llm.providers[embed_name] elif s.config.llm.default_provider is not None: entry = s.config.llm.providers.get(s.config.llm.default_provider) use_default_fallback = True if entry is None: base = cls(model=DEFAULT_EMBEDDING_MODEL) else: model = entry.model if use_default_fallback or not model: model = DEFAULT_EMBEDDING_MODEL base = cls( model=model, api_key=entry.api_key, base_url=entry.base_url, ) if not overrides: return base return replace(base, **overrides)
def _resolve_api_key(provider: EmbeddingProvider) -> str: return resolve_credential( provider.api_key, os.environ.get("OPENAI_API_KEY"), ) def _resolve_base_url(provider: EmbeddingProvider) -> str: return resolve_credential( provider.base_url, os.environ.get("OPENAI_BASE_URL"), )
[docs] class RathOpenAIEmbeddingClient: """Thin wrapper around ``openai.OpenAI().embeddings.create``. Construct once per :class:`EmbeddingProvider`; the underlying SDK client is created up-front and reused across calls. """ def __init__(self, provider: EmbeddingProvider) -> None: key = _resolve_api_key(provider) if not key: raise ValueError( "No API key for EmbeddingProvider: set EmbeddingProvider.api_key, " "export OPENAI_API_KEY, or configure llm.embedding_provider / " "llm.default_provider in ~/.openrath/config.json.", ) self._provider = provider init_kw: dict[str, Any] = {"api_key": key} base_url = _resolve_base_url(provider) if base_url: init_kw["base_url"] = base_url self._client: OpenAI = OpenAI(**init_kw) @property def provider(self) -> EmbeddingProvider: return self._provider
[docs] def embed(self, texts: Sequence[str]) -> tuple[tuple[float, ...], ...]: """Embed an arbitrary number of texts; returns one vector per input. An empty ``texts`` short-circuits to ``()`` without an API call. """ if not texts: return () kwargs: dict[str, Any] = { "model": self._provider.model, "input": list(texts), } if self._provider.dimensions is not None: kwargs["dimensions"] = self._provider.dimensions def _call() -> tuple[tuple[float, ...], ...]: resp = self._client.embeddings.create(**kwargs) return tuple(tuple(float(x) for x in d.embedding) for d in resp.data) return retry_with_backoff( _call, retryable=_EMBEDDING_RETRYABLE, max_attempts=self._provider.retry_max_attempts, base_seconds=self._provider.retry_base_seconds, )
[docs] def embed_one(self, text: str) -> tuple[float, ...]: """Convenience for the single-text case.""" vectors = self.embed((text,)) if not vectors: return () return vectors[0]