Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 31 additions & 28 deletions libs/langgraph/langgraph/_internal/_config.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from __future__ import annotations

from collections import ChainMap
Expand Down Expand Up @@ -76,6 +76,32 @@
return config


def _merge_callbacks(base: Callbacks, new: Callbacks) -> Callbacks:
# callbacks can be either None, list[handler] or manager
# so merging two callbacks values has 6 cases
if base is None:
if isinstance(new, (list, BaseCallbackManager)):
return new.copy()
return new
if isinstance(new, list):
if isinstance(base, list):
return base + new
if isinstance(base, BaseCallbackManager):
mngr = base.copy()
for cb in new:
mngr.add_handler(cb, inherit=True)
return mngr
elif isinstance(new, BaseCallbackManager):
if isinstance(base, list):
mngr = new.copy()
for cb in base:
mngr.add_handler(cb, inherit=True)
return mngr
if isinstance(base, BaseCallbackManager):
return base.merge(new)
raise NotImplementedError(f"Unsupported callback types: {type(base)}, {type(new)}")


def merge_configs(*configs: RunnableConfig | None) -> RunnableConfig:
"""Merge multiple configs into one.

Expand Down Expand Up @@ -110,34 +136,9 @@
else:
base[key] = value
elif key == "callbacks":
base_callbacks = base.get("callbacks")
# callbacks can be either None, list[handler] or manager
# so merging two callbacks values has 6 cases
if isinstance(value, list):
if base_callbacks is None:
base["callbacks"] = value.copy()
elif isinstance(base_callbacks, list):
base["callbacks"] = base_callbacks + value
else:
# base_callbacks is a manager
mngr = base_callbacks.copy()
for callback in value:
mngr.add_handler(callback, inherit=True)
base["callbacks"] = mngr
elif isinstance(value, BaseCallbackManager):
# value is a manager
if base_callbacks is None:
base["callbacks"] = value.copy()
elif isinstance(base_callbacks, list):
mngr = value.copy()
for callback in base_callbacks:
mngr.add_handler(callback, inherit=True)
base["callbacks"] = mngr
else:
# base_callbacks is also a manager
base["callbacks"] = base_callbacks.merge(value)
else:
raise NotImplementedError
base["callbacks"] = _merge_callbacks(
base.get("callbacks"), cast(Callbacks, value)
)
elif key == "recursion_limit":
if config["recursion_limit"] != DEFAULT_RECURSION_LIMIT:
base["recursion_limit"] = config["recursion_limit"]
Expand Down Expand Up @@ -303,6 +304,8 @@
if _is_not_empty(v) and k in CONFIG_KEYS:
if k == CONF:
empty[k] = cast(dict, v).copy()
elif k == "callbacks" and isinstance(v, (list, BaseCallbackManager)):
empty["callbacks"] = _merge_callbacks(empty.get("callbacks"), v)
else:
empty[k] = v # type: ignore[literal-required]
for k, v in config.items():
Expand Down
21 changes: 20 additions & 1 deletion libs/langgraph/tests/test_config_async.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import pytest
from langchain_core.callbacks import AsyncCallbackManager
from langchain_core.callbacks import AsyncCallbackManager, BaseCallbackHandler

from langgraph._internal._config import get_async_callback_manager_for_config
from langgraph.graph import StateGraph

pytestmark = pytest.mark.anyio

Expand All @@ -17,3 +18,21 @@ def test_new_async_manager_merges_tags_with_config() -> None:
config = {"callbacks": None, "tags": ["a"]}
manager = get_async_callback_manager_for_config(config, tags=["b"])
assert manager.inheritable_tags == ["a", "b"]


async def test_with_config_callbacks_preserved_in_astream_events() -> None:
class TrackingCallback(BaseCallbackHandler):
def __init__(self) -> None:
self.called = False

def on_chain_start(self, *args, **kwargs) -> None:
self.called = True

builder = StateGraph(dict)
builder.add_node("node", lambda state: state)
builder.add_edge("__start__", "node")
cb = TrackingCallback()
graph = builder.compile().with_config({"callbacks": [cb]})
async for _ in graph.astream_events({}, version="v2"):
pass
assert cb.called
Loading