diff --git a/core/langchain-decorator/package.json b/core/langchain-decorator/package.json index 5ac2f355..d8ab4bf3 100644 --- a/core/langchain-decorator/package.json +++ b/core/langchain-decorator/package.json @@ -57,7 +57,8 @@ "cross-env": "^7.0.3", "mocha": "^10.2.0", "ts-node": "^10.9.1", - "typescript": "^5.0.4" + "typescript": "^5.0.4", + "zod": "^3.24.4" }, "gitHead": "5f24bacd9131435188be15568d86ef4575f85636" } diff --git a/plugin/controller/test/http/request.test.ts b/plugin/controller/test/http/request.test.ts index b3171ff5..d4701678 100644 --- a/plugin/controller/test/http/request.test.ts +++ b/plugin/controller/test/http/request.test.ts @@ -30,7 +30,7 @@ describe('plugin/controller/test/http/request.test.ts', () => { }); const [ nodeMajor ] = process.versions.node.split('.').map(v => Number(v)); if (nodeMajor >= 16) { - it.only('Request should work', async () => { + it('Request should work', async () => { app.mockCsrf(); const param = { name: 'foo', diff --git a/plugin/langchain/app.ts b/plugin/langchain/app.ts index 49cef982..947823eb 100644 --- a/plugin/langchain/app.ts +++ b/plugin/langchain/app.ts @@ -7,6 +7,7 @@ import { CompiledStateGraphObject } from './lib/graph/CompiledStateGraphObject'; import { BoundModelObjectHook } from './lib/boundModel/BoundModelObjectHook'; import { GraphPrototypeHook } from './lib/graph/GraphPrototypeHook'; import { GraphBuildHook } from './lib/graph/GraphBuildHook'; +import { AgentHttpLoadUnitLifecycleHook } from './lib/agent/AgentHttpLoadUnitLifecycleHook'; export default class ModuleLangChainHook implements IBoot { readonly #app: Application; @@ -14,6 +15,7 @@ export default class ModuleLangChainHook implements IBoot { readonly #graphLoadUnitHook: GraphLoadUnitHook; readonly #boundModelObjectHook: BoundModelObjectHook; readonly #graphPrototypeHook: GraphPrototypeHook; + #agentHttpLoadUnitHook: AgentHttpLoadUnitLifecycleHook; constructor(app: Application) { this.#app = app; @@ -25,9 +27,11 @@ export default class ModuleLangChainHook implements IBoot { } configWillLoad() { + this.#agentHttpLoadUnitHook = new AgentHttpLoadUnitLifecycleHook(this.#app.moduleConfigs); + this.#app.loadUnitLifecycleUtil.registerLifecycle(this.#agentHttpLoadUnitHook); this.#app.eggObjectLifecycleUtil.registerLifecycle(this.#graphObjectHook); this.#app.eggObjectLifecycleUtil.registerLifecycle(this.#boundModelObjectHook); - this.#app.eggObjectFactory.registerEggObjectCreateMethod(CompiledStateGraphProto, CompiledStateGraphObject.createObject); + this.#app.eggObjectFactory.registerEggObjectCreateMethod(CompiledStateGraphProto, CompiledStateGraphObject.createObject(this.#app)); this.#app.eggPrototypeLifecycleUtil.registerLifecycle(this.#graphPrototypeHook); } @@ -36,6 +40,9 @@ export default class ModuleLangChainHook implements IBoot { } async beforeClose() { + if (this.#agentHttpLoadUnitHook) { + this.#app.loadUnitLifecycleUtil.deleteLifecycle(this.#agentHttpLoadUnitHook); + } this.#app.eggObjectLifecycleUtil.deleteLifecycle(this.#graphObjectHook); this.#app.eggObjectLifecycleUtil.deleteLifecycle(this.#boundModelObjectHook); this.#app.loadUnitLifecycleUtil.deleteLifecycle(this.#graphLoadUnitHook); diff --git a/plugin/langchain/lib/agent/AgentHttpLoadUnitLifecycleHook.ts b/plugin/langchain/lib/agent/AgentHttpLoadUnitLifecycleHook.ts new file mode 100644 index 00000000..1660feee --- /dev/null +++ b/plugin/langchain/lib/agent/AgentHttpLoadUnitLifecycleHook.ts @@ -0,0 +1,140 @@ +import { + ConfigSourceQualifier, + Context, + HTTPBody, + HTTPController, + HTTPMethod, + HTTPMethodEnum, + LifecycleHook, +} from '@eggjs/tegg'; +import { ClassProtoDescriptor, EggContainerFactory, EggPrototypeCreatorFactory, EggPrototypeFactory, ProtoDescriptorHelper } from '@eggjs/tegg/helper'; +import type { LoadUnit, LoadUnitLifecycleContext } from '@eggjs/tegg-metadata'; +import { ModuleConfig, ModuleReference } from 'egg'; +import { LangChainConfigSchemaType } from 'typings'; +import { Readable, Transform } from 'stream'; +import { CompiledStateGraph } from '@langchain/langgraph'; +import { AIMessage, HumanMessage, SystemMessage, ToolMessage } from '@langchain/core/messages'; + + +export interface ModuleConfigHolder { + name: string; + config: ModuleConfig; + reference: ModuleReference; +} + +type ValueOf = T[keyof T]; + +export class AgentHttpLoadUnitLifecycleHook implements LifecycleHook { + readonly moduleConfigs: Record; + + constructor(moduleConfigs: Record) { + this.moduleConfigs = moduleConfigs; + } + + async preCreate(_: LoadUnitLifecycleContext, loadUnit: LoadUnit): Promise { + const moduleConfigs = this.#getModuleConfig(loadUnit); + if (moduleConfigs.length > 0) { + for (const [ graphName, config ] of moduleConfigs) { + if (config?.type === 'http') { + const GraphHttpController = this.#createGraphHttpControllerClass(loadUnit, graphName, config); + const protoDescriptor = ProtoDescriptorHelper.createByInstanceClazz(GraphHttpController, { + moduleName: loadUnit.name, + unitPath: loadUnit.unitPath, + }) as ClassProtoDescriptor; + + const proto = await EggPrototypeCreatorFactory.createProtoByDescriptor(protoDescriptor, loadUnit); + EggPrototypeFactory.instance.registerPrototype(proto, loadUnit); + } + } + } + } + + #createGraphHttpControllerClass(loadUnit: LoadUnit, graphName: string, config: ValueOf) { + class GraphHttpController { + @HTTPMethod({ + path: config.path!, + method: HTTPMethodEnum.POST, + timeout: config.timeout, + }) + async invoke(@Context() ctx, @HTTPBody() args) { + const eggObj = await EggContainerFactory.getOrCreateEggObjectFromName(`compiled${graphName}`); + const invokeFunc = (eggObj.obj as CompiledStateGraph).invoke; + const streamFunc = (eggObj.obj as CompiledStateGraph).stream; + const genArgs = Object.entries(args).reduce((acc, [ key, value ]) => { + if (Array.isArray(value) && typeof value[0] === 'object') { + acc[key] = value.map(obj => { + switch (obj.role) { + case 'human': + return new HumanMessage(obj); + case 'ai': + return new AIMessage(obj); + case 'system': + return new SystemMessage(obj); + case 'tool': + return new ToolMessage(obj); + default: + throw new Error('unknown message type'); + } + }); + } else { + acc[key] = value; + } + return acc; + }, {}); + + const defaultConfig = { + configurable: { + thread_id: process.pid.toString(), + }, + }; + + const res = await Reflect.apply(config.stream ? streamFunc : invokeFunc, (eggObj.obj as CompiledStateGraph), [ genArgs, defaultConfig ]); + + if (config.stream) { + ctx.set({ + 'content-type': 'text/event-stream', + 'cache-control': 'no-cache', + 'transfer-encoding': 'chunked', + 'X-Accel-Buffering': 'no', + }); + const transformStream = new Transform({ + objectMode: true, + transform(chunk: any, _encoding: string, callback) { + try { + // 如果 chunk 是对象,转换为 JSON + let data: string; + if (typeof chunk === 'string') { + data = chunk; + } else if (typeof chunk === 'object') { + data = JSON.stringify(chunk); + } else { + data = String(chunk); + } + + // 格式化为 SSE 格式 + const sseFormatted = `data: ${data}\n\n`; + callback(null, sseFormatted); + } catch (error) { + callback(error); + } + }, + }); + return Readable.fromWeb(res as any, { objectMode: true }).pipe(transformStream); + } + return res; + } + } + HTTPController({ controllerName: `${graphName}HttpController`, protoName: `${graphName}HttpController` })(GraphHttpController); + ConfigSourceQualifier(loadUnit.name)(GraphHttpController.prototype, 'moduleConfig'); + + return GraphHttpController; + } + + #getModuleConfig(loadUnit: LoadUnit) { + const moduleConfig: LangChainConfigSchemaType = (this.moduleConfigs[loadUnit.name]?.config as any)?.langchain; + if (moduleConfig && Object.keys(moduleConfig?.agents || {}).length > 0) { + return Object.entries(moduleConfig.agents); + } + return []; + } +} diff --git a/plugin/langchain/lib/graph/CompiledStateGraphObject.ts b/plugin/langchain/lib/graph/CompiledStateGraphObject.ts index 2854ba2e..154536dd 100644 --- a/plugin/langchain/lib/graph/CompiledStateGraphObject.ts +++ b/plugin/langchain/lib/graph/CompiledStateGraphObject.ts @@ -11,6 +11,7 @@ import { EggPrototype } from '@eggjs/tegg-metadata'; import { ChatCheckpointSaverInjectName, ChatCheckpointSaverQualifierAttribute, GRAPH_EDGE_METADATA, GRAPH_NODE_METADATA, GraphEdgeMetadata, GraphMetadata, GraphNodeMetadata, IGraph, IGraphEdge, IGraphNode, TeggToolNode } from '@eggjs/tegg-langchain-decorator'; import { LangGraphTracer } from '../tracing/LangGraphTracer'; import { BaseCheckpointSaver, CompiledStateGraph } from '@langchain/langgraph'; +import { Application } from 'egg'; export class CompiledStateGraphObject implements EggObject { private status: EggObjectStatus = EggObjectStatus.PENDING; @@ -19,17 +20,19 @@ export class CompiledStateGraphObject implements EggObject { readonly proto: CompiledStateGraphProto; readonly ctx: EggContext; readonly daoName: string; - private _obj: object; + _obj: object; readonly graphMetadata: GraphMetadata; readonly graphName: string; + readonly app: Application; - constructor(name: EggObjectName, proto: CompiledStateGraphProto) { + constructor(name: EggObjectName, proto: CompiledStateGraphProto, app: Application) { this.name = name; this.proto = proto; this.ctx = ContextHandler.getContext()!; this.id = IdenticalUtil.createObjectId(this.proto.id, this.ctx?.id); this.graphMetadata = proto.graphMetadata; this.graphName = proto.graphName; + this.app = app; } async init() { @@ -122,9 +125,11 @@ export class CompiledStateGraphObject implements EggObject { return this._obj; } - static async createObject(name: EggObjectName, proto: EggPrototype): Promise { - const compiledStateGraphObject = new CompiledStateGraphObject(name, proto as CompiledStateGraphProto); - await compiledStateGraphObject.init(); - return compiledStateGraphObject; + static createObject(app: Application) { + return async function(name: EggObjectName, proto: EggPrototype): Promise { + const compiledStateGraphObject = new CompiledStateGraphObject(name, proto as CompiledStateGraphProto, app); + await compiledStateGraphObject.init(); + return compiledStateGraphObject; + }; } } diff --git a/plugin/langchain/lib/graph/CompiledStateGraphProto.ts b/plugin/langchain/lib/graph/CompiledStateGraphProto.ts index 3dd2a0a4..f7ce1135 100644 --- a/plugin/langchain/lib/graph/CompiledStateGraphProto.ts +++ b/plugin/langchain/lib/graph/CompiledStateGraphProto.ts @@ -24,6 +24,7 @@ export class CompiledStateGraphProto implements EggPrototype { readonly name: EggPrototypeName; readonly graphMetadata: GraphMetadata; readonly graphName: string; + readonly unitPath: string; constructor(loadUnit: LoadUnit, protoName: string, graphName: string, graphMetadata: GraphMetadata) { this.loadUnitId = loadUnit.id; @@ -31,6 +32,7 @@ export class CompiledStateGraphProto implements EggPrototype { this.name = protoName; this.graphMetadata = graphMetadata; this.graphName = graphName; + this.unitPath = loadUnit.unitPath; this.id = IdenticalUtil.createProtoId(loadUnit.id, protoName); } diff --git a/plugin/langchain/package.json b/plugin/langchain/package.json index 71869849..b77f4b45 100644 --- a/plugin/langchain/package.json +++ b/plugin/langchain/package.json @@ -52,7 +52,7 @@ "typescript": true }, "engines": { - "node": ">=18.0.0" + "node": ">=20.0.0" }, "dependencies": { "@eggjs/egg-module-common": "^3.64.2", diff --git a/plugin/langchain/test/fixtures/apps/langchain/app/modules/bar/module.yml b/plugin/langchain/test/fixtures/apps/langchain/app/modules/bar/module.yml index 46b277ac..2a548e0b 100644 --- a/plugin/langchain/test/fixtures/apps/langchain/app/modules/bar/module.yml +++ b/plugin/langchain/test/fixtures/apps/langchain/app/modules/bar/module.yml @@ -16,4 +16,11 @@ mcp: clientName: barSse version: 1.0.0 transportType: SSE - type: http \ No newline at end of file + type: http + +langchain: + agents: + FooGraph: + path: /graph/stream + type: http + stream: true \ No newline at end of file diff --git a/plugin/langchain/test/fixtures/apps/langchain/config/config.default.js b/plugin/langchain/test/fixtures/apps/langchain/config/config.default.js index c26c3118..db08ee8d 100644 --- a/plugin/langchain/test/fixtures/apps/langchain/config/config.default.js +++ b/plugin/langchain/test/fixtures/apps/langchain/config/config.default.js @@ -8,9 +8,6 @@ module.exports = function() { enable: false, }, }, - bodyParser: { - enable: false, - }, }; return config; }; diff --git a/plugin/langchain/test/llm.test.ts b/plugin/langchain/test/llm.test.ts index eca27953..87c53596 100644 --- a/plugin/langchain/test/llm.test.ts +++ b/plugin/langchain/test/llm.test.ts @@ -76,5 +76,60 @@ describe('plugin/langchain/test/llm.test.ts', () => { .get('/llm/graph') .expect(200, { value: 'hello graph toolhello world' }); }); + + it('should agent controller work', async () => { + const url = await app.httpRequest() + .post('/graph/stream').url; + const response = await fetch(url, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify({ messages: [{ role: 'human', content: 'hello world' }] }), + }); + + + if (!response.ok) { + throw new Error(`HTTP ${response.status}`); + } + + if (!response.body) { + throw new Error('Response body is null'); + } + + const reader = response.body.getReader(); + const decoder = new TextDecoder(); + let buffer = ''; + + const messages: object[] = []; + + try { + // eslint-disable-next-line no-constant-condition + while (true) { + const { done, value } = await reader.read(); + + if (done) break; + + buffer += decoder.decode(value, { stream: true }); + const lines = buffer.split('\n'); + buffer = lines.pop() || ''; + + lines.forEach(line => { + if (line.startsWith('data: ')) { + const data = line.slice(6); + try { + const parsed = JSON.parse(data); + messages.push(parsed); + } catch (e) { + throw e; + } + } + }); + } + } finally { + reader.releaseLock(); + } + assert(messages.length === 3); + }); } }); diff --git a/plugin/langchain/typings/index.d.ts b/plugin/langchain/typings/index.d.ts index d3673afa..0bb27649 100644 --- a/plugin/langchain/typings/index.d.ts +++ b/plugin/langchain/typings/index.d.ts @@ -108,11 +108,34 @@ export const ChatModelConfigModuleConfigSchema = Type.Object({ name: 'ChatModel', }); + +export const LangChainConfigSchema = Type.Object({ + agents: Type.Record(Type.String(), Type.Object({ + path: Type.Optional(Type.String({ + description: 'http path', + })), + stream: Type.Optional(Type.Boolean({ + description: '是否流式返回', + })), + type: Type.Optional(Type.String({ + description: 'Http', + })), + timeout: Type.Optional(Type.Number({ + description: '接口超时时间', + })), + })), +}, { + title: 'langchain 设置', + name: 'langchain', +}); + export type ChatModelConfigModuleConfigType = Static; +export type LangChainConfigSchemaType = Static; declare module '@eggjs/tegg' { export type LangChainModuleConfig = { ChatModel?: ChatModelConfigModuleConfigType; + langchain?: LangChainConfigSchema; }; export interface ModuleConfig extends LangChainModuleConfig {