diff --git a/plugin/langchain/lib/graph/CompiledStateGraphObject.ts b/plugin/langchain/lib/graph/CompiledStateGraphObject.ts index 08b0359b..7e686389 100644 --- a/plugin/langchain/lib/graph/CompiledStateGraphObject.ts +++ b/plugin/langchain/lib/graph/CompiledStateGraphObject.ts @@ -11,6 +11,7 @@ import { EggPrototype } from '@eggjs/tegg-metadata'; import { ChatCheckpointSaverInjectName, ChatCheckpointSaverQualifierAttribute, GRAPH_EDGE_METADATA, GRAPH_NODE_METADATA, GraphEdgeMetadata, GraphMetadata, GraphNodeMetadata, IGraph, IGraphEdge, IGraphNode, TeggToolNode } from '@eggjs/tegg-langchain-decorator'; import { LangGraphTracer } from '../tracing/LangGraphTracer'; import { BaseCheckpointSaver, CompiledStateGraph } from '@langchain/langgraph'; +import { EGG_CONTEXT } from '@eggjs/egg-module-common'; export class CompiledStateGraphObject implements EggObject { private status: EggObjectStatus = EggObjectStatus.PENDING; @@ -41,19 +42,13 @@ export class CompiledStateGraphObject implements EggObject { const langGraphTraceObj = await EggContainerFactory.getOrCreateEggObjectFromName('langGraphTracer'); const tracer = langGraphTraceObj.obj as LangGraphTracer; tracer.setName(this.graphName); - graph.invoke = async (input: any, config?: any) => { - if (config?.tags?.includes('trace-log')) { - config.callbacks = [ tracer, ...(config?.callbacks || []) ]; - } - return await originalInvoke.call(graph, input, config); - }; - graph.stream = async (input: any, config?: any) => { - if (config?.tags?.includes('trace-log')) { - config.callbacks = [ tracer, ...(config?.callbacks || []) ]; - } - return await originalStream.call(graph, input, config) as any; - }; + + graph.invoke = (input: any, config?: any) => + this.wrapGraphMethod(originalInvoke.bind(graph), input, config); + + graph.stream = (input: any, config?: any) => + this.wrapGraphMethod(originalStream.bind(graph), input, config); this.status = EggObjectStatus.READY; } @@ -118,6 +113,42 @@ export class CompiledStateGraphObject implements EggObject { } } + + /** + * 包装 graph 方法,添加 tracing + */ + async wrapGraphMethod( + originalMethod: (input: any, config?: any) => Promise, + input: any, + config?: any, + ) { + // 确保 config 对象存在 + const finalConfig = config || {}; + + // 准备 tracer + const shouldTrace = finalConfig.tags?.includes('trace-log'); + if (shouldTrace) { + const langGraphTraceObj = await EggContainerFactory.getOrCreateEggObjectFromClazz(LangGraphTracer); + const tracer = langGraphTraceObj.obj as LangGraphTracer; + tracer.setName(this.graphName); + + finalConfig.callbacks = [ tracer, ...(finalConfig.callbacks || []) ]; + } + + // 设置 runId + if (!finalConfig.runId) { + const trace = await this.getTracer(); + finalConfig.runId = trace?.traceId; + } + + return await originalMethod(input, finalConfig); + } + + async getTracer() { + const ctx = ContextHandler.getContext()!.get(EGG_CONTEXT); + return ctx.tracer; + } + injectProperty() { throw new Error('never call GraphObject#injectProperty'); } diff --git a/plugin/langchain/test/fixtures/apps/langchain/config/plugin.js b/plugin/langchain/test/fixtures/apps/langchain/config/plugin.js index 5d4ee32e..7329d547 100644 --- a/plugin/langchain/test/fixtures/apps/langchain/config/plugin.js +++ b/plugin/langchain/test/fixtures/apps/langchain/config/plugin.js @@ -28,4 +28,9 @@ exports.teggMcpClient = { package: '@eggjs/tegg-mcp-client', }; +exports.tracer = { + package: 'egg-tracer', + enable: true, +}; + exports.watcher = false; diff --git a/plugin/langchain/test/llm.test.ts b/plugin/langchain/test/llm.test.ts index ac15fad5..cfd43e62 100644 --- a/plugin/langchain/test/llm.test.ts +++ b/plugin/langchain/test/llm.test.ts @@ -1,7 +1,7 @@ import mm from 'egg-mock'; import path from 'path'; import assert from 'assert'; - +import Tracer from 'egg-tracer/lib/tracer'; describe('plugin/langchain/test/llm.test.ts', () => { // https://github.com/langchain-ai/langchainjs/blob/main/libs/langchain/package.json#L9 @@ -73,11 +73,13 @@ describe('plugin/langchain/test/llm.test.ts', () => { it('should graph work', async () => { app.mockLog(); + mm(Tracer.prototype, 'traceId', 'test-trace-id'); await app.httpRequest() .get('/llm/graph') .expect(200, { value: 'hello graph toolhello world' }); app.expectLog(/agent_run/); app.expectLog(/Executing FooNode thread_id is 1/); + app.expectLog(/traceId=test-trace-id/); }); } });