Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions langchain/README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# LangChain Sample

This sample shows you how you can use Temporal to orchestrate workflows for [LangChain](https://www.langchain.com).
This sample shows you how you can use Temporal to orchestrate workflows for [LangChain](https://www.langchain.com). It includes an interceptor that makes LangSmith traces work seamlessly across Temporal clients, workflows and activities.

For this sample, the optional `langchain` dependency group must be included. To include, run:

Expand All @@ -21,8 +21,10 @@ This will start the worker. Then, in another terminal, run the following to exec

Then, in another terminal, run the following command to translate a phrase:

curl -X POST "http://localhost:8000/translate?phrase=hello%20world&language=Spanish"
curl -X POST "http://localhost:8000/translate?phrase=hello%20world&language1=Spanish&language2=French&language3=Russian"

Which should produce some output like:

{"translation":"Hola mundo"}
{"translations":{"French":"Bonjour tout le monde","Russian":"Привет, мир","Spanish":"Hola mundo"}}

Check [LangSmith](https://smith.langchain.com/) for the corresponding trace.
9 changes: 6 additions & 3 deletions langchain/activities.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class TranslateParams:


@activity.defn
async def translate_phrase(params: TranslateParams) -> dict:
async def translate_phrase(params: TranslateParams) -> str:
# LangChain setup
template = """You are a helpful assistant who translates between languages.
Translate the following phrase into the specified language: {phrase}
Expand All @@ -26,6 +26,9 @@ async def translate_phrase(params: TranslateParams) -> dict:
)
chain = chat_prompt | ChatOpenAI()
# Use the asynchronous invoke method
return dict(
await chain.ainvoke({"phrase": params.phrase, "language": params.language})
return (
dict(
await chain.ainvoke({"phrase": params.phrase, "language": params.language})
).get("content")
or ""
)
181 changes: 181 additions & 0 deletions langchain/langchain_interceptor.py
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see we never really had tests for the langchain sample, so don't have to add as part of this, but we should probably make an issue to add them at some point. We've found cases where these integration-type examples become stale or stop working with new dependency versions and there are no tests to confirm that.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure how a local test would help, as they don't have an embeddable version of Langsmith.

Copy link
Member

@cretz cretz Jun 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is unfortunate. Ideally we'd test the same way they encourage users to test using the in-memory representations or mocks they suggest. But we do often not test integrations in this repo.

Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
from __future__ import annotations

from typing import Any, Mapping, Protocol, Type

from temporalio import activity, api, client, converter, worker, workflow

with workflow.unsafe.imports_passed_through():
from contextlib import contextmanager

from langsmith import trace, tracing_context
from langsmith.run_helpers import get_current_run_tree

# Header key for LangChain context
LANGCHAIN_CONTEXT_KEY = "langchain-context"


class _InputWithHeaders(Protocol):
headers: Mapping[str, api.common.v1.Payload]


def set_header_from_context(
input: _InputWithHeaders, payload_converter: converter.PayloadConverter
) -> None:
# Get current LangChain run tree
run_tree = get_current_run_tree()
if run_tree:
headers = run_tree.to_headers()
input.headers = {
**input.headers,
LANGCHAIN_CONTEXT_KEY: payload_converter.to_payload(headers),
}


@contextmanager
def context_from_header(
input: _InputWithHeaders, payload_converter: converter.PayloadConverter
):
payload = input.headers.get(LANGCHAIN_CONTEXT_KEY)
if payload:
run_tree = payload_converter.from_payload(payload, dict)
# Set the run tree in the current context
with tracing_context(parent=run_tree):
yield
else:
yield


class LangChainContextPropagationInterceptor(client.Interceptor, worker.Interceptor):
"""Interceptor that propagates LangChain context through Temporal."""

def __init__(
self,
payload_converter: converter.PayloadConverter = converter.default().payload_converter,
) -> None:
self._payload_converter = payload_converter

def intercept_client(
self, next: client.OutboundInterceptor
) -> client.OutboundInterceptor:
return _LangChainContextPropagationClientOutboundInterceptor(
next, self._payload_converter
)

def intercept_activity(
self, next: worker.ActivityInboundInterceptor
) -> worker.ActivityInboundInterceptor:
return _LangChainContextPropagationActivityInboundInterceptor(next)

def workflow_interceptor_class(
self, input: worker.WorkflowInterceptorClassInput
) -> Type[_LangChainContextPropagationWorkflowInboundInterceptor]:
return _LangChainContextPropagationWorkflowInboundInterceptor


class _LangChainContextPropagationClientOutboundInterceptor(client.OutboundInterceptor):
def __init__(
self,
next: client.OutboundInterceptor,
payload_converter: converter.PayloadConverter,
) -> None:
super().__init__(next)
self._payload_converter = payload_converter

async def start_workflow(
self, input: client.StartWorkflowInput
) -> client.WorkflowHandle[Any, Any]:
with trace(name=f"start_workflow:{input.workflow}"):
set_header_from_context(input, self._payload_converter)
return await super().start_workflow(input)


class _LangChainContextPropagationActivityInboundInterceptor(
worker.ActivityInboundInterceptor
):
async def execute_activity(self, input: worker.ExecuteActivityInput) -> Any:
if isinstance(input.fn, str):
name = input.fn
elif callable(input.fn):
defn = activity._Definition.from_callable(input.fn)
name = (
defn.name if defn is not None and defn.name is not None else "unknown"
)
else:
name = "unknown"

with context_from_header(input, activity.payload_converter()):
with trace(name=f"execute_activity:{name}"):
return await self.next.execute_activity(input)


class _LangChainContextPropagationWorkflowInboundInterceptor(
worker.WorkflowInboundInterceptor
):
def init(self, outbound: worker.WorkflowOutboundInterceptor) -> None:
self.next.init(
_LangChainContextPropagationWorkflowOutboundInterceptor(outbound)
)

async def execute_workflow(self, input: worker.ExecuteWorkflowInput) -> Any:
if isinstance(input.run_fn, str):
name = input.run_fn
elif callable(input.run_fn):
defn = workflow._Definition.from_run_fn(input.run_fn)
name = (
defn.name if defn is not None and defn.name is not None else "unknown"
)
else:
name = "unknown"

with context_from_header(input, workflow.payload_converter()):
# This is a sandbox friendly way to write
# with trace(...):
# return await self.next.execute_workflow(input)
with workflow.unsafe.sandbox_unrestricted():
t = trace(
name=f"execute_workflow:{name}", run_id=workflow.info().run_id
)
with workflow.unsafe.imports_passed_through():
t.__enter__()
try:
return await self.next.execute_workflow(input)
finally:
with workflow.unsafe.sandbox_unrestricted():
# Cannot use __aexit__ because it's internally uses
# loop.run_in_executor which is not available in the sandbox
t.__exit__()


class _LangChainContextPropagationWorkflowOutboundInterceptor(
worker.WorkflowOutboundInterceptor
):
def start_activity(
self, input: worker.StartActivityInput
) -> workflow.ActivityHandle:
with workflow.unsafe.sandbox_unrestricted():
t = trace(name=f"start_activity:{input.activity}", run_id=workflow.uuid4())
with workflow.unsafe.imports_passed_through():
t.__enter__()
try:
set_header_from_context(input, workflow.payload_converter())
return self.next.start_activity(input)
finally:
with workflow.unsafe.sandbox_unrestricted():
t.__exit__()

async def start_child_workflow(
self, input: worker.StartChildWorkflowInput
) -> workflow.ChildWorkflowHandle:
with workflow.unsafe.sandbox_unrestricted():
t = trace(
name=f"start_child_workflow:{input.workflow}", run_id=workflow.uuid4()
)
with workflow.unsafe.imports_passed_through():
t.__enter__()

try:
set_header_from_context(input, workflow.payload_converter())
return await self.next.start_child_workflow(input)
finally:
with workflow.unsafe.sandbox_unrestricted():
t.__exit__()
16 changes: 10 additions & 6 deletions langchain/starter.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,41 @@
from contextlib import asynccontextmanager
from typing import List
from uuid import uuid4

import uvicorn
from activities import TranslateParams
from fastapi import FastAPI, HTTPException
from langchain_interceptor import LangChainContextPropagationInterceptor
from temporalio.client import Client
from workflow import LangChainWorkflow
from workflow import LangChainWorkflow, TranslateWorkflowParams


@asynccontextmanager
async def lifespan(app: FastAPI):
app.state.temporal_client = await Client.connect("localhost:7233")
app.state.temporal_client = await Client.connect(
"localhost:7233", interceptors=[LangChainContextPropagationInterceptor()]
)
yield


app = FastAPI(lifespan=lifespan)


@app.post("/translate")
async def translate(phrase: str, language: str):
async def translate(phrase: str, language1: str, language2: str, language3: str):
languages = [language1, language2, language3]
client = app.state.temporal_client
try:
result = await client.execute_workflow(
LangChainWorkflow.run,
TranslateParams(phrase, language),
TranslateWorkflowParams(phrase, languages),
id=f"langchain-translation-{uuid4()}",
task_queue="langchain-task-queue",
)
translation_content = result.get("content", "Translation not available")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

return {"translation": translation_content}
return {"translations": result}


if __name__ == "__main__":
Expand Down
9 changes: 6 additions & 3 deletions langchain/worker.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import asyncio

from activities import translate_phrase
from langchain_interceptor import LangChainContextPropagationInterceptor
from temporalio.client import Client
from temporalio.worker import Worker
from workflow import LangChainWorkflow
from workflow import LangChainChildWorkflow, LangChainWorkflow

interrupt_event = asyncio.Event()

Expand All @@ -13,8 +14,9 @@ async def main():
worker = Worker(
client,
task_queue="langchain-task-queue",
workflows=[LangChainWorkflow],
workflows=[LangChainWorkflow, LangChainChildWorkflow],
activities=[translate_phrase],
interceptors=[LangChainContextPropagationInterceptor()],
)

print("\nWorker started, ctrl+c to exit\n")
Expand All @@ -28,7 +30,8 @@ async def main():


if __name__ == "__main__":
loop = asyncio.get_event_loop()
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(main())
except KeyboardInterrupt:
Expand Down
40 changes: 38 additions & 2 deletions langchain/workflow.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import asyncio
from dataclasses import dataclass
from datetime import timedelta
from typing import List

from temporalio import workflow

Expand All @@ -7,11 +10,44 @@


@workflow.defn
class LangChainWorkflow:
class LangChainChildWorkflow:
@workflow.run
async def run(self, params: TranslateParams) -> dict:
async def run(self, params: TranslateParams) -> str:
return await workflow.execute_activity(
translate_phrase,
params,
schedule_to_close_timeout=timedelta(seconds=30),
)


@dataclass
class TranslateWorkflowParams:
phrase: str
languages: List[str]


@workflow.defn
class LangChainWorkflow:
@workflow.run
async def run(self, params: TranslateWorkflowParams) -> dict:
result1, result2, result3 = await asyncio.gather(
workflow.execute_activity(
translate_phrase,
TranslateParams(params.phrase, params.languages[0]),
schedule_to_close_timeout=timedelta(seconds=30),
),
workflow.execute_activity(
translate_phrase,
TranslateParams(params.phrase, params.languages[1]),
schedule_to_close_timeout=timedelta(seconds=30),
),
workflow.execute_child_workflow(
LangChainChildWorkflow.run,
TranslateParams(params.phrase, params.languages[2]),
),
)
return {
params.languages[0]: result1,
params.languages[1]: result2,
params.languages[2]: result3,
}
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ gevent = ["gevent==25.4.2 ; python_version >= '3.8'"]
langchain = [
"langchain>=0.1.7,<0.2 ; python_version >= '3.8.1' and python_version < '4.0'",
"langchain-openai>=0.0.6,<0.0.7 ; python_version >= '3.8.1' and python_version < '4.0'",
"langsmith>=0.1.22,<0.2 ; python_version >= '3.8.1' and python_version < '4.0'",
"openai>=1.4.0,<2",
"fastapi>=0.105.0,<0.106",
"tqdm>=4.62.0,<5",
Expand Down
Loading
Loading