From e9d6f3ecf2584f3427026bf5798cb39a9656a388 Mon Sep 17 00:00:00 2001 From: Gianfranco Valentino Date: Sun, 14 Sep 2025 15:31:52 -0700 Subject: [PATCH] WIP - 2 --- grafi/common/topics/topic.py | 1 + grafi/common/topics/topic_base.py | 1 + grafi/common/topics/topic_event_cache.py | 1 + grafi/workflows/decorators.py | 280 ++++++++++++++++++ grafi/workflows/impl/event_driven_workflow.py | 9 + grafi/workflows/impl/utils.py | 4 +- tests/workflow/test_workflow_decorators.py | 117 ++++++++ 7 files changed, 412 insertions(+), 1 deletion(-) create mode 100644 grafi/workflows/decorators.py create mode 100644 tests/workflow/test_workflow_decorators.py diff --git a/grafi/common/topics/topic.py b/grafi/common/topics/topic.py index 809e1e49..f5f85dc0 100644 --- a/grafi/common/topics/topic.py +++ b/grafi/common/topics/topic.py @@ -58,6 +58,7 @@ async def a_publish_data( self, publish_event: PublishToTopicEvent ) -> Optional[PublishToTopicEvent]: if self.condition(publish_event.data): + print("PUBLUSHING") event = publish_event.model_copy( update={ "name": self.name, diff --git a/grafi/common/topics/topic_base.py b/grafi/common/topics/topic_base.py index dec9c961..e8c4c3f0 100644 --- a/grafi/common/topics/topic_base.py +++ b/grafi/common/topics/topic_base.py @@ -146,6 +146,7 @@ async def a_add_event(self, event: TopicEvent) -> TopicEvent: This method should be used by subclasses when publishing events. """ if isinstance(event, PublishToTopicEvent): + print("ADDING EVENT" + str(event)) return await self.event_cache.a_put(event) def serialize_callable(self) -> dict: diff --git a/grafi/common/topics/topic_event_cache.py b/grafi/common/topics/topic_event_cache.py index f74ced70..5152ecb8 100644 --- a/grafi/common/topics/topic_event_cache.py +++ b/grafi/common/topics/topic_event_cache.py @@ -110,6 +110,7 @@ async def a_put(self, event: TopicEvent) -> TopicEvent: offset = len(self._records) event.offset = offset # Set the offset for the event self._records.append(event) + print("APPENDED!") self._cond.notify_all() # wake waiting consumers return event diff --git a/grafi/workflows/decorators.py b/grafi/workflows/decorators.py new file mode 100644 index 00000000..401a3b00 --- /dev/null +++ b/grafi/workflows/decorators.py @@ -0,0 +1,280 @@ +""" +This file is part of the Graphite project. + +Copyright (c) 2023-2025 Binome Dev and contributors + +This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0. +If a copy of the MPL was not distributed with this file, You can obtain one at +https://mozilla.org/MPL/2.0/. +""" + + +from ast import ParamSpec +import asyncio +from collections.abc import Callable +import inspect +from typing import List, TypeVar +from uuid import uuid4 + +from grafi.common.events.topic_events.consume_from_topic_event import ConsumeFromTopicEvent +from openinference.semconv.trace import OpenInferenceSpanKindValues +from grafi.common.models.invoke_context import InvokeContext +from grafi.common.models.message import Messages, MsgsAGen +from grafi.common.topics.input_topic import InputTopic +from grafi.common.topics.output_topic import OutputTopic +from grafi.common.topics.subscription_builder import SubscriptionBuilder +from grafi.common.topics.topic import Topic +from grafi.nodes.node import Node +from grafi.nodes.node_base import NodeBaseBuilder +from grafi.tools.tool import Tool + + +class CallableTool(Tool): + """A Tool that wraps a callable function, that can be injected from a decorator context. This is an internal + detail and not part of the public API. + """ + + # Wrapper around callable. + _a_invoke_impl: Callable[[InvokeContext, List[ConsumeFromTopicEvent]], MsgsAGen] = None + + def __init__(self, tool_invoke: Callable[[InvokeContext, List[ConsumeFromTopicEvent]], MsgsAGen], **kwargs): + super().__init__(name="CallableTool", oi_span_type=OpenInferenceSpanKindValues.TOOL, type="CallableTool", **kwargs) + self._a_invoke_impl = tool_invoke + + + async def a_invoke( + self, + invoke_context: InvokeContext, + input_data: List[ConsumeFromTopicEvent], + ) -> MsgsAGen: + if self._a_invoke_impl is None: + raise NotImplementedError("Must provide `_a__invoke_impl` Callable.") + async for res in self._a_invoke_impl(self, invoke_context, input_data): + yield res + + def invoke(self, invoke_context, input_data) -> Messages: + """ Synchronously call the a_invoke_imnpl.""" + + async def a_invoke_bridge(invoke_context, input_data) -> Messages: + """ Async wrapper around async invoke that accumulates all the results.""" + results = [] + async for res in self.a_invoke(invoke_context, input_data): + results.extend(res) + return results + + inner_loop = asyncio.new_event_loop() + asyncio.set_event_loop(inner_loop) + results = inner_loop.run_until_complete(a_invoke_bridge(invoke_context, input_data)) + inner_loop.close() + return results + +class CallableNode(Node): + """A Node that wraps a callable tool, that can be injected from a decorator context. This is an internal + detail and not part of the public API. + + Unlike common Node, `condition` is evaluated at the output generation time on the node side, allowing nodes to covnerge + on the same topic, under different circumstances. + """ + + def __init__(self, **kwargs): + condition = kwargs.pop("condition", None) + super().__init__(**kwargs) + self._node_condition = condition + + + @classmethod + def builder(cls) -> NodeBaseBuilder: + """Return a builder for CallableNode.""" + return NodeBaseBuilder(cls) + + def can_invoke_with_topics(self, topic_with_messahes) -> bool: + if (self._node_condition is None): + return super().can_invoke_with_topics(topic_with_messahes) + + all_subscibed_topics = {topic.name: False for topic in self.subscribed_topics} + available_topics = all_subscibed_topics | {topic.name: True for topic in topic_with_messahes} + return eval(self._node_condiiton, __builtins__ ={}, globals=None, locals=available_topics) + + def can_invoke(self) -> bool: + # Evaluate each expression; if any is satisfied, we can invoke. + return self.can_invoke_with_topics([topic.name for topic in self.subscribed_topics if topic.can_consume(self.name)]) + + +def node(func): + """Decorator to mark a function as a node within a workflow. + + This decorator can be used to wrap functions that should be treated as nodes, providing a more declarative style + for defining a workflow. + """ + _node_id = uuid4().hex + + + def __wrapper(func: Callable[[InvokeContext, List[ConsumeFromTopicEvent]], MsgsAGen]) -> Callable[[InvokeContext, List[ConsumeFromTopicEvent]], MsgsAGen]: + """Generate the node information that wraps this and register it with the wrapping workflow. + """ + if hasattr(func, "__workflow_node") and func.__workflow_node.get("node_id"): + raise ValueError( + f"Function {func.__name__} is already registered as a node with id {func.__workflow_node['node_id']}. " + "Please ensure that the node decorator is applied only once." + ) + + # Don't expect any ordering on annotations, decorators will only set certain attributes and are composable. + func.__workflow_node = (func.__workflow_node or {}) | { + "node_tool_class": CallableTool, + "node_name": func.__name__, + "node_id": _node_id, + "node_type": "CallableToolNode", + "tool_invoke": func, + } + return func + + return __wrapper(func) + +def trigger_when(topic_expression: str): + """ + `trigger_when` defines the trigger condition for a node based on the topic names. Expressions are written as an expression that MUST + evaluate to a boolean. Contextual variables are the topic names, which evaluate to True if there is a new message on that topic. + """ + def __wrapper(func: Callable[[InvokeContext, List[ConsumeFromTopicEvent]], MsgsAGen]) -> Callable[[InvokeContext, List[ConsumeFromTopicEvent]], MsgsAGen]: + if hasattr(func, "__workflow_node") and func.__workflow_node.get("node_condition"): + raise ValueError( + f"Function {func.__name__} is already registered with a condition. " + "Please ensure that the condition decorator is applied only once." + ) + + # Don't expect any ordering on annotations, decorators will only set certain attributes and are composable. + func.__workflow_node = (func.__workflow_node or {}) | { + "node_condition": topic_expression, + } + return func + return __wrapper + +def publish_to(*args: str): + """ + `node` decorated functions can be further decorated with this to specify the topics they publish to. + """ + + def __wrapper(func: Callable[[InvokeContext, List[ConsumeFromTopicEvent]], MsgsAGen]) -> Callable[[InvokeContext, List[ConsumeFromTopicEvent]], MsgsAGen]: # Generate the node information that wraps this and register it with the wrapping workflow. + if (len(args) == 0): + raise ValueError("At least one topic must be specified for publishing.") + + registered_topics = [] + if hasattr(func, "__workflow_node"): + registered_topics += func.__workflow_node.get("node_publish_to", []) + registered_topics += [*args] + + if (len(registered_topics) != len(set(registered_topics))): + raise ValueError("Duplicate topics found in the publish_to decorator arguments.") + # Don't expect any ordering on annotations, decorators will only set certain attributes and are composable. + func.__workflow_node = getattr(func, "__workflow_node", {})| { + "node_publish_to": registered_topics, + } + return func + return __wrapper + + +def subscribe_to(*args: str): + """ + `node` decorated functions can be further decorated with this to specify the topics they subscribe to. + """ + + def __wrapper(func: Callable[[InvokeContext, List[ConsumeFromTopicEvent]], MsgsAGen]) -> Callable[[InvokeContext, List[ConsumeFromTopicEvent]], MsgsAGen]: # Generate the node information that wraps this and register it with the wrapping workflow. + if (len(args) == 0): + raise ValueError("At least one topic must be specified for publishing.") + + registered_topics = [] + if hasattr(func, "__workflow_node"): + registered_topics += func.__workflow_node.get("node_subscribe_to", []) + registered_topics += [*args] + + if (len(registered_topics) != len(set(registered_topics))): + raise ValueError("Duplicate topics found in the subscribe_to decorator arguments.") + + # Don't expect any ordering on annotations, decorators will only set certain attributes and are composable. + func.__workflow_node = (getattr(func, "__workflow_node", {})) | { + "node_subscribe_to": registered_topics, + } + return func + return __wrapper + +def workflow(workflow_class): + """ + This deocorator enhances the annotated class to behave as an event driven workflow and providing a class method that + allows generating the workflow from the decorators applied to its methods. + """ + + _workflow_id = uuid4().hex + _name = workflow_class.__name__ + _type = workflow_class.__name__ + _oi__span_type = workflow_class.oi_span_type if hasattr(workflow_class, 'oi_span_type') else OpenInferenceSpanKindValues.AGENT + + + @classmethod + def generate(cls, **kwargs): + builder = workflow_class.builder().oi_span_type(_oi__span_type).name(_name).type(_type) + methods = inspect.getmembers(workflow_class, predicate=inspect.isfunction) + + topics = {} + for name,method in methods: + node_data = getattr(method, "__workflow_node", None) + if node_data: + publishes_to = node_data.get("node_publish_to") + subscribes_to = node_data.get("node_subscribe_to") + + if not publish_to: + raise ValueError("Node {name} did not provide `publish_to` annotation.") + + if not subscribe_to: + raise ValueError("Node {name} did not provide `subscribe_to` annotation.") + + for topic_name in publishes_to + subscribes_to: + if topic_name == "output_topic": + # Special case for output_topic, which is always created + topics[topic_name] = OutputTopic(name=topic_name, condition=lambda x: True) + elif topic_name == "input_topic": + # Special case for output_topic, which is always created + topics[topic_name] = InputTopic(name=topic_name, condition=lambda x: True) + else: + topics[topic_name] = Topic(name=topic_name, condition=lambda x: True) + for _, method in methods: + node_data = getattr(method, "__workflow_node", None) + if node_data: + node_id = node_data["node_id"] + node_name = node_data["node_name"] + node_type = node_data["node_type"] + tool_invoke = node_data["tool_invoke"] + node_tool_class = node_data["node_tool_class"] + publishes_to = node_data["node_publish_to"] + subscribes_to = node_data["node_subscribe_to"] + node_condition = node_data.get("node_condition", None) + + node_subscribe_topics = [topics[topic_name] for topic_name in subscribes_to] + node_publish_topics = [topics[topic_name] for topic_name in publishes_to] + tool = node_tool_class(tool_invoke=tool_invoke) + + # Just meet the constructor requirements, the `CallableNode` will override the subscription evaluation. + subscriptions = [] + for subscribed_topic in node_subscribe_topics: + subscriptions += [SubscriptionBuilder().subscribed_to(subscribed_topic).build()] + + node = CallableNode( + node_id=node_id, + name = node_name, + type = node_type, + oi_span_type=_oi__span_type, + publish_to=node_publish_topics, + subscribed_expressions=subscriptions, + tool = tool, + condition=node_condition, + ) + builder.node(node) + return builder.build() + + # Don't expect any ordering on annotations, decorators will only set certain attributes and are composable. + workflow_class.__is_workflow = True + workflow_class.__workflow_id = _workflow_id + workflow_class.__workflow_name = _name + workflow_class.__workflow_type = _type + workflow_class.__workflow_oi_span_type = _oi__span_type + workflow_class.generate = generate + return workflow_class \ No newline at end of file diff --git a/grafi/workflows/impl/event_driven_workflow.py b/grafi/workflows/impl/event_driven_workflow.py index 82ef9d19..70e343a9 100644 --- a/grafi/workflows/impl/event_driven_workflow.py +++ b/grafi/workflows/impl/event_driven_workflow.py @@ -375,6 +375,7 @@ async def a_invoke( offset=event.offset, data=event.data, ) + print(consumed_event) yield consumed_event consumed_output_events.append(consumed_event) @@ -495,12 +496,14 @@ async def wait_node_invoke(node: NodeBase) -> None: # publish before commit node_output_events: List[PublishToTopicEvent] = [] if consumed_events: + print(node.name) async for event in node.a_invoke( invoke_context, consumed_events ): node_output_events.extend( await a_publish_events(node=node, publish_event=event) ) + print(node_output_events) await self._a_commit_events( consumer_name=node.name, topic_events=consumed_events @@ -540,13 +543,19 @@ def on_event(self, event: TopicEvent) -> None: """Handle topic publish events and trigger node invoke if conditions are met.""" if not isinstance(event, PublishToTopicEvent): return + + print("EVENT RECEIVED") name = event.name + print(name) if name not in self._topic_nodes: + print("No nodes subscribed to this topic") return # Get all nodes subscribed to this topic subscribed_nodes = self._topic_nodes[name] + print("NODES") + print(subscribed_nodes) for node_name in subscribed_nodes: node = self.nodes[node_name] diff --git a/grafi/workflows/impl/utils.py b/grafi/workflows/impl/utils.py index 3b4a8f32..c70c7c3e 100644 --- a/grafi/workflows/impl/utils.py +++ b/grafi/workflows/impl/utils.py @@ -103,8 +103,10 @@ async def a_publish_events( ) -> List[PublishToTopicEvent]: published_events: List[PublishToTopicEvent] = [] for topic in node.publish_to: + print(topic.name) + print(publish_event) event = await topic.a_publish_data(publish_event) - + print(event) if event: published_events.append(event) diff --git a/tests/workflow/test_workflow_decorators.py b/tests/workflow/test_workflow_decorators.py new file mode 100644 index 00000000..f5e93000 --- /dev/null +++ b/tests/workflow/test_workflow_decorators.py @@ -0,0 +1,117 @@ +""" +This file is part of the Graphite project. + +Copyright (c) 2023-2025 Binome Dev and contributors + +This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0. +If a copy of the MPL was not distributed with this file, You can obtain one at +https://mozilla.org/MPL/2.0/. +""" + +from typing import List +import asyncio +from uuid import uuid4 +from collections.abc import AsyncGenerator + + +from grafi.common.events.topic_events.consume_from_topic_event import ConsumeFromTopicEvent +from grafi.common.events.topic_events.publish_to_topic_event import PublishToTopicEvent +from grafi.common.models.invoke_context import InvokeContext +from grafi.common.models.message import Message, MsgsAGen +from grafi.workflows.decorators import publish_to, subscribe_to, node, workflow +from grafi.workflows.impl.event_driven_workflow import EventDrivenWorkflow + + +@workflow +class SingleNodeWorkflow(EventDrivenWorkflow): + """ + Example workflow that uses the decorators to define nodes and their interactions. This consist of a single node. + """ + + @node + @publish_to("output_topic") + @subscribe_to("input_topic") + async def hello(self, invoke_context: InvokeContext, node_input: List[Message]) -> MsgsAGen: + assert(len(node_input) == 1) + assert(node_input[0].content == "Test message") + output_data = [Message( + role="user", + content="hi", + )] + yield output_data + + +@workflow +class MultiNodeWorkflow(EventDrivenWorkflow): + """ + Example workflow that uses the decorators to define nodes and their interactions. This consist of a single node. + """ + + @node + @publish_to("foo_bar_topic") + @subscribe_to("input_topic") + async def hello(self, invoke_context: InvokeContext, node_input: List[Message]) -> MsgsAGen: + print("In hello 2") + assert(len(node_input) == 1) + assert(node_input[0].content == "Test message") + output_data = [Message( + role="user", + content="Got test message", + )] + yield output_data + + @node + @publish_to("output_topic") + @subscribe_to("foo_bar_topic") + async def bye(self, invoke_context: InvokeContext, node_input: List[Message]) -> MsgsAGen: + print("In bye 3") + assert(len(node_input) == 1) + assert(node_input[0].content == "Got test message") + + output_data = [Message( + role="user", + content="hi", + )] + yield output_data + +async def main(): + input_messages = [ + Message( + role="user", + content="Test message", + ) + ] + + invoke_context = InvokeContext( + conversation_id="conversation_id", + invoke_id=uuid4().hex, + assistant_request_id=uuid4().hex, + ) + + event = PublishToTopicEvent( + invoke_context=invoke_context, + data=input_messages, + ) + + # One node from input to output +# wkflow = SingleNodeWorkflow.generate() +# has_messages = False +# async for output in wkflow.a_invoke(event): +# has_messages = True +# assert(len(output.data) == 1) +# assert(output.data[0].content == "hi") +# assert(has_messages) + + # Two nodes, from input to foo_bar to output + wkflow = MultiNodeWorkflow.generate() + print(wkflow) + has_messages = False + async for output in wkflow.a_invoke(event): + print(output) + has_messages = True + assert(len(output.data) == 1) + assert(output.data[0].content == "hi") + assert(has_messages) + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file