Source code for design_research_agents._model_selection._selector

"""Public model selection facade with flattened constructor-first ergonomics."""

from __future__ import annotations

from collections.abc import Callable, Mapping
from typing import Literal, cast

from design_research_agents._contracts._llm import LLMClient
from design_research_agents.llm import (
    AzureOpenAIServiceLLMClient,
    LlamaCppServerLLMClient,
    MLXLocalLLMClient,
    OllamaLLMClient,
    OpenAICompatibleHTTPLLMClient,
    OpenAIServiceLLMClient,
    SGLangServerLLMClient,
    TransformersLocalLLMClient,
    VLLMServerLLMClient,
)

from ._catalog import ModelCatalog
from ._hardware import HardwareProfile
from ._policy import ModelSelectionPolicy
from ._types import (
    ModelSelectionConstraints,
    ModelSelectionDecision,
    ModelSelectionIntent,
    ModelSelectionPolicyConfig,
)

type Priority = Literal["quality", "balanced", "speed"]
type SelectionOutput = Literal["client", "decision", "client_config"]
type LocalClientResolver = Callable[[ModelSelectionDecision], dict[str, object]]

_CLIENT_CLASSES: dict[str, type[object]] = {
    "AzureOpenAIServiceLLMClient": AzureOpenAIServiceLLMClient,
    "LlamaCppServerLLMClient": LlamaCppServerLLMClient,
    "OpenAIServiceLLMClient": OpenAIServiceLLMClient,
    "OpenAICompatibleHTTPLLMClient": OpenAICompatibleHTTPLLMClient,
    "TransformersLocalLLMClient": TransformersLocalLLMClient,
    "MLXLocalLLMClient": MLXLocalLLMClient,
    "VLLMServerLLMClient": VLLMServerLLMClient,
    "OllamaLLMClient": OllamaLLMClient,
    "SGLangServerLLMClient": SGLangServerLLMClient,
}


[docs] class ModelSelector: """Flat model selection interface with client/config resolution helpers.""" def __init__( self, *, catalog: ModelCatalog | None = None, prefer_local: bool = True, ram_reserve_gb: float = 2.0, vram_reserve_gb: float = 0.5, max_load_ratio: float = 0.85, remote_cost_floor_usd: float = 0.02, default_max_latency_ms: int | None = None, local_client_resolver: LocalClientResolver | None = None, ) -> None: """Initialize model selector policy controls and optional resolver hook. Args: catalog: Optional model catalog to use for selection. prefer_local: Whether to prefer local models over remote ones when all else is equal. ram_reserve_gb: Amount of RAM (in GB) to reserve when evaluating local candidates. vram_reserve_gb: Amount of GPU VRAM (in GB) to reserve when evaluating local candidates. max_load_ratio: Maximum system load ratio to consider a local candidate viable (0.0 to 1.0). remote_cost_floor_usd: Minimum cost threshold (in USD) for remote models to be considered viable. default_max_latency_ms: Default maximum latency (in milliseconds) to consider when evaluating candidates, if not specified in selection constraints. local_client_resolver: Optional callable that takes a ModelSelectionDecision and returns a dict with 'client_class' and 'kwargs' for constructing a local client when the provider is not recognized by the built-in resolver. This allows for custom local providers to be integrated without modifying the ModelSelector code. """ self._policy = ModelSelectionPolicy( catalog=catalog or ModelCatalog.default(), config=ModelSelectionPolicyConfig( policy_id="default", prefer_local=prefer_local, ram_reserve_gb=ram_reserve_gb, vram_reserve_gb=vram_reserve_gb, max_load_ratio=max_load_ratio, remote_cost_floor_usd=remote_cost_floor_usd, default_max_latency_ms=default_max_latency_ms, ), ) self._local_client_resolver = local_client_resolver
[docs] def select( self, *, task: str, priority: Priority = "balanced", require_local: bool = False, preferred_provider: str | None = None, max_cost_usd: float | None = None, max_latency_ms: int | None = None, hardware_profile: Mapping[str, object] | HardwareProfile | None = None, output: SelectionOutput = "client", ) -> LLMClient | ModelSelectionDecision | dict[str, object]: """Select a model and return a decision, config mapping, or live client. Args: task: Description of the task or use case for which a model is being selected. priority: Selection priority, which may influence the trade-off between quality, latency, and cost in the decision process. require_local: If True, only consider local models as viable candidates. preferred_provider: Optional provider name to prioritize in the selection process. max_cost_usd: Optional maximum cost threshold (in USD) for candidate models. max_latency_ms: Optional maximum latency threshold (in milliseconds) for candidate models. hardware_profile: Optional mapping or HardwareProfile instance describing the current hardware state, which may be used to evaluate local candidates. output: Determines the format of the selection result. "client" returns an instantiated LLMClient ready for use, "decision" returns the raw ModelSelectionDecision object with details of the selection rationale, and "client_config" returns a dict containing the information needed to construct an LLMClient (including 'client_class' and 'kwargs') without actually instantiating it. Returns: Depending on the 'output' parameter: - If output is "client": An instantiated LLMClient configured according to the selection decision, ready for use in making requests. - If output is "decision": A ModelSelectionDecision object containing details about the selected model, provider, rationale, and policy information. - If output is "client_config": A dict containing the resolved client configuration, including 'client_class', 'kwargs', and metadata from the selection decision, which can be used to instantiate an LLMClient at a later time or in a different context. Raises: ValueError: If ``output`` is unsupported or selection/config coercion fails. """ if output not in {"client", "decision", "client_config"}: raise ValueError("output must be one of: 'client', 'decision', 'client_config'.") decision = self._policy.select_model( intent=ModelSelectionIntent(task=task, priority=priority), constraints=ModelSelectionConstraints( require_local=require_local, preferred_provider=preferred_provider, max_cost_usd=max_cost_usd, max_latency_ms=max_latency_ms, ), hardware_profile=_coerce_hardware_profile(hardware_profile), ) if output == "decision": return decision client_config = self._resolve_client_config(decision) if output == "client_config": return client_config # Build concrete client only for "client" output; other modes stay side-effect free. return _build_client_from_config(client_config)
def _resolve_client_config(self, decision: ModelSelectionDecision) -> dict[str, object]: """Resolve one selection decision into a concrete client configuration payload. Args: decision: Selection decision produced by the policy layer. Returns: Configuration payload containing ``client_class`` and ``kwargs`` plus metadata used for downstream tracing and reporting. Raises: ValueError: If resolved ``client_class`` or ``kwargs`` are invalid. """ provider = decision.provider.strip() default_config: dict[str, object] | None = None # Map provider identifiers to concrete client constructors with minimal required kwargs. if provider == "openai": default_config = { "client_class": "OpenAIServiceLLMClient", "kwargs": {"default_model": decision.model_id}, } elif provider in {"azure", "azure_openai", "azure-openai"}: default_config = { "client_class": "AzureOpenAIServiceLLMClient", "kwargs": {"default_model": decision.model_id}, } elif provider in { "openai_compatible_http", "openai-compatible-http", "openai-compatible", }: default_config = { "client_class": "OpenAICompatibleHTTPLLMClient", "kwargs": {"default_model": decision.model_id}, } elif provider == "transformers_local": default_config = { "client_class": "TransformersLocalLLMClient", "kwargs": { "model_id": decision.model_id, "default_model": decision.model_id, }, } elif provider == "mlx_local": default_config = { "client_class": "MLXLocalLLMClient", "kwargs": { "model_id": decision.model_id, "default_model": decision.model_id, }, } elif provider == "vllm_local": default_config = { "client_class": "VLLMServerLLMClient", "kwargs": { "api_model": decision.model_id, "manage_server": False, "base_url": "http://127.0.0.1:8002/v1", }, } elif provider == "ollama_local": default_config = { "client_class": "OllamaLLMClient", "kwargs": { "default_model": decision.model_id, "manage_server": False, }, } elif provider == "sglang_local": default_config = { "client_class": "SGLangServerLLMClient", "kwargs": { "model": decision.model_id, "manage_server": False, "base_url": "http://127.0.0.1:30000/v1", }, } resolved_config: dict[str, object] # Unknown providers are delegated to caller-supplied resolver for extensibility. resolved_config = default_config if default_config is not None else self._resolve_local_client_config(decision) client_class = resolved_config.get("client_class") kwargs = resolved_config.get("kwargs") if not isinstance(client_class, str) or client_class not in _CLIENT_CLASSES: supported = ", ".join(sorted(_CLIENT_CLASSES)) raise ValueError(f"ModelSelector resolver returned unsupported client_class. Expected one of: {supported}.") if not isinstance(kwargs, dict): raise ValueError("ModelSelector resolver returned invalid kwargs (must be a dict).") full_config = dict(resolved_config) full_config.update( { "provider": decision.provider, "model_id": decision.model_id, "client_class": client_class, "kwargs": dict(kwargs), "rationale": decision.rationale, "policy_id": decision.policy_id, "catalog_signature": decision.catalog_signature, } ) return full_config def _resolve_local_client_config(self, decision: ModelSelectionDecision) -> dict[str, object]: """Resolve local-provider client configuration via user-supplied resolver. Args: decision: Selection decision that could not be mapped by built-in providers. Returns: Resolver payload containing ``client_class`` and ``kwargs``. Raises: ValueError: If resolver is missing, returns a non-dict, or misses required keys. """ if self._local_client_resolver is None: raise ValueError( "ModelSelector cannot map selected provider " f"'{decision.provider}' (model '{decision.model_id}') to a client config. " "Provide local_client_resolver returning {'client_class': ..., 'kwargs': {...}}." ) resolved = self._local_client_resolver(decision) if not isinstance(resolved, dict): raise ValueError("local_client_resolver must return a dict payload.") if "client_class" not in resolved or "kwargs" not in resolved: raise ValueError("local_client_resolver result must include 'client_class' and 'kwargs'.") return resolved
def _build_client_from_config(config: dict[str, object]) -> LLMClient: """Instantiate an ``LLMClient`` from a resolved configuration payload. Args: config: Mapping with ``client_class`` and constructor ``kwargs``. Returns: Instantiated ``LLMClient`` for immediate use. Raises: ValueError: If client class name or kwargs payload is invalid. """ client_class = config.get("client_class") kwargs = config.get("kwargs") if not isinstance(client_class, str) or client_class not in _CLIENT_CLASSES: raise ValueError("client_config has unsupported client_class.") if not isinstance(kwargs, dict): raise ValueError("client_config has invalid kwargs (must be dict).") client_ctor = _CLIENT_CLASSES[client_class] client = client_ctor(**kwargs) return cast(LLMClient, client) def _coerce_hardware_profile( value: Mapping[str, object] | HardwareProfile | None, ) -> HardwareProfile | None: """Normalize optional hardware-profile input into ``HardwareProfile``. Args: value: Existing ``HardwareProfile`` instance, mapping, or ``None``. Returns: Normalized hardware profile instance or ``None``. Raises: ValueError: If ``value`` has unsupported type or invalid field shapes. """ if value is None: return None if isinstance(value, HardwareProfile): return value if not isinstance(value, Mapping): raise ValueError("hardware_profile must be a mapping, HardwareProfile, or None.") load_average = _coerce_load_average(value.get("load_average")) return HardwareProfile( total_ram_gb=_coerce_optional_float(value.get("total_ram_gb")), available_ram_gb=_coerce_optional_float(value.get("available_ram_gb")), cpu_count=_coerce_optional_int(value.get("cpu_count")), load_average=load_average, gpu_present=_coerce_optional_bool(value.get("gpu_present")), gpu_vram_gb=_coerce_optional_float(value.get("gpu_vram_gb")), gpu_name=_coerce_optional_str(value.get("gpu_name")), platform_name=_coerce_optional_str(value.get("platform_name")), ) def _coerce_load_average(raw: object) -> tuple[float, float, float] | None: """Normalize optional load-average values to a 3-float tuple. Args: raw: Optional 3-item sequence representing system load averages. Returns: ``None`` when unset, otherwise a ``(1m, 5m, 15m)`` float tuple. Raises: ValueError: If value is not a 3-item sequence. """ if raw is None: return None if not isinstance(raw, (tuple, list)) or len(raw) != 3: raise ValueError("hardware_profile.load_average must be a 3-item sequence when provided.") coerced = tuple(float(item) for item in raw) return cast(tuple[float, float, float], coerced) def _coerce_optional_float(raw: object) -> float | None: """Coerce an optional value into ``float``. Args: raw: Input value to normalize. Returns: ``None`` when unset, otherwise parsed float. Raises: ValueError: If value cannot be interpreted as float. """ if raw is None: return None if isinstance(raw, bool): raise ValueError("Expected float-compatible value, got bool.") if isinstance(raw, (int, float, str, bytes, bytearray)): try: return float(raw) except (TypeError, ValueError) as exc: raise ValueError("Expected float-compatible value.") from exc raise ValueError("Expected float-compatible value.") def _coerce_optional_int(raw: object) -> int | None: """Coerce an optional value into ``int``. Args: raw: Input value to normalize. Returns: ``None`` when unset, otherwise parsed int. Raises: ValueError: If value cannot be interpreted as int. """ if raw is None: return None if isinstance(raw, bool): raise ValueError("Expected int-compatible value, got bool.") if isinstance(raw, (int, float, str, bytes, bytearray)): try: return int(raw) except (TypeError, ValueError) as exc: raise ValueError("Expected int-compatible value.") from exc raise ValueError("Expected int-compatible value.") def _coerce_optional_bool(raw: object) -> bool | None: """Coerce an optional value into ``bool``. Args: raw: Input value to normalize. Returns: ``None`` when unset, otherwise bool value. Raises: ValueError: If provided value is not a bool. """ if raw is None: return None if not isinstance(raw, bool): raise ValueError("Expected bool value when provided.") return raw def _coerce_optional_str(raw: object) -> str | None: """Coerce an optional value into normalized ``str``. Args: raw: Input value to normalize. Returns: ``None`` when unset or blank, otherwise stripped string value. Raises: ValueError: If provided value is not a string. """ if raw is None: return None if not isinstance(raw, str): raise ValueError("Expected str value when provided.") normalized = raw.strip() return normalized or None __all__ = ["ModelSelector"]