Source code for design_research_problems.problems._mcp_problem

"""MCP-backed problem family that proxies upstream MCP tools."""

from __future__ import annotations

import asyncio
import inspect
import sys
from collections.abc import Mapping, Sequence
from contextlib import AsyncExitStack
from dataclasses import dataclass
from types import MappingProxyType
from typing import TYPE_CHECKING, Any, Literal, cast

from design_research_problems._exceptions import ProblemEvaluationError
from design_research_problems._optional import import_optional_module
from design_research_problems.problems._assets import PackageResourceBundle
from design_research_problems.problems._mcp import (
    create_fastmcp_server,
    register_design_brief_resource,
    to_json_value,
)
from design_research_problems.problems._metadata import ProblemMetadata
from design_research_problems.problems._problem import Problem

if TYPE_CHECKING:
    from mcp.server.fastmcp import FastMCP
    from mcp.types import CallToolResult
    from mcp.types import Tool as MCPTool

    from design_research_problems._catalog._manifest import ProblemManifest


@dataclass(frozen=True)
class MCPStdioConfig:
    """Parsed stdio transport parameters for one MCP-backed problem."""

    command: str
    args: tuple[str, ...]
    cwd: str | None
    env: Mapping[str, str]


def _resolve_stdio_command(command: str) -> str:
    """Resolve special stdio command markers to executable paths.

    Args:
        command: Raw command value from constructor or manifest parameters.

    Returns:
        Resolved executable command string.
    """
    if command == "__python_executable__":
        return sys.executable
    return command


def parse_mcp_stdio_parameters(parameters: Mapping[str, object]) -> MCPStdioConfig:
    """Parse MCP stdio configuration from manifest parameters.

    Args:
        parameters: Manifest ``[parameters]`` payload.

    Returns:
        Parsed stdio configuration.

    Raises:
        ValueError: If the payload does not match the expected schema.
    """
    transport_raw = parameters.get("transport")
    if not isinstance(transport_raw, str) or transport_raw.strip().lower() != "stdio":
        raise ValueError("mcp parameters must declare transport='stdio'.")

    command_raw = parameters.get("command")
    if not isinstance(command_raw, str) or not command_raw.strip():
        raise ValueError("mcp stdio parameters require a non-empty string command.")
    command = command_raw.strip()

    args_raw = parameters.get("args", ())
    if isinstance(args_raw, str | bytes) or not isinstance(args_raw, Sequence):
        raise ValueError("mcp stdio parameters require args as a sequence of strings.")
    parsed_args: list[str] = []
    for arg in args_raw:
        if not isinstance(arg, str):
            raise ValueError("mcp stdio args entries must be strings.")
        parsed_args.append(arg)

    cwd_raw = parameters.get("cwd")
    cwd: str | None = None
    if cwd_raw is not None:
        if not isinstance(cwd_raw, str) or not cwd_raw.strip():
            raise ValueError("mcp stdio cwd must be a non-empty string when provided.")
        cwd = cwd_raw.strip()

    env_raw = parameters.get("env", {})
    if env_raw is None:
        env_raw = {}
    if not isinstance(env_raw, Mapping):
        raise ValueError("mcp stdio env must be a mapping of string keys and values.")
    env: dict[str, str] = {}
    for key, value in env_raw.items():
        if not isinstance(key, str) or not isinstance(value, str):
            raise ValueError("mcp stdio env entries must use string keys and string values.")
        env[key] = value

    return MCPStdioConfig(
        command=command,
        args=tuple(parsed_args),
        cwd=cwd,
        env=MappingProxyType(env),
    )


class _LoopBoundUpstreamSession:
    """Lazy persistent upstream session bound to one running event loop."""

    def __init__(self, stdio_config: MCPStdioConfig) -> None:
        self._stdio_config = stdio_config
        self._bound_loop: asyncio.AbstractEventLoop | None = None
        self._loop_lock: asyncio.Lock | None = None
        self._exit_stack: AsyncExitStack | None = None
        self._session: object | None = None
        self._closed = False

    @property
    def has_active_session(self) -> bool:
        """Return whether an upstream session has been initialized."""
        return self._session is not None

    async def call_tool(self, name: str, arguments: Mapping[str, object]) -> CallToolResult:
        """Call one upstream tool through the persistent session.

        Args:
            name: Upstream MCP tool name.
            arguments: Tool arguments.

        Returns:
            Upstream call result object.
        """
        session = await self._ensure_session()
        session_obj = cast(Any, session)
        return cast("CallToolResult", await session_obj.call_tool(name, dict(arguments)))

    async def aclose(self) -> None:
        """Close the persistent upstream session on its bound event loop.

        Raises:
            RuntimeError: If called from a different event loop than the session
                that initialized the upstream client.
        """
        current_loop = asyncio.get_running_loop()
        if self._bound_loop is not None and self._bound_loop is not current_loop:
            raise RuntimeError(
                "This wrapped MCP server is bound to a different event loop. "
                "Call await server.aclose_upstream_session() from the same event loop used for tool calls."
            )
        if self._loop_lock is None:
            self._loop_lock = asyncio.Lock()

        async with self._loop_lock:
            exit_stack = self._exit_stack
            self._session = None
            self._exit_stack = None
            self._closed = True

        if exit_stack is not None:
            await exit_stack.aclose()

    async def _ensure_session(self) -> object:
        """Create or return the loop-bound upstream session.

        Returns:
            Initialized ``ClientSession`` object.

        Raises:
            RuntimeError: If reused across different event loops.
        """
        if self._closed:
            raise RuntimeError(
                "The upstream MCP session has been closed. "
                "Create a new wrapped server by calling to_mcp_server() again."
            )
        current_loop = asyncio.get_running_loop()
        if self._bound_loop is None:
            self._bound_loop = current_loop
        elif self._bound_loop is not current_loop:
            raise RuntimeError(
                "This wrapped MCP server is bound to a different event loop. "
                "Create a new server by calling to_mcp_server() again."
            )

        if self._loop_lock is None:
            self._loop_lock = asyncio.Lock()

        async with self._loop_lock:
            if self._session is not None:
                return self._session

            _anyio_module, client_session_cls, stdio_module = _import_mcp_client_modules()
            stdio_module_any = cast(Any, stdio_module)
            client_session_cls_any = cast(Any, client_session_cls)
            stdio_client = stdio_module_any.stdio_client
            params = stdio_module_any.StdioServerParameters(
                command=self._stdio_config.command,
                args=list(self._stdio_config.args),
                cwd=self._stdio_config.cwd,
                env=dict(self._stdio_config.env) if self._stdio_config.env else None,
            )

            exit_stack = AsyncExitStack()
            read_stream, write_stream = await exit_stack.enter_async_context(stdio_client(params))
            session = await exit_stack.enter_async_context(client_session_cls_any(read_stream, write_stream))
            await cast(Any, session).initialize()
            self._session = session
            self._exit_stack = exit_stack
            return session


[docs] class MCPProblem(Problem): """Problem wrapper that ingests one external MCP server over stdio.""" def __init__( self, *, metadata: ProblemMetadata, statement_markdown: str = "", command: str, args: tuple[str, ...] = (), cwd: str | None = None, env: Mapping[str, str] | None = None, resource_bundle: PackageResourceBundle | None = None, ) -> None: """Store metadata and upstream stdio server launch configuration. Args: metadata: Shared packaged metadata. statement_markdown: Canonical Markdown statement. command: Executable command used to start the upstream MCP server. args: Command-line arguments passed to the upstream server. cwd: Optional working directory for the subprocess. env: Optional environment variables merged over inherited defaults. resource_bundle: Optional package-resource loader for problem assets. """ super().__init__( metadata=metadata, statement_markdown=statement_markdown, resource_bundle=resource_bundle, ) self.transport: Literal["stdio"] = "stdio" self.command = _resolve_stdio_command(command) self.args = tuple(args) self.cwd = cwd self.env = MappingProxyType(dict(env or {})) self._stdio_config = MCPStdioConfig( command=self.command, args=self.args, cwd=self.cwd, env=self.env, )
[docs] @classmethod def from_stdio( cls, *, metadata: ProblemMetadata, command: str, args: Sequence[str] = (), cwd: str | None = None, env: Mapping[str, str] | None = None, statement_markdown: str = "", resource_bundle: PackageResourceBundle | None = None, ) -> MCPProblem: """Construct one MCP-backed problem directly from stdio launch parameters. Args: metadata: Shared packaged metadata. command: Executable command used to start the upstream MCP server. args: Command-line arguments passed to the upstream server. cwd: Optional working directory for the subprocess. env: Optional environment variables merged over inherited defaults. statement_markdown: Canonical Markdown statement. resource_bundle: Optional package-resource loader for problem assets. Returns: MCP problem instance with validated stdio configuration. """ parsed = parse_mcp_stdio_parameters( { "transport": "stdio", "command": command, "args": list(args), "cwd": cwd, "env": dict(env or {}), } ) return cls( metadata=metadata, statement_markdown=statement_markdown, command=parsed.command, args=parsed.args, cwd=parsed.cwd, env=parsed.env, resource_bundle=resource_bundle, )
[docs] @classmethod def from_manifest(cls, manifest: ProblemManifest) -> MCPProblem: """Construct one MCP-backed problem directly from a packaged manifest. Args: manifest: Parsed manifest used to initialize the problem instance. Returns: Problem instance populated from the manifest data. Raises: ProblemEvaluationError: If the manifest parameters are invalid. """ try: config = parse_mcp_stdio_parameters(manifest.parameters) except ValueError as exc: raise ProblemEvaluationError( f"Invalid MCP stdio parameters for {manifest.metadata.problem_id!r}: {exc}" ) from exc return cls( metadata=manifest.metadata, statement_markdown=manifest.statement_markdown, command=config.command, args=config.args, cwd=config.cwd, env=config.env, resource_bundle=cls.resource_bundle_from_manifest(manifest), )
[docs] def to_mcp_server( self, *, server_name: str | None = None, include_citation: bool = True, citation_mode: Literal["summary", "summary+raw", "raw"] = "summary", ) -> FastMCP: """Expose this MCP-backed problem through a local FastMCP proxy server. The exported server exposes the standard ``problem://design-brief`` resource and proxies upstream MCP tools one-for-one. Args: server_name: Optional explicit server name. include_citation: Whether the design brief includes citations. citation_mode: Citation rendering mode for the design brief. Returns: Configured FastMCP server. Raises: ProblemEvaluationError: If discovered upstream tool schemas are not compatible with keyword-argument proxy wrapping. """ server = create_fastmcp_server(self, server_name=server_name) register_design_brief_resource( server, brief_text=self.render_brief(include_citation=include_citation, citation_mode=citation_mode), ) upstream_tools = self._discover_upstream_tools() if len({tool.name for tool in upstream_tools}) != len(upstream_tools): raise ProblemEvaluationError("Upstream MCP tool names must be unique.") persistent_session = _LoopBoundUpstreamSession(self._stdio_config) upstream_tool_names = {tool.name for tool in upstream_tools} expose_submit_final = "submit_final" in upstream_tool_names or "final_answer" in upstream_tool_names exposed_names: set[str] = set() for tool in upstream_tools: proxy_tool = self._build_proxy_tool( tool_name=tool.name, input_schema=tool.inputSchema, session=persistent_session ) description = (tool.description or "").strip() or f"Proxy call to upstream MCP tool {tool.name!r}." exposed_name = ( "submit_final" if tool.name == "final_answer" and "submit_final" not in upstream_tool_names else tool.name ) if exposed_name in exposed_names: raise ProblemEvaluationError(f"Duplicate exposed MCP tool name: {exposed_name!r}") exposed_names.add(exposed_name) server.add_tool( proxy_tool, name=exposed_name, title=tool.title, description=description, ) if not expose_submit_final: def submit_final(answer: str, justification: str | None = None) -> dict[str, object]: """Submit one free-text final answer. Args: answer: Free-text final answer string. justification: Optional rationale for the submission. Returns: MCP-ready submission payload for the provided answer. Raises: ValueError: If the answer is empty after trimming whitespace. """ normalized_answer = answer.strip() if not normalized_answer: raise ValueError("answer must be a non-empty string.") return { "problem_id": self.metadata.problem_id, "problem_kind": self.metadata.kind.value, "answer": normalized_answer, "justification": None if justification is None else justification.strip() or None, } server.add_tool( submit_final, name="submit_final", title="Submit Final Answer", description="Submit a free-text final answer for this design brief.", ) async def aclose_upstream_session() -> None: """Close the proxied upstream MCP session for this server instance.""" await persistent_session.aclose() def close_upstream_session() -> None: """Close the proxied session when no upstream session has been opened. If tool calls have already initialized the upstream session, callers must use ``await server.aclose_upstream_session()`` on the same event loop used for those tool calls. """ try: asyncio.get_running_loop() except RuntimeError: if persistent_session.has_active_session: raise RuntimeError( "An upstream session is active. " "Use await server.aclose_upstream_session() from the same event loop used for tool calls." ) from None return raise RuntimeError( "close_upstream_session() cannot be called from an active event loop. " "Use await server.aclose_upstream_session() instead." ) server_any = cast(Any, server) server_any.aclose_upstream_session = aclose_upstream_session server_any.close_upstream_session = close_upstream_session return server
def _discover_upstream_tools(self) -> tuple[MCPTool, ...]: """List upstream MCP tools using one short-lived discovery session. Returns: Discovered upstream tool definitions. """ anyio_module, client_session_cls, stdio_module = _import_mcp_client_modules() stdio_module_any = cast(Any, stdio_module) client_session_cls_any = cast(Any, client_session_cls) stdio_client = stdio_module_any.stdio_client params = stdio_module_any.StdioServerParameters( command=self._stdio_config.command, args=list(self._stdio_config.args), cwd=self._stdio_config.cwd, env=dict(self._stdio_config.env) if self._stdio_config.env else None, ) async def _discover() -> tuple[MCPTool, ...]: async with stdio_client(params) as streams: read_stream, write_stream = streams async with client_session_cls_any(read_stream, write_stream) as session: await cast(Any, session).initialize() listed = await cast(Any, session).list_tools() return tuple(cast(Any, listed).tools) try: asyncio.get_running_loop() except RuntimeError: pass else: raise RuntimeError( "MCP discovery must run outside an active event loop. Call to_mcp_server() from synchronous setup code." ) return tuple(cast(Any, anyio_module).run(_discover)) def _build_proxy_tool( self, *, tool_name: str, input_schema: Mapping[str, object], session: _LoopBoundUpstreamSession, ) -> Any: """Build one proxy tool wrapper from an upstream schema. Args: tool_name: Upstream tool name. input_schema: Upstream JSON schema for tool inputs. session: Shared persistent upstream session handle. Returns: Proxy function suitable for ``server.add_tool``. """ signature = _proxy_signature(tool_name=tool_name, input_schema=input_schema) optional_none_keys = tuple( parameter.name for parameter in signature.parameters.values() if parameter.default is None ) async def proxy_tool(**kwargs: object) -> dict[str, object]: """Call the upstream tool and normalize the response payload.""" forwarded = dict(kwargs) for key in optional_none_keys: if forwarded.get(key) is None: forwarded.pop(key, None) result = await session.call_tool(tool_name, forwarded) if result.isError: message = _upstream_error_message(result) raise ValueError(f"Upstream MCP tool {tool_name!r} failed: {message}") structured = result.structuredContent if isinstance(structured, dict): normalized = to_json_value(structured) return cast(dict[str, object], normalized) return { "tool_name": tool_name, "structured_content": to_json_value(structured), "content": [_serialize_content_block(block) for block in result.content], "is_error": False, } proxy_tool.__name__ = f"proxy_{_safe_identifier(tool_name)}" proxy_tool_any = cast(Any, proxy_tool) proxy_tool_any.__signature__ = signature return proxy_tool
def _import_mcp_client_modules() -> tuple[Any, Any, Any]: """Import MCP client modules lazily. Returns: Tuple of imported modules/classes for anyio + MCP client stdio support. Raises: MissingOptionalDependencyError: If optional MCP dependencies are missing. """ anyio_module = import_optional_module( "anyio", required_for="MCP server export and ingestion proxying", extras=("mcp",), dependency_label="anyio", ) session_module = import_optional_module( "mcp.client.session", required_for="MCP server export and ingestion proxying", extras=("mcp",), dependency_label="mcp", ) stdio_module = import_optional_module( "mcp.client.stdio", required_for="MCP server export and ingestion proxying", extras=("mcp",), dependency_label="mcp", ) return anyio_module, session_module.ClientSession, stdio_module def _proxy_signature(*, tool_name: str, input_schema: Mapping[str, object]) -> inspect.Signature: """Build an inspect signature from one upstream MCP tool schema. Args: tool_name: Upstream tool name. input_schema: Upstream JSON schema payload. Returns: Keyword-only signature for proxy tool registration. Raises: ProblemEvaluationError: If the schema is unsupported. """ schema_type = input_schema.get("type") if schema_type not in (None, "object"): raise ProblemEvaluationError( f"Unsupported input schema for upstream tool {tool_name!r}: " "expected an object schema with named properties." ) properties = input_schema.get("properties") if not isinstance(properties, Mapping): raise ProblemEvaluationError( f"Unsupported input schema for upstream tool {tool_name!r}: missing object-style named properties." ) required_raw = input_schema.get("required", ()) if required_raw is None: required_raw = () if isinstance(required_raw, str | bytes) or not isinstance(required_raw, Sequence): raise ProblemEvaluationError( f"Unsupported input schema for upstream tool {tool_name!r}: required must be a sequence of property names." ) required_names: set[str] = set() for item in required_raw: if not isinstance(item, str): raise ProblemEvaluationError( f"Unsupported input schema for upstream tool {tool_name!r}: required entries must be strings." ) required_names.add(item) property_names = {str(name) for name in properties} unknown_required = required_names - property_names if unknown_required: raise ProblemEvaluationError( f"Unsupported input schema for upstream tool {tool_name!r}: " f"required contains unknown properties {sorted(unknown_required)!r}." ) parameters: list[inspect.Parameter] = [] for raw_name, schema in properties.items(): if not isinstance(raw_name, str): raise ProblemEvaluationError( f"Unsupported input schema for upstream tool {tool_name!r}: property names must be strings." ) if not raw_name.isidentifier(): raise ProblemEvaluationError( f"Unsupported input schema for upstream tool {tool_name!r}: " f"property name {raw_name!r} is not a valid Python identifier." ) annotation = _annotation_from_schema(schema) if raw_name in required_names: default = inspect._empty else: annotation = annotation | None default = None parameters.append( inspect.Parameter( raw_name, kind=inspect.Parameter.KEYWORD_ONLY, annotation=annotation, default=default, ) ) return inspect.Signature(parameters=parameters, return_annotation=dict[str, object]) def _annotation_from_schema(schema: object) -> Any: """Map one JSON-schema field descriptor to a Python annotation.""" if not isinstance(schema, Mapping): return object schema_type = schema.get("type") if isinstance(schema_type, list): non_null = [entry for entry in schema_type if entry != "null"] schema_type = non_null[0] if non_null else None if schema_type == "string": return str if schema_type == "number": return float if schema_type == "integer": return int if schema_type == "boolean": return bool if schema_type == "array": return list[object] if schema_type == "object": return dict[str, object] return object def _serialize_content_block(block: object) -> object: """Convert one upstream MCP content block into JSON-safe data.""" dump_method = getattr(block, "model_dump", None) if callable(dump_method): return to_json_value(dump_method(mode="json")) return to_json_value(block) def _upstream_error_message(result: CallToolResult) -> str: """Extract a readable upstream error message from one MCP result.""" structured = result.structuredContent if isinstance(structured, Mapping): for key in ("message", "error", "detail"): value = structured.get(key) if isinstance(value, str) and value.strip(): return value.strip() for block in result.content: text = getattr(block, "text", None) if isinstance(text, str) and text.strip(): return text.strip() return "upstream tool returned an error result without details." def _safe_identifier(name: str) -> str: """Convert one arbitrary string into a Python identifier-like token.""" token = "".join(char if (char.isalnum() or char == "_") else "_" for char in name) return token or "tool" __all__ = ["MCPProblem", "MCPStdioConfig", "parse_mcp_stdio_parameters"]