"""Provider-specific LLM client classes with constructor-first defaults."""
from __future__ import annotations
import json
import sys
from collections.abc import Iterator, Mapping, Sequence
from hashlib import sha256
from types import TracebackType
from typing import Self
from design_research_agents._contracts._llm import (
BackendCapabilities,
LLMChatParams,
LLMClient,
LLMDelta,
LLMMessage,
LLMRequest,
LLMResponse,
LLMStreamEvent,
)
from design_research_agents._tracing import (
emit_model_request_observed,
emit_model_response_observed,
)
from .._backends._base import BaseLLMBackend
from .._backends._providers._llama_cpp import LlamaCppBackend
from .._backends._providers._llama_cpp_server import (
create_backend as create_llama_cpp_server,
)
from .._backends._providers._mlx_local import MlxLocalBackend
from .._backends._providers._ollama_local import OllamaLocalBackend
from .._backends._providers._ollama_server import (
create_backend as create_ollama_server,
)
from .._backends._providers._sglang_local import SglangLocalBackend
from .._backends._providers._sglang_server import (
create_backend as create_sglang_server,
)
from .._backends._providers._transformers_local import TransformersLocalBackend
from .._backends._providers._vllm_local import VllmLocalBackend
from .._backends._providers._vllm_server import (
create_backend as create_vllm_server,
)
from ._managed_port_reservations import (
_reserve_managed_server_port,
)
from ._snapshot_helpers import (
capabilities_to_dict,
llama_config_snapshot,
llama_server_snapshot,
mlx_config_snapshot,
normalize_snapshot_mapping,
ollama_config_snapshot,
ollama_server_snapshot,
sglang_config_snapshot,
sglang_server_snapshot,
transformers_config_snapshot,
vllm_config_snapshot,
vllm_server_snapshot,
)
class _SingleBackendLLMClient(LLMClient):
"""LLM client wrapper that delegates to one concrete backend."""
def __init__(
self,
*,
backend: BaseLLMBackend,
config_snapshot: Mapping[str, object] | None = None,
server_snapshot: Mapping[str, object] | None = None,
) -> None:
"""Initialize the client with a configured backend instance.
Args:
backend: The LLM backend instance to delegate calls to. This should be fully configured
and ready to use, as the client will not perform any additional setup.
config_snapshot: Additional client/backend config metadata fields.
server_snapshot: Managed server metadata fields when server lifecycle is owned.
"""
self._backend = backend
snapshot: dict[str, object] = {
"name": backend.name,
"kind": backend.kind,
"default_model": backend.default_model,
"base_url": backend.base_url,
"config_hash": backend.config_hash,
"max_retries": backend.max_retries,
"model_patterns": list(backend.model_patterns),
}
if config_snapshot is not None:
snapshot.update(normalize_snapshot_mapping(config_snapshot))
# Snapshot fields are intentionally plain JSON-compatible scalars for deterministic example output.
self._config_snapshot = snapshot
self._server_snapshot = normalize_snapshot_mapping(server_snapshot) if server_snapshot is not None else None
def close(self) -> None:
"""Release client-owned resources.
The shared default is a no-op so non-managed clients still support a
uniform lifecycle contract. Managed clients override this method.
"""
return None
def __enter__(self) -> Self:
"""Return this client for ``with``-statement usage."""
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc: BaseException | None,
tb: TracebackType | None,
) -> None:
"""Close the client on ``with``-statement exit."""
del exc_type, exc, tb
self.close()
return None
def generate(self, request: LLMRequest) -> LLMResponse:
"""Generate one response using the configured backend.
Args:
request: Provider-neutral request payload.
Returns:
Normalized completion response.
"""
emit_model_request_observed(
source="_SingleBackendLLMClient.generate",
model=request.model,
request_payload=request,
metadata={"backend_name": self._backend.name, "backend_kind": self._backend.kind},
)
try:
response = self._backend.generate(request)
except Exception as exc:
emit_model_response_observed(
source="_SingleBackendLLMClient.generate",
error=str(exc),
metadata={"backend_name": self._backend.name, "backend_kind": self._backend.kind},
)
raise
emit_model_response_observed(
source="_SingleBackendLLMClient.generate",
response_payload=response,
metadata={"backend_name": self._backend.name, "backend_kind": self._backend.kind},
)
return response
def chat(
self,
messages: Sequence[LLMMessage],
*,
model: str,
params: LLMChatParams,
) -> LLMResponse:
"""Build and execute a request-object call from chat-style inputs.
Args:
messages: Ordered chat messages for this completion request.
model: Target model identifier.
params: Generation controls (temperature, max tokens, schema, options).
Returns:
Normalized completion response.
"""
request = LLMRequest(
messages=messages,
model=model,
temperature=params.temperature,
max_tokens=params.max_tokens,
tools=(),
response_schema=params.response_schema,
response_format=None,
metadata={},
provider_options=dict(params.provider_options),
task_profile=None,
)
# Reuse generate() so chat() inherits shared tracing and error-observation behavior.
return self.generate(request)
def stream(self, request: LLMRequest) -> Iterator[LLMDelta]:
"""Stream response deltas for one request.
Args:
request: Provider-neutral request payload.
Returns:
Iterator yielding normalized response deltas.
"""
emit_model_request_observed(
source="_SingleBackendLLMClient.stream",
model=request.model,
request_payload=request,
metadata={"backend_name": self._backend.name, "backend_kind": self._backend.kind},
)
stream = self._backend.stream(request)
def _observed_stream() -> Iterator[LLMDelta]:
delta_count = 0
text_delta_count = 0
try:
for delta in stream:
delta_count += 1
if delta.text_delta:
text_delta_count += 1
yield delta
except Exception as exc:
emit_model_response_observed(
source="_SingleBackendLLMClient.stream",
error=str(exc),
metadata={
"backend_name": self._backend.name,
"backend_kind": self._backend.kind,
"delta_count": delta_count,
"text_delta_count": text_delta_count,
},
)
raise
emit_model_response_observed(
source="_SingleBackendLLMClient.stream",
response_payload=getattr(stream, "response", None),
metadata={
"backend_name": self._backend.name,
"backend_kind": self._backend.kind,
"delta_count": delta_count,
"text_delta_count": text_delta_count,
},
)
return _observed_stream()
def stream_chat(
self,
messages: Sequence[LLMMessage],
*,
model: str,
params: LLMChatParams,
) -> Iterator[LLMStreamEvent]:
"""Build and execute a streaming request from chat-style inputs.
Args:
messages: Ordered chat messages for this completion request.
model: Target model identifier.
params: Generation controls (temperature, max tokens, schema, options).
Yields:
Streaming events containing deltas and one final completed response.
"""
request = LLMRequest(
messages=messages,
model=model,
temperature=params.temperature,
max_tokens=params.max_tokens,
tools=(),
response_schema=params.response_schema,
response_format=None,
metadata={},
provider_options=dict(params.provider_options),
task_profile=None,
)
stream = self.stream(request)
full_text = ""
for delta in stream:
if delta.text_delta:
full_text += delta.text_delta
yield LLMStreamEvent(kind="delta", delta_text=delta.text_delta)
completed = getattr(stream, "response", None)
if not isinstance(completed, LLMResponse):
# If backend streaming does not expose a final response object, synthesize one from deltas.
completed = LLMResponse(
text=full_text,
model=model,
provider=None,
finish_reason=None,
usage=None,
latency_ms=None,
raw_output=None,
tool_calls=(),
raw=None,
provenance=None,
)
yield LLMStreamEvent(kind="completed", response=completed)
def default_model(self) -> str:
"""Return the configured backend default model.
Returns:
Non-empty model identifier configured as backend default.
Raises:
ValueError: If the backend default model is missing or blank.
"""
default_model = self._backend.default_model
if not isinstance(default_model, str) or not default_model.strip():
raise ValueError("LLM backend default_model is not configured.")
return default_model
def capabilities(self) -> BackendCapabilities:
"""Return declared backend capabilities for this client.
Returns:
Backend capability payload.
"""
return self._backend.capabilities()
def config_snapshot(self) -> dict[str, object]:
"""Return stable client/backend configuration metadata.
Returns:
Snapshot dictionary safe for diagnostics and example output.
"""
return normalize_snapshot_mapping(self._config_snapshot)
def server_snapshot(self) -> dict[str, object] | None:
"""Return managed-server metadata when this client owns a server.
Returns:
Server snapshot dictionary, or ``None`` when not managed.
"""
if self._server_snapshot is None:
return None
return normalize_snapshot_mapping(self._server_snapshot)
def describe(self) -> dict[str, object]:
"""Return composed client configuration and capability summary.
Returns:
JSON-serializable description of client runtime characteristics.
"""
return {
"client_class": self.__class__.__name__,
"default_model": self.default_model(),
"backend": self.config_snapshot(),
"capabilities": capabilities_to_dict(self.capabilities()),
"server": self.server_snapshot(),
}
[docs]
class LlamaCppServerLLMClient(_SingleBackendLLMClient):
"""Client for a managed local ``llama_cpp.server`` backend."""
def __init__(
self,
*,
name: str = "llama-local",
model: str = "Qwen2.5-1.5B-Instruct-Q4_K_M.gguf",
hf_model_repo_id: str | None = "bartowski/Qwen2.5-1.5B-Instruct-GGUF",
api_model: str = "qwen2-1.5b-q4",
host: str = "127.0.0.1",
port: int = 8001,
context_window: int = 4096,
startup_timeout_seconds: float = 60.0,
request_timeout_seconds: float = 60.0,
poll_interval_seconds: float = 0.25,
python_executable: str = sys.executable,
extra_server_args: tuple[str, ...] = (),
max_retries: int = 2,
model_patterns: tuple[str, ...] | None = None,
) -> None:
"""Initialize a local llama-cpp client with sensible defaults.
Args:
name: Logical name for this client instance, used in logging and provenance.
model: Local model identifier or path for llama_cpp.server to load.
hf_model_repo_id: Optional Hugging Face repo ID to auto-download the model from if not
found locally.
api_model: The model name to report in API responses, which can differ from the local
model name.
host: Host interface for the local server to bind to.
port: Port for the local server to listen on.
context_window: Context window size (n_ctx) to configure the llama_cpp.server with.
startup_timeout_seconds: Max time to wait for the server process to start and become
healthy.
request_timeout_seconds: HTTP timeout for generate and stream requests.
poll_interval_seconds: Time interval between health check polls during startup.
python_executable: Python executable to use for running the server process.
extra_server_args: Additional command-line arguments to pass when starting the server
process.
max_retries: Number of times to retry a request in case of failure before giving up.
model_patterns: Optional tuple of model name patterns supported by this client, used for
routing decisions. If None, defaults to (api_model,).
"""
combined_server_args = ("--n_ctx", str(context_window), *extra_server_args)
reserved_port = _reserve_managed_server_port(host=host, requested_port=port)
resolved_port = reserved_port.port
self._llama_server = create_llama_cpp_server(
model=model,
hf_model_repo_id=hf_model_repo_id,
api_model=api_model,
host=host,
port=resolved_port,
startup_timeout_seconds=startup_timeout_seconds,
poll_interval_seconds=poll_interval_seconds,
python_executable=python_executable,
extra_server_args=combined_server_args,
)
self._llama_server.set_port_reservation(reserved_port.reservation_socket)
config_hash = _config_hash(
{
"kind": "llama_cpp",
"name": name,
"model": model,
"hf_model_repo_id": hf_model_repo_id,
"api_model": api_model,
"host": host,
"port": resolved_port,
"context_window": context_window,
"startup_timeout_seconds": startup_timeout_seconds,
"request_timeout_seconds": request_timeout_seconds,
"poll_interval_seconds": poll_interval_seconds,
"python_executable": python_executable,
"extra_server_args": combined_server_args,
"max_retries": max_retries,
}
)
backend = LlamaCppBackend(
name=name,
llama_backend=self._llama_server,
default_model=api_model,
config_hash=config_hash,
request_timeout_seconds=request_timeout_seconds,
max_retries=max_retries,
model_patterns=_resolve_model_patterns(model_patterns, api_model),
)
super().__init__(
backend=backend,
config_snapshot=llama_config_snapshot(
model=model,
hf_model_repo_id=hf_model_repo_id,
api_model=api_model,
context_window=context_window,
startup_timeout_seconds=startup_timeout_seconds,
request_timeout_seconds=request_timeout_seconds,
poll_interval_seconds=poll_interval_seconds,
python_executable=python_executable,
extra_server_args=combined_server_args,
),
server_snapshot=llama_server_snapshot(server=self._llama_server),
)
[docs]
def close(self) -> None:
"""Stop the managed local server process."""
self._llama_server.close()
def __del__(self) -> None: # pragma: no cover - defensive cleanup.
"""Best-effort cleanup for managed server process during GC."""
self.close()
[docs]
class VLLMServerLLMClient(_SingleBackendLLMClient):
"""Client for local or self-hosted vLLM OpenAI-compatible inference."""
def __init__(
self,
*,
name: str = "vllm-local",
model: str = "Qwen/Qwen2.5-1.5B-Instruct",
api_model: str = "qwen2.5-1.5b-instruct",
host: str = "127.0.0.1",
port: int = 8002,
manage_server: bool = True,
startup_timeout_seconds: float = 90.0,
poll_interval_seconds: float = 0.5,
python_executable: str = sys.executable,
extra_server_args: tuple[str, ...] = (),
base_url: str | None = None,
request_timeout_seconds: float = 60.0,
max_retries: int = 2,
model_patterns: tuple[str, ...] | None = None,
) -> None:
"""Initialize a vLLM client in managed-server or connect mode.
Args:
name: Logical name for this client instance.
model: Model identifier passed to managed vLLM server startup.
api_model: Model alias exposed by vLLM OpenAI-compatible API.
host: Host interface used in managed mode.
port: TCP port used in managed mode.
manage_server: Whether this client manages the vLLM server lifecycle.
startup_timeout_seconds: Maximum startup wait time in managed mode.
poll_interval_seconds: Delay between readiness probes in managed mode.
python_executable: Python executable used to launch managed vLLM process.
extra_server_args: Additional CLI flags forwarded to vLLM server.
base_url: Optional connect-mode endpoint URL. Required only for remote/self-managed
deployments; defaults to ``http://{host}:{port}/v1``.
request_timeout_seconds: HTTP timeout for generate and stream requests.
max_retries: Number of retries for retryable provider/transport errors.
model_patterns: Optional tuple of model patterns for routing decisions.
Raises:
ValueError: If ``manage_server`` and ``base_url`` are both configured.
"""
if manage_server and base_url is not None:
raise ValueError("base_url cannot be provided when manage_server is True.")
reserved_vllm_port = _reserve_managed_server_port(host=host, requested_port=port) if manage_server else None
resolved_port = reserved_vllm_port.port if reserved_vllm_port is not None else port
self._vllm_server = (
create_vllm_server(
model=model,
api_model=api_model,
host=host,
port=resolved_port,
startup_timeout_seconds=startup_timeout_seconds,
poll_interval_seconds=poll_interval_seconds,
python_executable=python_executable,
extra_server_args=extra_server_args,
)
if manage_server
else None
)
if self._vllm_server is not None:
reservation_socket = reserved_vllm_port.reservation_socket if reserved_vllm_port is not None else None
self._vllm_server.set_port_reservation(reservation_socket)
resolved_base_url = (
self._vllm_server.base_url if self._vllm_server is not None else (base_url or f"http://{host}:{port}/v1")
)
config_hash = _config_hash(
{
"kind": "vllm_local",
"name": name,
"model": model,
"api_model": api_model,
"host": host,
"port": resolved_port,
"manage_server": manage_server,
"startup_timeout_seconds": startup_timeout_seconds,
"poll_interval_seconds": poll_interval_seconds,
"python_executable": python_executable,
"extra_server_args": extra_server_args,
"base_url": base_url,
"request_timeout_seconds": request_timeout_seconds,
"max_retries": max_retries,
}
)
backend = VllmLocalBackend(
name=name,
base_url=resolved_base_url,
default_model=api_model,
request_timeout_seconds=request_timeout_seconds,
managed_server=self._vllm_server,
config_hash=config_hash,
max_retries=max_retries,
model_patterns=_resolve_model_patterns(model_patterns, api_model),
)
super().__init__(
backend=backend,
config_snapshot=vllm_config_snapshot(
model=model,
api_model=api_model,
host=host,
port=resolved_port,
manage_server=manage_server,
startup_timeout_seconds=startup_timeout_seconds,
poll_interval_seconds=poll_interval_seconds,
python_executable=python_executable,
extra_server_args=extra_server_args,
request_timeout_seconds=request_timeout_seconds,
),
server_snapshot=vllm_server_snapshot(server=self._vllm_server),
)
[docs]
def close(self) -> None:
"""Stop the managed vLLM server process when present."""
server = getattr(self, "_vllm_server", None)
if server is not None:
server.close()
def __del__(self) -> None: # pragma: no cover - defensive cleanup.
"""Best-effort managed server cleanup during garbage collection."""
self.close()
[docs]
class OllamaLLMClient(_SingleBackendLLMClient):
"""Client for local or self-hosted Ollama chat inference."""
def __init__(
self,
*,
name: str = "ollama-local",
default_model: str = "qwen2.5:1.5b-instruct",
host: str = "127.0.0.1",
port: int = 11434,
manage_server: bool = True,
ollama_executable: str = "ollama",
auto_pull_model: bool = False,
startup_timeout_seconds: float = 60.0,
poll_interval_seconds: float = 0.25,
request_timeout_seconds: float = 60.0,
max_retries: int = 2,
model_patterns: tuple[str, ...] | None = None,
) -> None:
"""Initialize an Ollama client in managed-server or connect mode.
Args:
name: Logical name for this client instance.
default_model: Default model id used when requests omit model.
host: Host interface used in managed mode or connect mode.
port: TCP port used in managed mode or connect mode.
manage_server: Whether this client manages ``ollama serve`` lifecycle.
ollama_executable: Executable used to invoke ``ollama`` commands.
auto_pull_model: Whether to pull ``default_model`` after startup.
startup_timeout_seconds: Maximum startup wait time in managed mode.
poll_interval_seconds: Delay between readiness probes in managed mode.
request_timeout_seconds: HTTP timeout for generate and stream requests.
max_retries: Number of retries for retryable provider/transport errors.
model_patterns: Optional tuple of model patterns for routing decisions.
"""
reserved_ollama_port = _reserve_managed_server_port(host=host, requested_port=port) if manage_server else None
resolved_port = reserved_ollama_port.port if reserved_ollama_port is not None else port
self._ollama_server = (
create_ollama_server(
host=host,
port=resolved_port,
ollama_executable=ollama_executable,
auto_pull_model=auto_pull_model,
default_model=default_model,
startup_timeout_seconds=startup_timeout_seconds,
poll_interval_seconds=poll_interval_seconds,
)
if manage_server
else None
)
if self._ollama_server is not None:
reservation_socket = reserved_ollama_port.reservation_socket if reserved_ollama_port is not None else None
self._ollama_server.set_port_reservation(reservation_socket)
resolved_base_url = self._ollama_server.base_url if self._ollama_server is not None else f"http://{host}:{port}"
config_hash = _config_hash(
{
"kind": "ollama_local",
"name": name,
"default_model": default_model,
"host": host,
"port": resolved_port,
"manage_server": manage_server,
"ollama_executable": ollama_executable,
"auto_pull_model": auto_pull_model,
"startup_timeout_seconds": startup_timeout_seconds,
"poll_interval_seconds": poll_interval_seconds,
"request_timeout_seconds": request_timeout_seconds,
"max_retries": max_retries,
}
)
backend = OllamaLocalBackend(
name=name,
base_url=resolved_base_url,
default_model=default_model,
request_timeout_seconds=request_timeout_seconds,
managed_server=self._ollama_server,
config_hash=config_hash,
max_retries=max_retries,
model_patterns=_resolve_model_patterns(model_patterns, default_model),
)
super().__init__(
backend=backend,
config_snapshot=ollama_config_snapshot(
host=host,
port=resolved_port,
manage_server=manage_server,
ollama_executable=ollama_executable,
auto_pull_model=auto_pull_model,
startup_timeout_seconds=startup_timeout_seconds,
poll_interval_seconds=poll_interval_seconds,
request_timeout_seconds=request_timeout_seconds,
),
server_snapshot=ollama_server_snapshot(server=self._ollama_server),
)
[docs]
def close(self) -> None:
"""Stop the managed Ollama daemon when present."""
server = getattr(self, "_ollama_server", None)
if server is not None:
server.close()
def __del__(self) -> None: # pragma: no cover - defensive cleanup.
"""Best-effort managed daemon cleanup during garbage collection."""
self.close()
[docs]
class SGLangServerLLMClient(_SingleBackendLLMClient):
"""Client for local or self-hosted SGLang OpenAI-compatible inference."""
def __init__(
self,
*,
name: str = "sglang-local",
model: str = "Qwen/Qwen2.5-1.5B-Instruct",
host: str = "127.0.0.1",
port: int = 30000,
manage_server: bool = True,
startup_timeout_seconds: float = 90.0,
poll_interval_seconds: float = 0.5,
python_executable: str = sys.executable,
extra_server_args: tuple[str, ...] = (),
base_url: str | None = None,
request_timeout_seconds: float = 60.0,
max_retries: int = 2,
model_patterns: tuple[str, ...] | None = None,
) -> None:
"""Initialize an SGLang client in managed-server or connect mode.
Args:
name: Logical name for this client instance.
model: Model identifier passed to managed SGLang server startup.
host: Host interface used in managed mode.
port: TCP port used in managed mode.
manage_server: Whether this client manages the SGLang server lifecycle.
startup_timeout_seconds: Maximum startup wait time in managed mode.
poll_interval_seconds: Delay between readiness probes in managed mode.
python_executable: Python executable used to launch managed SGLang process.
extra_server_args: Additional CLI flags forwarded to SGLang server.
base_url: Optional connect-mode endpoint URL. Required only for remote/self-managed
deployments; defaults to ``http://{host}:{port}/v1``.
request_timeout_seconds: HTTP timeout for generate and stream requests.
max_retries: Number of retries for retryable provider/transport errors.
model_patterns: Optional tuple of model patterns for routing decisions.
Raises:
ValueError: If ``manage_server`` and ``base_url`` are both configured.
"""
if manage_server and base_url is not None:
raise ValueError("base_url cannot be provided when manage_server is True.")
reserved_sglang_port = _reserve_managed_server_port(host=host, requested_port=port) if manage_server else None
resolved_port = reserved_sglang_port.port if reserved_sglang_port is not None else port
self._sglang_server = (
create_sglang_server(
model=model,
host=host,
port=resolved_port,
startup_timeout_seconds=startup_timeout_seconds,
poll_interval_seconds=poll_interval_seconds,
python_executable=python_executable,
extra_server_args=extra_server_args,
)
if manage_server
else None
)
if self._sglang_server is not None:
reservation_socket = reserved_sglang_port.reservation_socket if reserved_sglang_port is not None else None
self._sglang_server.set_port_reservation(reservation_socket)
resolved_base_url = (
self._sglang_server.base_url
if self._sglang_server is not None
else (base_url or f"http://{host}:{port}/v1")
)
config_hash = _config_hash(
{
"kind": "sglang_local",
"name": name,
"model": model,
"host": host,
"port": resolved_port,
"manage_server": manage_server,
"startup_timeout_seconds": startup_timeout_seconds,
"poll_interval_seconds": poll_interval_seconds,
"python_executable": python_executable,
"extra_server_args": extra_server_args,
"base_url": base_url,
"request_timeout_seconds": request_timeout_seconds,
"max_retries": max_retries,
}
)
backend = SglangLocalBackend(
name=name,
base_url=resolved_base_url,
default_model=model,
request_timeout_seconds=request_timeout_seconds,
managed_server=self._sglang_server,
config_hash=config_hash,
max_retries=max_retries,
model_patterns=_resolve_model_patterns(model_patterns, model),
)
super().__init__(
backend=backend,
config_snapshot=sglang_config_snapshot(
model=model,
host=host,
port=resolved_port,
manage_server=manage_server,
startup_timeout_seconds=startup_timeout_seconds,
poll_interval_seconds=poll_interval_seconds,
python_executable=python_executable,
extra_server_args=extra_server_args,
request_timeout_seconds=request_timeout_seconds,
),
server_snapshot=sglang_server_snapshot(server=self._sglang_server),
)
[docs]
def close(self) -> None:
"""Stop the managed SGLang server process when present."""
server = getattr(self, "_sglang_server", None)
if server is not None:
server.close()
def __del__(self) -> None: # pragma: no cover - defensive cleanup.
"""Best-effort managed server cleanup during garbage collection."""
self.close()
[docs]
class MLXLocalLLMClient(_SingleBackendLLMClient):
"""Client for Apple MLX local inference."""
def __init__(
self,
*,
name: str = "mlx-local",
model_id: str = "mlx-community/Qwen2.5-1.5B-Instruct-4bit",
default_model: str = "mlx-community/Qwen2.5-1.5B-Instruct-4bit",
quantization: str = "none",
max_retries: int = 2,
model_patterns: tuple[str, ...] | None = None,
) -> None:
"""Initialize an MLX local client with sensible defaults.
Args:
name: Logical name for this client instance, used in logging and provenance.
model_id: Identifier for the MLX model to load (e.g. "mlx-community
/Qwen2.5-1.5B-Instruct-4bit").
default_model: Default model name for prompts that don't specify one.
quantization: Quantization level to use when loading the model (e.g. "4
-bit", "8-bit", "fp16").
max_retries: Number of times to retry a request in case of failure before giving up
model_patterns: Optional tuple of model name patterns supported by this client, used for
routing decisions. If None, defaults to (default_model,).
"""
config_hash = _config_hash(
{
"kind": "mlx_local",
"name": name,
"model_id": model_id,
"default_model": default_model,
"quantization": quantization,
"max_retries": max_retries,
}
)
backend = MlxLocalBackend(
name=name,
model_id=model_id,
default_model=default_model,
quantization=quantization,
config_hash=config_hash,
max_retries=max_retries,
model_patterns=_resolve_model_patterns(model_patterns, default_model),
)
super().__init__(
backend=backend,
config_snapshot=mlx_config_snapshot(model_id=model_id, quantization=quantization),
)
def _resolve_model_patterns(
model_patterns: tuple[str, ...] | None,
default_model: str,
) -> tuple[str, ...]:
"""Resolve model patterns, defaulting to the configured default model.
Args:
model_patterns: Optional explicit model-match patterns.
default_model: Fallback model used when no patterns are provided.
Returns:
Non-empty tuple of model patterns used by selectors/routers.
"""
if model_patterns is not None:
return model_patterns
return (default_model,)
def _config_hash(config_payload: dict[str, object]) -> str:
"""Create a short stable hash for backend configuration payloads.
Args:
config_payload: JSON-serializable configuration mapping.
Returns:
Deterministic 12-character SHA-256 prefix.
"""
encoded = json.dumps(config_payload, sort_keys=True, default=str).encode("utf-8")
return sha256(encoded).hexdigest()[:12]
__all__ = [
"LlamaCppServerLLMClient",
"MLXLocalLLMClient",
"OllamaLLMClient",
"SGLangServerLLMClient",
"TransformersLocalLLMClient",
"VLLMServerLLMClient",
"_SingleBackendLLMClient",
]