Source code for design_research_agents._implementations._patterns._rag_pattern

"""RAG-first reasoning pattern built from workflow memory primitives."""

from __future__ import annotations

import json
from collections.abc import Mapping

from design_research_agents._contracts._delegate import Delegate, ExecutionResult
from design_research_agents._contracts._memory import MemoryStore
from design_research_agents._contracts._workflow import (
    DelegateStep,
    DelegateTarget,
    MemoryReadStep,
    MemoryWriteStep,
    WorkflowStep,
)
from design_research_agents._runtime._patterns import (
    MODE_RAG,
    build_compiled_pattern_execution,
    build_pattern_execution_result,
    resolve_pattern_run_context,
)
from design_research_agents._tracing import Tracer
from design_research_agents.workflow import CompiledExecution
from design_research_agents.workflow.workflow import Workflow


[docs] class RAGPattern(Delegate): """Reasoning pattern orchestrated as memory read -> reason -> memory write.""" def __init__( self, *, reasoning_delegate: DelegateTarget, memory_store: MemoryStore | None, memory_namespace: str = "default", memory_top_k: int = 5, memory_min_score: float | None = None, write_back: bool = True, tracer: Tracer | None = None, ) -> None: """Initialize RAG reasoning pattern. Args: reasoning_delegate: Delegate object that performs reasoning with retrieved context. memory_store: Memory store used for retrieval and optional write-back. memory_namespace: Namespace partition for reads/writes. memory_top_k: Number of retrieved matches for reasoning context. memory_min_score: Optional minimum retrieval score threshold. write_back: Whether to persist one summary record after reasoning. tracer: Optional tracer dependency. Raises: ValueError: Raised when ``memory_top_k`` is less than one. """ if memory_top_k < 1: raise ValueError("memory_top_k must be >= 1.") self._reasoning_delegate = reasoning_delegate self._memory_store = memory_store self._memory_namespace = memory_namespace.strip() or "default" self._memory_top_k = memory_top_k self._memory_min_score = memory_min_score self._write_back = write_back self._tracer = tracer self.workflow: Workflow | None = None
[docs] def run( self, prompt: str, *, request_id: str | None = None, dependencies: Mapping[str, object] | None = None, ) -> ExecutionResult: """Execute memory retrieval, delegated reasoning, and optional write-back.""" return self.compile( prompt=prompt, request_id=request_id, dependencies=dependencies, ).run()
[docs] def compile( self, prompt: str, *, request_id: str | None = None, dependencies: Mapping[str, object] | None = None, ) -> CompiledExecution: """Compile the read/reason/write workflow.""" run_context = resolve_pattern_run_context( default_request_id_prefix=None, default_dependencies={}, request_id=request_id, dependencies=dependencies, ) input_payload = { "prompt": prompt, "mode": MODE_RAG, "memory_namespace": self._memory_namespace, "memory_top_k": self._memory_top_k, "write_back": self._write_back, } workflow = self._build_workflow( prompt, request_id=run_context.request_id, dependencies=run_context.dependencies, ) return build_compiled_pattern_execution( workflow=workflow, pattern_name="RAGPattern", request_id=run_context.request_id, dependencies=run_context.dependencies, tracer=self._tracer, input_payload=input_payload, workflow_request_id=f"{run_context.request_id}:rag_reasoning", finalize=lambda workflow_result: _build_rag_result( workflow_result=workflow_result, request_id=run_context.request_id, dependencies=run_context.dependencies, memory_namespace=self._memory_namespace, memory_top_k=self._memory_top_k, write_back=self._write_back, ), )
def _build_workflow( self, prompt: str, *, request_id: str, dependencies: Mapping[str, object], ) -> Workflow: """Build the read/reason/write workflow for one resolved run context.""" del request_id, dependencies workflow_steps: list[WorkflowStep] = [ MemoryReadStep( step_id="memory_read", query_builder=lambda context: str(context.get("prompt", "")), namespace=self._memory_namespace, top_k=self._memory_top_k, min_score=self._memory_min_score, ), DelegateStep( step_id="reason", dependencies=("memory_read",), delegate=self._reasoning_delegate, prompt_builder=lambda context: _build_reasoning_prompt( task_prompt=str(context.get("prompt", "")), memory_read_step_output=_extract_dependency_output( context=context, dependency_id="memory_read", ), ), ), ] if self._write_back: workflow_steps.append( MemoryWriteStep( step_id="memory_write", dependencies=("memory_read", "reason"), namespace=self._memory_namespace, records_builder=lambda context: _build_write_back_records( task_prompt=str(context.get("prompt", "")), reason_step_output=_extract_dependency_output( context=context, dependency_id="reason", ), memory_read_step_output=_extract_dependency_output( context=context, dependency_id="memory_read", ), ), ) ) workflow = Workflow( tool_runtime=None, memory_store=self._memory_store, tracer=self._tracer, input_schema={"type": "object"}, base_context={"prompt": prompt}, steps=workflow_steps, ) self.workflow = workflow return workflow def _run_rag_pattern( self, *, prompt: str, request_id: str, dependencies: Mapping[str, object], ) -> ExecutionResult: """Execute underlying workflow for read/reason/write orchestration. Args: prompt: Task prompt. request_id: Resolved request identifier. dependencies: Normalized dependency mapping. Returns: Aggregated workflow result. """ workflow = self._build_workflow( prompt, request_id=request_id, dependencies=dependencies, ) workflow_result = workflow.run( input={}, execution_mode="sequential", request_id=f"{request_id}:rag_reasoning", dependencies=dependencies, ) return _build_rag_result( workflow_result=workflow_result, request_id=request_id, dependencies=dependencies, memory_namespace=self._memory_namespace, memory_top_k=self._memory_top_k, write_back=self._write_back, )
def _build_rag_result( *, workflow_result: ExecutionResult, request_id: str, dependencies: Mapping[str, object], memory_namespace: str, memory_top_k: int, write_back: bool, ) -> ExecutionResult: """Build final RAG result from one workflow execution.""" memory_read_result = workflow_result.step_results.get("memory_read") reason_result = workflow_result.step_results.get("reason") memory_write_result = workflow_result.step_results.get("memory_write") retrieval_output = ( dict(memory_read_result.output) if memory_read_result is not None else { "query": {}, "matches": [], "count": 0, "namespace": memory_namespace, } ) reasoning_output = dict(reason_result.output) if reason_result is not None else {} write_back_output = ( dict(memory_write_result.output) if memory_write_result is not None else {"written": 0, "namespace": memory_namespace, "ids": []} ) workflow_payload = workflow_result.to_dict() workflow_artifacts = workflow_result.output.get("artifacts", []) delegate_final_output = reasoning_output.get("output") final_output = dict(delegate_final_output) if isinstance(delegate_final_output, Mapping) else dict(reasoning_output) retrieval_details = dict(retrieval_output) retrieval_details["context"] = _build_retrieval_context(retrieval_output) result_success = workflow_result.success terminated_reason = "completed" if result_success else "workflow_failure" return build_pattern_execution_result( success=result_success, final_output=final_output, terminated_reason=terminated_reason, details={ "retrieval": retrieval_details, "reasoning": reasoning_output, "write_back": write_back_output, }, workflow_payload=workflow_payload, artifacts=workflow_artifacts, request_id=request_id, dependencies=dependencies, mode=MODE_RAG, metadata={ "memory_namespace": memory_namespace, "memory_top_k": memory_top_k, "write_back": write_back, }, tool_results=[], model_response=None, requested_mode=MODE_RAG, resolved_mode=MODE_RAG, ) def _extract_dependency_output( *, context: Mapping[str, object], dependency_id: str, ) -> Mapping[str, object]: """Extract one dependency output mapping from workflow step context. Args: context: Step context mapping. dependency_id: Dependency step identifier. Returns: Dependency output mapping when present, otherwise empty mapping. """ dependency_results = context.get("dependency_results") if not isinstance(dependency_results, Mapping): return {} dependency_payload = dependency_results.get(dependency_id) if not isinstance(dependency_payload, Mapping): return {} output = dependency_payload.get("output") if isinstance(output, Mapping): return output return {} def _build_reasoning_prompt( *, task_prompt: str, memory_read_step_output: Mapping[str, object], ) -> str: """Build explicit prompt with retrieved context injection. Args: task_prompt: Task prompt. memory_read_step_output: Output payload from memory read step. Returns: Prompt string passed to the reasoning delegate. """ matches = memory_read_step_output.get("matches") normalized_matches = matches if isinstance(matches, list) else [] context_json_block = json.dumps( { "namespace": memory_read_step_output.get("namespace", "default"), "count": _safe_int(memory_read_step_output.get("count")), "matches": normalized_matches, }, ensure_ascii=True, indent=2, sort_keys=True, ) context_text_lines: list[str] = [] for match in normalized_matches: if not isinstance(match, Mapping): continue item_id = str(match.get("item_id", "")) score = match.get("score") content = str(match.get("content", "")).strip() if not content: continue score_text = f" score={score}" if isinstance(score, (int, float)) else "" context_text_lines.append(f"- [{item_id}]{score_text} {content}") context_text = "\n".join(context_text_lines) if context_text_lines else "(none)" prompt_lines = [ f"Task: {task_prompt}", "", "Retrieved context (JSON):", context_json_block, "", "Retrieved context (text):", context_text, "", "Use the retrieved context when relevant, but reason independently when context is sparse.", ] return "\n".join(prompt_lines) def _build_write_back_records( *, task_prompt: str, reason_step_output: Mapping[str, object], memory_read_step_output: Mapping[str, object], ) -> list[dict[str, object]]: """Build write-back records from reasoning output and retrieval context. Args: task_prompt: Task prompt. reason_step_output: Reasoning step output payload. memory_read_step_output: Memory read step output payload. Returns: Memory write payloads for optional persistence. """ reasoning_payload = reason_step_output.get("output") normalized_reasoning = dict(reasoning_payload) if isinstance(reasoning_payload, Mapping) else {} retrieval_matches = memory_read_step_output.get("matches") retrieved_count = len(retrieval_matches) if isinstance(retrieval_matches, list) else 0 content_payload = { "task": task_prompt, "retrieved_count": retrieved_count, "reasoning": normalized_reasoning, } return [ { "content": json.dumps(content_payload, ensure_ascii=True, sort_keys=True), "metadata": { "kind": "rag_reasoning", "retrieved_count": retrieved_count, "task": task_prompt, }, } ] def _build_retrieval_context(retrieval_output: Mapping[str, object]) -> dict[str, object]: """Return one normalized retrieval-context payload derived from raw matches.""" matches = retrieval_output.get("matches") normalized_matches = ( [dict(match) for match in matches if isinstance(match, Mapping)] if isinstance(matches, list) else [] ) return { "namespace": retrieval_output.get("namespace", "default"), "count": _safe_int(retrieval_output.get("count")), "matches": normalized_matches, } def _safe_int(value: object) -> int: """Convert values to int with deterministic fallback to zero. Args: value: Raw input value. Returns: Integer representation or ``0`` fallback. """ if isinstance(value, bool): return int(value) if isinstance(value, int): return value if isinstance(value, float): return int(value) if isinstance(value, str): try: return int(value.strip()) except ValueError: return 0 return 0 __all__ = ["RAGPattern"]