Source code for rath.flow.workflow
"""Workflow base type: assigns ``AgentParam`` attributes and orchestrates sessions."""
from __future__ import annotations
from typing import Any
from rath.flow.agent_param import AgentParam
from rath.session.session import Session
def _indent_child_module_repr(body: str, spaces: int = 2) -> str:
"""Indent a child ``repr`` like ``torch.nn.Module`` (first line unindented)."""
lines = body.split("\n")
if len(lines) <= 1:
return body
first, *rest = lines
pad = " " * spaces
return first + "\n" + "\n".join(pad + line for line in rest)
[docs]
class Workflow:
"""Collects attached ``AgentParam`` instances and subclasses run sessions here."""
__slots__ = ("_agents",)
_agents: dict[str, AgentParam]
def __init__(self) -> None:
object.__setattr__(self, "_agents", {})
def __setattr__(self, name: str, value: Any) -> None:
if isinstance(value, AgentParam):
agents: dict[str, AgentParam] = object.__getattribute__(self, "_agents")
agents[name] = value
super().__setattr__(name, value)
def __delattr__(self, name: str) -> None:
agents = object.__getattribute__(self, "_agents")
agents.pop(name, None)
super().__delattr__(name)
[docs]
def named_agents(self) -> tuple[tuple[str, AgentParam], ...]:
"""Agent params registered via attribute assignment."""
agents: dict[str, AgentParam] = object.__getattribute__(self, "_agents")
return tuple(sorted(agents.items(), key=lambda x: x[0]))
[docs]
def forward(self, session: Session) -> Session:
"""Subclasses orchestrate Sessions (blocking)."""
raise NotImplementedError
def __call__(self, session: Session) -> Session:
# Before forward, join any in-flight lazy materialization so
# ``chunk_table`` is readable when ``forward`` runs.
if session._pending is not None:
session.synchronize()
return self.forward(session)
def __repr__(self) -> str:
cls_name = type(self).__name__
agents = self.named_agents()
if not agents:
return f"{cls_name}()"
lines = [f"{cls_name}("]
for child_name, agent in agents:
sub = repr(agent)
sub = _indent_child_module_repr(sub, 2)
lines.append(f" ({child_name}): {sub}")
lines.append(")")
return "\n".join(lines)
__str__ = __repr__
__all__ = ["Workflow"]