Source code for design_research_agents._memory._embedding

"""Embedding provider abstractions for memory retrieval backends."""

from __future__ import annotations

import inspect
from collections.abc import Callable, Sequence
from typing import Protocol, cast, runtime_checkable

from design_research_agents._contracts._llm import EmbeddingResult, LLMCapabilityError


[docs] @runtime_checkable class EmbeddingProvider(Protocol): """Protocol for converting text payloads into dense vectors.""" @property def model_name(self) -> str: """Return stable embedding model identifier. Returns: Stable embedding model identifier. """
[docs] def embed(self, texts: Sequence[str]) -> list[list[float]] | None: """Embed one or more texts. Args: texts: Input texts to embed. Returns: Embedding vectors in input order, or ``None`` when embeddings are unavailable and lexical fallback should be used. """
[docs] class LLMEmbeddingProvider: """Embedding adapter that delegates to an LLM client/backend when available.""" def __init__( self, *, llm_client: object, model: str | None = None, enable_lexical_fallback: bool = True, ) -> None: """Initialize one embedding adapter from an existing LLM client. Args: llm_client: LLM client instance that may expose embeddings. model: Optional explicit embedding model identifier. enable_lexical_fallback: Return ``None`` instead of raising when embeddings are unsupported. """ self._llm_client = llm_client self._enable_lexical_fallback = enable_lexical_fallback resolved_model_name = (model or "").strip() if not resolved_model_name: default_model_callable = getattr(llm_client, "default_model", None) if callable(default_model_callable): try: resolved_model_name = str(default_model_callable()).strip() except Exception: resolved_model_name = "default" else: resolved_model_name = "default" self._model_name = resolved_model_name @property def model_name(self) -> str: """Return configured embedding model identifier. Returns: Resolved embedding model identifier. """ return self._model_name
[docs] def embed(self, texts: Sequence[str]) -> list[list[float]] | None: """Embed texts through available client/backend embedding APIs. Args: texts: Input texts to embed. Returns: List of vectors, or ``None`` when lexical fallback is enabled and embeddings are unavailable. Raises: Exception: Raised when embeddings fail and lexical fallback is disabled. """ if not texts: return [] embed_callable = self._resolve_embed_callable() if embed_callable is None: if self._enable_lexical_fallback: return None raise RuntimeError("No embedding-capable callable found on llm_client/backend.") try: result = self._invoke_embed(embed_callable, texts) except LLMCapabilityError: if self._enable_lexical_fallback: return None raise except Exception: if self._enable_lexical_fallback: return None raise vectors = result.vectors if isinstance(result, EmbeddingResult) else result normalized_vectors: list[list[float]] = [] for vector in vectors: normalized_vectors.append([float(value) for value in vector]) if len(normalized_vectors) != len(texts): if self._enable_lexical_fallback: return None raise RuntimeError("Embedding provider returned mismatched vector count.") return normalized_vectors
def _resolve_embed_callable( self, ) -> Callable[..., EmbeddingResult | Sequence[Sequence[float]]] | None: """Return first available embedding callable from client or backend. Returns: Embed callable when available, otherwise ``None``. """ direct_embed = getattr(self._llm_client, "embed", None) if callable(direct_embed): return cast( Callable[..., EmbeddingResult | Sequence[Sequence[float]]], direct_embed, ) backend = getattr(self._llm_client, "_backend", None) backend_embed = getattr(backend, "embed", None) if callable(backend_embed): return cast( Callable[..., EmbeddingResult | Sequence[Sequence[float]]], backend_embed, ) return None def _invoke_embed( self, embed_callable: Callable[..., EmbeddingResult | Sequence[Sequence[float]]], texts: Sequence[str], ) -> EmbeddingResult | Sequence[Sequence[float]]: """Invoke embed callable while adapting signature variants. Args: embed_callable: Resolved embedding callable. texts: Texts to embed. Returns: Raw embedding result from the resolved callable. """ signature = inspect.signature(embed_callable) parameters = signature.parameters if "model" in parameters: return embed_callable(texts, model=self._model_name) return embed_callable(texts)
__all__ = ["EmbeddingProvider", "LLMEmbeddingProvider"]