diff --git a/.gitignore b/.gitignore index ddcda146..9de51cab 100644 --- a/.gitignore +++ b/.gitignore @@ -11,6 +11,7 @@ run .DS_Store .tmp .vscode +.claude package-lock.json yarn.lock diff --git a/core/common-util/src/ModuleConfigs.ts b/core/common-util/src/ModuleConfigs.ts index b3f06789..6738bdb6 100644 --- a/core/common-util/src/ModuleConfigs.ts +++ b/core/common-util/src/ModuleConfigs.ts @@ -7,4 +7,8 @@ export class ModuleConfigs { get(moduleName: string): ModuleConfig | undefined { return this.inner[moduleName]?.config; } + + * [Symbol.iterator](): Iterator<[string, ModuleConfigHolder]> { + yield* Object.entries(this.inner); + } } diff --git a/core/common-util/test/ModuleConfig.test.ts b/core/common-util/test/ModuleConfig.test.ts index a611b412..d4e43872 100644 --- a/core/common-util/test/ModuleConfig.test.ts +++ b/core/common-util/test/ModuleConfig.test.ts @@ -1,6 +1,7 @@ import { strict as assert } from 'node:assert'; import path from 'node:path'; import { ModuleConfigUtil } from '../src/ModuleConfig'; +import { ModuleConfigs } from '../src/ModuleConfigs'; import type { ModuleReference } from '@eggjs/tegg-types'; describe('test/ModuleConfig.test.ts', () => { @@ -140,6 +141,45 @@ describe('test/ModuleConfig.test.ts', () => { }]); }); }); + + it('should iterate over all module configs', () => { + const mockInner = { + module1: { + name: 'module1', + reference: { path: '/path/to/module1', name: 'module1' }, + config: { foo: 'bar' }, + }, + module2: { + name: 'module2', + reference: { path: '/path/to/module2', name: 'module2' }, + config: { baz: 'qux' }, + }, + }; + + const moduleConfigs = new ModuleConfigs(mockInner); + const result: Array<[string, any]> = []; + + for (const [ name, holder ] of moduleConfigs) { + result.push([ name, holder ]); + } + + assert.strictEqual(result.length, 2); + assert.strictEqual(result[0][0], 'module1'); + assert.deepStrictEqual(result[0][1], mockInner.module1); + assert.strictEqual(result[1][0], 'module2'); + assert.deepStrictEqual(result[1][1], mockInner.module2); + }); + + it('should work with empty configs', () => { + const moduleConfigs = new ModuleConfigs({}); + const result: Array<[string, any]> = []; + + for (const [ name, holder ] of moduleConfigs) { + result.push([ name, holder ]); + } + + assert.strictEqual(result.length, 0); + }); }); describe('ModuleConfigUtil.deduplicateModules', () => { @@ -323,3 +363,4 @@ describe('ModuleConfigUtil.deduplicateModules', () => { }); }); }); + diff --git a/core/langchain-decorator/src/util/GraphInfoUtil.ts b/core/langchain-decorator/src/util/GraphInfoUtil.ts index b81f85d6..03e125d4 100644 --- a/core/langchain-decorator/src/util/GraphInfoUtil.ts +++ b/core/langchain-decorator/src/util/GraphInfoUtil.ts @@ -15,6 +15,15 @@ export class GraphInfoUtil { return MetadataUtil.getMetaData(GRAPH_GRAPH_METADATA, clazz); } + static getGraphByName(graphName: string): { clazz: EggProtoImplClass; metadata: IGraphMetadata } | undefined { + for (const [ clazz, metadata ] of GraphInfoUtil.graphMap.entries()) { + if (metadata.name === graphName) { + return { clazz, metadata }; + } + } + return undefined; + } + static getAllGraphMetadata(): Map { return GraphInfoUtil.graphMap; } diff --git a/plugin/langchain/app/controller/RunsController.ts b/plugin/langchain/app/controller/RunsController.ts new file mode 100644 index 00000000..9c603859 --- /dev/null +++ b/plugin/langchain/app/controller/RunsController.ts @@ -0,0 +1,145 @@ +import { + HTTPController, + HTTPMethod, + HTTPMethodEnum, + HTTPBody, + Context, + Middleware, + Inject, +} from '@eggjs/tegg'; +import type { EggContext } from '@eggjs/tegg'; +import type { RunCreateDTO } from './types'; +import { streamSSE } from '../../lib/sse'; +import { RunCreate } from './schemas'; +import { ZodErrorMiddleware } from '../middleware/ZodErrorMiddleware'; +import { RunsService } from '../../lib/runs/RunsService'; + +/** + * LangGraph Runs Controller + * 处理 Run 相关的 HTTP 请求 + */ +@HTTPController({ + path: '/api', +}) +@Middleware(ZodErrorMiddleware) +export class RunsController { + @Inject() + runsService: RunsService; + /** + * POST /api/runs/stream + * 流式创建无状态 Run (SSE) + * + * 对应 LangGraph runs.mts 的 api.post("/runs/stream", ...) 端点 + */ + @HTTPMethod({ + method: HTTPMethodEnum.POST, + path: '/runs/stream', + }) + async streamStatelessRun(@Context() ctx: EggContext, @HTTPBody() payload: RunCreateDTO) { + const validated = RunCreate.parse(payload); + + // 使用 RunsService 创建并验证 run + const run = await this.runsService.createValidRun( + undefined, // threadId (无状态 run) + validated, + { + // auth: ctx.auth, // TODO: 集成认证系统 + headers: ctx.headers, + }, + ); + + console.log('streamStatelessRun', { + run, + agentConfigs: this.runsService.getAllAgentConfigs(), + }); + + // 设置 Content-Location header + ctx.set('Content-Location', `/runs/${run.run_id}`); + + // 类型断言帮助访问 input 中的 messages + const inputData = validated.input as { messages?: Array<{ role: string; content: string }> } | undefined; + + // 使用 SSE 流式返回 + return streamSSE(ctx, async stream => { + // 如果需要在断开连接时取消,创建 AbortSignal + // const cancelOnDisconnect = validated.on_disconnect === 'cancel' + // ? getDisconnectAbortSignal(ctx, stream) + // : undefined; + + try { + // TODO: 调用 runs service 的 stream.join 方法获取运行结果 + // for await (const { event, data } of runs().stream.join( + // runId, + // undefined, + // { + // cancelOnDisconnect, + // lastEventId: validated.stream_resumable ? "-1" : undefined, + // ignore404: true, + // }, + // auth + // )) { + // await stream.writeSSE({ data: JSON.stringify(data), event }); + // } + + // Mock 实现:模拟 SSE 流式响应 + // 1. 发送 metadata 事件 + await stream.writeSSE({ + event: 'metadata', + data: JSON.stringify({ + run_id: run.run_id, + assistant_id: validated.assistant_id || 'mock_assistant', + }), + }); + + await stream.sleep(100); + + // 2. 发送 values 事件 - 模拟开始处理 + await stream.writeSSE({ + event: 'values', + data: JSON.stringify({ + messages: [ + { + role: 'user', + content: inputData?.messages?.[0]?.content || 'Hello', + }, + ], + }), + }); + + await stream.sleep(500); + + // 3. 发送 values 事件 - 模拟 AI 响应 + await stream.writeSSE({ + event: 'values', + data: JSON.stringify({ + messages: [ + { + role: 'user', + content: inputData?.messages?.[0]?.content || 'Hello', + }, + { + role: 'assistant', + content: `Mock response to: ${inputData?.messages?.[0]?.content || 'Hello'}`, + }, + ], + }), + }); + + await stream.sleep(200); + + // 4. 发送 end 事件 + await stream.writeSSE({ + event: 'end', + data: JSON.stringify({ + run_id: run.run_id, + status: 'completed', + }), + }); + } catch (error) { + console.error('Error streaming run:', error); + throw error; + } + }); + } + +} diff --git a/plugin/langchain/app/controller/schemas.ts b/plugin/langchain/app/controller/schemas.ts new file mode 100644 index 00000000..9e872550 --- /dev/null +++ b/plugin/langchain/app/controller/schemas.ts @@ -0,0 +1,340 @@ +/** + * LangGraph API Schemas + * 从 langgraphjs/libs/langgraph-api/src/schemas.mts 迁移 + * + * 注意:此文件需要 zod 依赖 + * 安装: npm install zod + */ + +import { z } from 'zod'; + +// 基础工具 schema +export const coercedBoolean = z.string().transform((val) => { + const lower = val.toLowerCase(); + return lower === 'true' || lower === '1' || lower === 'yes'; +}); + +// Config schemas +export const AssistantConfigurable = z + .object({ + thread_id: z.string().optional(), + thread_ts: z.string().optional(), + }) + .catchall(z.unknown()); + +export const AssistantConfig = z + .object({ + tags: z.array(z.string()).optional(), + recursion_limit: z.number().int().optional(), + configurable: AssistantConfigurable.optional(), + }) + .catchall(z.unknown()) + .describe('The configuration of an assistant.'); + +export const Config = z.object({ + tags: z.array(z.string()).optional(), + recursion_limit: z.number().int().optional(), + configurable: z.object({}).catchall(z.any()).optional(), +}); + +// Checkpoint schema +export const CheckpointSchema = z.object({ + checkpoint_id: z.string().uuid().optional(), + checkpoint_ns: z.string().nullish(), + checkpoint_map: z.record(z.unknown()).nullish(), +}); + +// Command schema +export const CommandSchema = z.object({ + goto: z + .union([ + z.union([ + z.string(), + z.object({ node: z.string(), input: z.unknown().optional() }), + ]), + z.array( + z.union([ + z.string(), + z.object({ node: z.string(), input: z.unknown().optional() }), + ]) + ), + ]) + .optional(), + update: z + .union([z.record(z.unknown()), z.array(z.tuple([z.string(), z.unknown()]))]) + .optional(), + resume: z.unknown().optional(), +}); + +// Langsmith tracer schema +export const LangsmithTracer = z.object({ + project_name: z.string().optional(), + example_id: z.string().optional(), +}); + +// Run schemas +export const Run = z.object({ + run_id: z.string().uuid(), + thread_id: z.string().uuid(), + assistant_id: z.string().uuid(), + created_at: z.string(), + updated_at: z.string(), + status: z.enum([ + 'pending', + 'running', + 'error', + 'success', + 'timeout', + 'interrupted', + ]), + metadata: z.object({}).catchall(z.any()), + kwargs: z.object({}).catchall(z.any()), + multitask_strategy: z.enum(['reject', 'rollback', 'interrupt', 'enqueue']), +}); + +export const RunCreate = z + .object({ + assistant_id: z.union([z.string().uuid(), z.string()]), + checkpoint_id: z.string().optional(), + checkpoint: CheckpointSchema.optional(), + input: z.union([z.unknown(), z.null()]).optional(), + command: CommandSchema.optional(), + metadata: z + .object({}) + .catchall(z.any()) + .describe('Metadata for the run.') + .optional(), + context: z.unknown().optional(), + config: AssistantConfig.optional(), + webhook: z.string().optional(), + interrupt_before: z.union([z.enum(['*']), z.array(z.string())]).optional(), + interrupt_after: z.union([z.enum(['*']), z.array(z.string())]).optional(), + on_disconnect: z + .enum(['cancel', 'continue']) + .optional() + .default('continue'), + multitask_strategy: z + .enum(['reject', 'rollback', 'interrupt', 'enqueue']) + .optional(), + stream_mode: z + .union([ + z.array( + z.enum([ + 'values', + 'messages', + 'messages-tuple', + 'updates', + 'events', + 'tasks', + 'checkpoints', + 'debug', + 'custom', + ]) + ), + z.enum([ + 'values', + 'messages', + 'messages-tuple', + 'updates', + 'events', + 'tasks', + 'checkpoints', + 'debug', + 'custom', + ]), + ]) + .optional(), + stream_subgraphs: z.boolean().optional(), + stream_resumable: z.boolean().optional(), + after_seconds: z.number().optional(), + if_not_exists: z.enum(['reject', 'create']).optional(), + on_completion: z.enum(['delete', 'keep']).optional(), + feedback_keys: z.array(z.string()).optional(), + langsmith_tracer: LangsmithTracer.optional(), + }) + .describe('Payload for creating a stateful run.'); + +export const RunBatchCreate = z + .array(RunCreate) + .min(1) + .describe('Payload for creating a batch of runs.'); + +export const SearchResult = z + .object({ + metadata: z + .object({}) + .catchall(z.any()) + .describe('Metadata to search for.') + .optional(), + limit: z + .number() + .int() + .gte(1) + .lte(1000) + .describe('Maximum number to return.') + .optional(), + offset: z + .number() + .int() + .gte(0) + .describe('Offset to start from.') + .optional(), + }) + .describe('Payload for listing runs.'); + +// Cron schemas +export const Cron = z.object({ + cron_id: z.string().uuid(), + thread_id: z.string().uuid(), + end_time: z.string(), + schedule: z.string(), + created_at: z.string(), + updated_at: z.string(), + payload: z.object({}).catchall(z.any()), +}); + +export const CronCreate = z + .object({ + thread_id: z.string().uuid(), + assistant_id: z.string().uuid(), + checkpoint_id: z.string().optional(), + input: z + .union([ + z.array(z.object({}).catchall(z.any())), + z.object({}).catchall(z.any()), + ]) + .optional(), + metadata: z + .object({}) + .catchall(z.any()) + .describe('Metadata for the run.') + .optional(), + config: AssistantConfig.optional(), + context: z.unknown().optional(), + webhook: z.string().optional(), + interrupt_before: z.union([z.enum(['*']), z.array(z.string())]).optional(), + interrupt_after: z.union([z.enum(['*']), z.array(z.string())]).optional(), + multitask_strategy: z + .enum(['reject', 'rollback', 'interrupt', 'enqueue']) + .optional(), + }) + .describe('Payload for creating a cron.'); + +export const CronSearch = z + .object({ + assistant_id: z.string().uuid().optional(), + thread_id: z.string().uuid().optional(), + limit: z + .number() + .int() + .gte(1) + .lte(1000) + .describe('Maximum number to return.') + .optional(), + offset: z + .number() + .int() + .gte(0) + .describe('Offset to start from.') + .optional(), + }) + .describe('Payload for listing crons'); + +// Thread schemas +export const Thread = z.object({ + thread_id: z.string().uuid(), + created_at: z.string(), + updated_at: z.string(), + metadata: z.record(z.unknown()).optional(), + status: z.enum(['idle', 'busy', 'interrupted', 'error']).optional(), +}); + +export const ThreadCreate = z + .object({ + supersteps: z + .array( + z.object({ + updates: z.array( + z.object({ + values: z.unknown().nullish(), + command: CommandSchema.nullish(), + as_node: z.string(), + }) + ), + }) + ) + .describe('The supersteps to apply to the thread.') + .optional(), + thread_id: z + .string() + .uuid() + .describe('The ID of the thread. If not provided, an ID is generated.') + .optional(), + metadata: z + .object({}) + .catchall(z.any()) + .describe('Metadata for the thread.') + .optional(), + if_exists: z + .union([z.literal('raise'), z.literal('do_nothing')]) + .optional(), + }) + .describe('Payload for creating a thread.'); + +export const ThreadPatch = z + .object({ + metadata: z + .object({}) + .catchall(z.any()) + .describe('Metadata to merge with existing thread metadata.') + .optional(), + }) + .describe('Payload for patching a thread.'); + +// Assistant schemas +export const Assistant = z.object({ + assistant_id: z.string().uuid(), + graph_id: z.string(), + config: AssistantConfig, + created_at: z.string(), + updated_at: z.string(), + metadata: z.object({}).catchall(z.any()), +}); + +export const AssistantCreate = z + .object({ + assistant_id: z + .string() + .uuid() + .describe('The ID of the assistant. If not provided, an ID is generated.') + .optional(), + graph_id: z.string().describe('The graph to use.'), + config: AssistantConfig.optional(), + context: z.unknown().optional(), + metadata: z + .object({}) + .catchall(z.unknown()) + .describe('Metadata for the assistant.') + .optional(), + if_exists: z + .union([z.literal('raise'), z.literal('do_nothing')]) + .optional(), + name: z.string().optional(), + description: z.string().optional(), + }) + .describe('Payload for creating an assistant.'); + +export const AssistantPatch = z + .object({ + graph_id: z.string().describe('The graph to use.').optional(), + config: AssistantConfig.optional(), + context: z.unknown().optional(), + name: z.string().optional(), + description: z.string().optional(), + metadata: z + .object({}) + .catchall(z.any()) + .describe('Metadata to merge with existing assistant metadata.') + .optional(), + }) + .describe('Payload for updating an assistant.'); diff --git a/plugin/langchain/app/controller/types.ts b/plugin/langchain/app/controller/types.ts new file mode 100644 index 00000000..ee683932 --- /dev/null +++ b/plugin/langchain/app/controller/types.ts @@ -0,0 +1,250 @@ +/** + * LangGraph API Types + * 从 langgraphjs/libs/langgraph-api/src/storage/types.mts 迁移 + */ + +export type Metadata = Record; + +export type AssistantSelectField = + | 'assistant_id' + | 'graph_id' + | 'name' + | 'description' + | 'config' + | 'context' + | 'created_at' + | 'updated_at' + | 'metadata' + | 'version'; + +export type ThreadSelectField = + | 'thread_id' + | 'created_at' + | 'updated_at' + | 'metadata' + | 'config' + | 'context' + | 'status' + | 'values' + | 'interrupts'; + +export type ThreadStatus = 'idle' | 'busy' | 'interrupted' | 'error'; + +export type RunStatus = + | 'pending' + | 'running' + | 'error' + | 'success' + | 'timeout' + | 'interrupted'; + +export type StreamMode = + | 'values' + | 'messages' + | 'messages-tuple' + | 'custom' + | 'updates' + | 'events' + | 'debug' + | 'tasks' + | 'checkpoints'; + +export type MultitaskStrategy = 'reject' | 'rollback' | 'interrupt' | 'enqueue'; + +export type OnConflictBehavior = 'raise' | 'do_nothing'; + +export type IfNotExists = 'create' | 'reject'; + +export type OnDisconnect = 'cancel' | 'continue'; + +export type OnCompletion = 'delete' | 'keep'; + +export interface RunnableConfig { + tags?: string[]; + recursion_limit?: number; + configurable?: { + thread_id?: string; + thread_ts?: string; + checkpoint_id?: string; + checkpoint_ns?: string; + checkpoint_map?: Record; + langgraph_auth_user?: unknown; + langgraph_auth_user_id?: string; + langgraph_auth_permissions?: string[]; + langsmith_project?: string; + langsmith_example_id?: string; + [key: string]: unknown; + }; + metadata?: Record; +} + +export interface RunCommand { + goto?: + | string + | { node: string; input?: unknown } + | Array; + update?: Record | Array<[string, unknown]>; + resume?: unknown; +} + +export interface CheckpointSchema { + checkpoint_id?: string; + checkpoint_ns?: string | null; + checkpoint_map?: Record | null; +} + +export interface LangsmithTracer { + project_name?: string; + example_id?: string; +} + +export interface RunKwargs { + input?: unknown; + command?: RunCommand; + stream_mode?: Array; + interrupt_before?: '*' | string[] | undefined; + interrupt_after?: '*' | string[] | undefined; + config?: RunnableConfig; + context?: unknown; + subgraphs?: boolean; + resumable?: boolean; + temporary?: boolean; + webhook?: unknown; + feedback_keys?: string[] | undefined; + [key: string]: unknown; +} + +export interface Run { + run_id: string; + thread_id: string; + assistant_id: string; + created_at: Date; + updated_at: Date; + status: RunStatus; + metadata: Metadata; + kwargs: RunKwargs; + multitask_strategy: MultitaskStrategy; +} + +export interface Assistant { + name: string; + description: string | null; + assistant_id: string; + graph_id: string; + created_at: Date; + updated_at: Date; + version: number; + config: RunnableConfig; + context: unknown; + metadata: Metadata; +} + +export interface Thread { + thread_id: string; + created_at: Date; + updated_at: Date; + metadata?: Metadata; + config?: RunnableConfig; + status: ThreadStatus; + values?: Record; + interrupts?: Record; +} + +export interface Checkpoint { + thread_id: string; + checkpoint_ns: string; + checkpoint_id: string | null; + checkpoint_map: Record | null; +} + +export interface CheckpointTask { + id: string; + name: string; + error?: string; + interrupts: Record; + state?: RunnableConfig; +} + +export interface ThreadTask { + id: string; + name: string; + error: string | null; + interrupts: Record[]; + checkpoint: Checkpoint | null; + state: ThreadState | null; + result: unknown | null; +} + +export interface ThreadState { + values: Record; + next: string[]; + checkpoint: Checkpoint | null; + metadata: Record | undefined; + created_at: Date | null; + parent_checkpoint: Checkpoint | null; + tasks: ThreadTask[]; +} + +// DTO 类型 - 用于 HTTP 请求/响应 + +export interface RunCreateDTO { + assistant_id: string; + checkpoint_id?: string; + checkpoint?: CheckpointSchema; + input?: unknown | null; + command?: RunCommand; + metadata?: Metadata; + context?: unknown; + config?: RunnableConfig; + webhook?: string; + interrupt_before?: '*' | string[]; + interrupt_after?: '*' | string[]; + on_disconnect?: OnDisconnect; + multitask_strategy?: MultitaskStrategy; + stream_mode?: StreamMode | StreamMode[]; + stream_subgraphs?: boolean; + stream_resumable?: boolean; + after_seconds?: number; + if_not_exists?: IfNotExists; + on_completion?: OnCompletion; + feedback_keys?: string[]; + langsmith_tracer?: LangsmithTracer; +} + +export interface RunSearchDTO { + limit?: number; + offset?: number; + status?: string; + metadata?: Metadata; +} + +export interface CronCreateDTO { + thread_id: string; + assistant_id: string; + checkpoint_id?: string; + input?: unknown[] | Record; + metadata?: Metadata; + config?: RunnableConfig; + context?: unknown; + webhook?: string; + interrupt_before?: '*' | string[]; + interrupt_after?: '*' | string[]; + multitask_strategy?: MultitaskStrategy; +} + +export interface CronSearchDTO { + assistant_id?: string; + thread_id?: string; + limit?: number; + offset?: number; +} + +// Auth Context (需要适配 Tegg 的认证系统) +export interface AuthContext { + user: { + id: string; + identity?: string; + [key: string]: unknown; + }; + scopes: string[]; +} diff --git a/plugin/langchain/app/middleware/ZodErrorMiddleware.ts b/plugin/langchain/app/middleware/ZodErrorMiddleware.ts new file mode 100644 index 00000000..fc1e3128 --- /dev/null +++ b/plugin/langchain/app/middleware/ZodErrorMiddleware.ts @@ -0,0 +1,31 @@ +/** + * Validation Error Middleware + * 统一处理 Zod 验证错误 + */ + +import type { EggContext, Next } from '@eggjs/tegg'; +import { ZodError } from 'zod'; + +export async function ZodErrorMiddleware(ctx: EggContext, next: Next): Promise { + try { + await next(); + } catch (error) { + console.log('ZodErrorMiddleware Catch Error', error); + // 捕获 ZodError 并返回 422 响应 + if (error instanceof ZodError) { + ctx.status = 422; + ctx.body = { + error: 'Validation failed', + details: error.errors.map(e => ({ + path: e.path.join('.'), + message: e.message, + code: e.code, + })), + }; + return; + } + + // 其他错误继续抛出 + throw error; + } +} diff --git a/plugin/langchain/lib/assistants/AssistantsInitService.ts b/plugin/langchain/lib/assistants/AssistantsInitService.ts new file mode 100644 index 00000000..17abf2da --- /dev/null +++ b/plugin/langchain/lib/assistants/AssistantsInitService.ts @@ -0,0 +1,146 @@ +/* eslint-disable @typescript-eslint/no-unused-vars */ +import { AccessLevel, Inject, LifecyclePostInject, ModuleConfigs, SingletonProto } from '@eggjs/tegg'; +import { GraphInfoUtil } from '@eggjs/tegg-langchain-decorator'; +import { AssistantsRepository } from './AssistantsRepository'; +import { v5 as uuidv5, parse as uuidParse } from 'uuid'; + +/** + * UUID Namespace for generating assistant IDs + * 使用与 langgraphjs 相同的 namespace + */ +const NAMESPACE_GRAPH = uuidParse('6ba7b821-9dad-11d1-80b4-00c04fd430c8'); + +/** + * Assistants 初始化服务 + * 在应用启动时,从 GraphInfoUtil 和 moduleConfigs 中加载所有 graphs + * 并将它们注册为 assistants + * + * 参考: langgraphjs/libs/langgraph-api/src/graph/load.mts 的 registerFromEnv + */ +@SingletonProto({ accessLevel: AccessLevel.PUBLIC }) +export class AssistantsInitService { + @Inject() + private readonly moduleConfigs: ModuleConfigs; + + @Inject() + private readonly assistantsRepository: AssistantsRepository; + + @LifecyclePostInject() + protected async init() { + console.log('🚀 Initializing Assistants from GraphInfoUtil and moduleConfigs...'); + + // 1. 从 GraphInfoUtil 加载所有注册的 graphs + await this.registerGraphsFromUtil(); + + // 2. 从 moduleConfigs 加载 agents 配置 + await this.registerGraphsFromModuleConfigs(); + + // 3. 输出注册的 assistants + const allAssistants = this.assistantsRepository.getAll(); + console.log(`✅ Registered ${allAssistants.length} assistants:`); + allAssistants.forEach(assistant => { + console.log(` - ${assistant.name} (graph_id: ${assistant.graph_id}, assistant_id: ${assistant.assistant_id})`); + }); + } + + /** + * 从 GraphInfoUtil 注册所有 graphs + */ + private async registerGraphsFromUtil() { + const graphMap = GraphInfoUtil.getAllGraphMetadata(); + + for (const [ _clazz, metadata ] of graphMap.entries()) { + if (!metadata.name) { + console.warn('⚠️ Graph metadata missing name, skipping registration'); + continue; + } + + const graphId = metadata.name; + const assistantId = this.generateAssistantId(graphId); + + console.log(`📦 Registering graph from GraphInfoUtil: ${graphId}`); + + await this.assistantsRepository.put( + assistantId, + { + graph_id: graphId, + metadata: { + created_by: 'system', + source: 'GraphInfoUtil', + }, + config: {}, + context: undefined, + if_exists: 'do_nothing', + name: graphId, + description: `Graph loaded from GraphInfoUtil: ${graphId}`, + }, + undefined, + ); + } + } + + /** + * 从 moduleConfigs 注册 agents 配置中的 graphs + */ + private async registerGraphsFromModuleConfigs() { + for (const [ moduleName, moduleInfo ] of this.moduleConfigs) { + if (!moduleInfo.config.agents) continue; + + const agents = moduleInfo.config.agents || {}; + + for (const [ agentName, agentConfig ] of Object.entries(agents)) { + const graphId = agentName; + const assistantId = this.generateAssistantId(graphId); + + console.log(`📦 Registering graph from moduleConfigs: ${graphId} (module: ${moduleName})`); + + await this.assistantsRepository.put( + assistantId, + { + graph_id: graphId, + metadata: { + created_by: 'system', + source: 'moduleConfigs', + module: moduleName, + }, + config: agentConfig as any, + context: undefined, + if_exists: 'do_nothing', + name: graphId, + description: `Graph loaded from moduleConfigs (module: ${moduleName})`, + }, + undefined, + ); + } + } + } + + /** + * 生成 assistant_id + * 使用与 langgraphjs 相同的方式: uuid.v5(graphId, NAMESPACE_GRAPH) + */ + private generateAssistantId(graphId: string): string { + return uuidv5(graphId, NAMESPACE_GRAPH); + } + + /** + * 根据 graphId 获取 assistantId + */ + public getAssistantId(graphId: string): string { + return this.generateAssistantId(graphId); + } + + /** + * 根据 assistantId 获取 assistant + */ + public async getAssistant(assistantId: string) { + return this.assistantsRepository.get(assistantId); + } + + /** + * 根据 graphId 获取 assistant + */ + public async getAssistantByGraphId(graphId: string) { + return this.assistantsRepository.getByGraphId(graphId); + } +} diff --git a/plugin/langchain/lib/assistants/AssistantsRepository.ts b/plugin/langchain/lib/assistants/AssistantsRepository.ts new file mode 100644 index 00000000..922fb565 --- /dev/null +++ b/plugin/langchain/lib/assistants/AssistantsRepository.ts @@ -0,0 +1,162 @@ +import { AccessLevel, SingletonProto } from '@eggjs/tegg'; +import type { + Assistant, + Metadata, + RunnableConfig, + OnConflictBehavior, + AuthContext, +} from '../../app/controller/types'; + +export interface AssistantsPutOptions { + config?: RunnableConfig; + context?: unknown; + graph_id: string; + metadata?: Metadata; + if_exists: OnConflictBehavior; + name?: string; + description?: string; +} + +/** + * Assistants 存储层接口 + * 从 langgraphjs/libs/langgraph-api/src/storage/types.mts 的 AssistantsRepo 移植 + */ +export interface IAssistantsRepository { + get(assistantId: string, auth?: AuthContext): Promise; + + put( + assistantId: string, + options: AssistantsPutOptions, + auth?: AuthContext, + ): Promise; + + search( + options: { + graph_id?: string; + name?: string; + metadata?: Metadata; + limit?: number; + offset?: number; + }, + auth?: AuthContext, + ): Promise; + + getByGraphId(graphId: string): Promise; +} + +/** + * 内存版本的 Assistants 存储实现 + * TODO: 后续可替换为数据库实现 + */ +@SingletonProto({ accessLevel: AccessLevel.PUBLIC }) +export class AssistantsRepository implements IAssistantsRepository { + // 存储所有 assistants: Map + private assistants: Map = new Map(); + + // 按 graph_id 索引: Map + private assistantsByGraphId: Map = new Map(); + + async put( + assistantId: string, + options: AssistantsPutOptions, + _auth?: AuthContext, + ): Promise { + const { + config = {}, + context, + graph_id: graphId, + metadata = {}, + if_exists: ifExists, + name, + description = null, + } = options; + + // 检查是否已存在 + const existing = this.assistants.get(assistantId); + if (existing) { + if (ifExists === 'raise') { + throw new Error(`Assistant with id "${assistantId}" already exists`); + } else if (ifExists === 'do_nothing') { + return existing; + } + } + + const now = new Date(); + const assistant: Assistant = { + assistant_id: assistantId, + graph_id: graphId, + name: name || graphId, + description, + config, + context, + metadata, + created_at: existing?.created_at ?? now, + updated_at: now, + version: (existing?.version ?? 0) + 1, + }; + + // 存储 assistant + this.assistants.set(assistantId, assistant); + this.assistantsByGraphId.set(graphId, assistantId); + + return assistant; + } + + async get(assistantId: string, _auth?: AuthContext): Promise { + return this.assistants.get(assistantId) || null; + } + + async getByGraphId(graphId: string): Promise { + const assistantId = this.assistantsByGraphId.get(graphId); + if (!assistantId) return null; + return this.assistants.get(assistantId) || null; + } + + async search( + options: { + graph_id?: string; + name?: string; + metadata?: Metadata; + limit?: number; + offset?: number; + }, + _auth?: AuthContext, + ): Promise { + const { graph_id, name, metadata, limit = 10, offset = 0 } = options; + + let results: Assistant[] = Array.from(this.assistants.values()); + + // 过滤 graph_id + if (graph_id) { + results = results.filter(a => a.graph_id === graph_id); + } + + // 过滤 name + if (name) { + results = results.filter(a => a.name === name); + } + + // 过滤 metadata + if (metadata) { + results = results.filter(a => { + for (const [ key, value ] of Object.entries(metadata)) { + if (a.metadata[key] !== value) return false; + } + return true; + }); + } + + // 按创建时间倒序排列 + results.sort((a, b) => b.created_at.getTime() - a.created_at.getTime()); + + // 分页 + return results.slice(offset, offset + limit); + } + + /** + * 获取所有 assistants(用于初始化检查) + */ + getAll(): Assistant[] { + return Array.from(this.assistants.values()); + } +} diff --git a/plugin/langchain/lib/runs/Graph.ts b/plugin/langchain/lib/runs/Graph.ts new file mode 100644 index 00000000..71ee06c2 --- /dev/null +++ b/plugin/langchain/lib/runs/Graph.ts @@ -0,0 +1,27 @@ +import { AccessLevel, SingletonProto } from '@eggjs/tegg'; +import { GraphInfoUtil } from '@eggjs/tegg-langchain-decorator'; +import * as uuid from 'uuid'; + +// Magic NAMESPACE from +// https://github.com/langchain-ai/langgraphjs/blob/main/libs/langgraph-api/src/graph/load.mts#L27 +// Dont ask me why... +export const NAMESPACE_GRAPH = uuid.parse( + '6ba7b821-9dad-11d1-80b4-00c04fd430c8', +); + +@SingletonProto({ accessLevel: AccessLevel.PRIVATE }) +export class Graph { + + public getAssistantId(graphId: string): string { + if (GraphInfoUtil.getGraphByName(graphId)) { + return uuid.v5(graphId, NAMESPACE_GRAPH); + } + + return graphId; + } + + public getGraph(graphId: string) { + return GraphInfoUtil.getGraphByName(graphId); + } +} + diff --git a/plugin/langchain/lib/runs/RunsRepository.ts b/plugin/langchain/lib/runs/RunsRepository.ts new file mode 100644 index 00000000..41d05235 --- /dev/null +++ b/plugin/langchain/lib/runs/RunsRepository.ts @@ -0,0 +1,279 @@ +/* eslint-disable @typescript-eslint/no-unused-vars */ +import { AccessLevel, SingletonProto } from '@eggjs/tegg'; +import type { + Run, + RunKwargs, + RunStatus, + Metadata, + MultitaskStrategy, + IfNotExists, + AuthContext, +} from '../../app/controller/types'; + +export interface RunsPutOptions { + threadId?: string; + userId?: string; + status?: RunStatus; + metadata?: Metadata; + preventInsertInInflight?: boolean; + multitaskStrategy?: MultitaskStrategy; + ifNotExists?: IfNotExists; + afterSeconds?: number; +} + +/** + * Runs 存储层接口 + * 从 langgraphjs/libs/langgraph-api/src/storage/types.mts 的 RunsRepo 移植 + */ +export interface IRunsRepository { + put( + runId: string, + assistantId: string, + kwargs: RunKwargs, + options: RunsPutOptions, + auth: AuthContext | undefined, + ): Promise; + + get( + runId: string, + threadId: string | undefined, + auth: AuthContext | undefined, + ): Promise; + + delete( + runId: string, + threadId: string | undefined, + auth: AuthContext | undefined, + ): Promise; + + cancel( + threadId: string | undefined, + runIds: string[], + options: { + action?: 'interrupt' | 'rollback'; + }, + auth: AuthContext | undefined, + ): Promise; + + search( + threadId: string, + options: { + limit?: number | null; + offset?: number | null; + status?: string | null; + metadata?: Metadata | null; + }, + auth: AuthContext | undefined, + ): Promise; +} + +/** + * 内存版本的 Runs 存储实现 + * TODO: 后续可替换为数据库实现(如 PostgreSQL, MongoDB 等) + */ +@SingletonProto({ accessLevel: AccessLevel.PUBLIC }) +export class RunsRepository implements IRunsRepository { + // 存储所有 runs: Map + private runs: Map = new Map(); + + // 按 threadId 索引: Map> + private runsByThread: Map> = new Map(); + + /** + * 创建并存储一个 Run + * 返回数组:[新创建的 run, ...冲突的 inflight runs] + */ + async put( + runId: string, + assistantId: string, + kwargs: RunKwargs, + options: RunsPutOptions, + _auth: AuthContext | undefined, + ): Promise { + const { + threadId, + status = 'pending', + metadata = {}, + preventInsertInInflight = false, + multitaskStrategy = 'reject', + ifNotExists, + } = options; + + // 检查是否已存在 + if (ifNotExists === 'reject' && this.runs.has(runId)) { + // 如果设置了 ifNotExists=reject,且 run 已存在,则返回空数组 + return []; + } + + // 如果有 threadId,查找该 thread 下正在运行的 runs + const inflightRuns: Run[] = []; + if (threadId && preventInsertInInflight) { + const threadRunIds = this.runsByThread.get(threadId) || new Set(); + for (const existingRunId of threadRunIds) { + const existingRun = this.runs.get(existingRunId); + if ( + existingRun && + (existingRun.status === 'pending' || existingRun.status === 'running') + ) { + inflightRuns.push(existingRun); + } + } + } + + // 创建新的 Run + const now = new Date(); + const run: Run = { + run_id: runId, + thread_id: threadId || '', + assistant_id: assistantId, + created_at: now, + updated_at: now, + status, + metadata, + kwargs, + multitask_strategy: multitaskStrategy, + }; + + // 存储 run + this.runs.set(runId, run); + + // 更新 threadId 索引 + if (threadId) { + if (!this.runsByThread.has(threadId)) { + this.runsByThread.set(threadId, new Set()); + } + this.runsByThread.get(threadId)!.add(runId); + } + + // 返回 [新 run, ...冲突的 runs] + return [ run, ...inflightRuns ]; + } + + async get( + runId: string, + threadId: string | undefined, + _auth: AuthContext | undefined, + ): Promise { + const run = this.runs.get(runId); + if (!run) return null; + + // 如果指定了 threadId,验证是否匹配 + if (threadId && run.thread_id !== threadId) { + return null; + } + + return run; + } + + async delete( + runId: string, + threadId: string | undefined, + _auth: AuthContext | undefined, + ): Promise { + const run = this.runs.get(runId); + if (!run) return null; + + // 如果指定了 threadId,验证是否匹配 + if (threadId && run.thread_id !== threadId) { + return null; + } + + // 删除 run + this.runs.delete(runId); + + // 从 threadId 索引中删除 + if (run.thread_id) { + const threadRuns = this.runsByThread.get(run.thread_id); + if (threadRuns) { + threadRuns.delete(runId); + if (threadRuns.size === 0) { + this.runsByThread.delete(run.thread_id); + } + } + } + + return runId; + } + + async cancel( + threadId: string | undefined, + runIds: string[], + _options: { + action?: 'interrupt' | 'rollback'; + }, + _auth: AuthContext | undefined, + ): Promise { + for (const runId of runIds) { + const run = this.runs.get(runId); + if (!run) continue; + + // 如果指定了 threadId,验证是否匹配 + if (threadId && run.thread_id !== threadId) { + continue; + } + + // 更新状态为 interrupted + run.status = 'interrupted'; + run.updated_at = new Date(); + } + } + + async search( + threadId: string, + options: { + limit?: number | null; + offset?: number | null; + status?: string | null; + metadata?: Metadata | null; + }, + _auth: AuthContext | undefined, + ): Promise { + const { limit = 10, offset = 0, status, metadata } = options; + + // 获取该 thread 下的所有 runs + const threadRunIds = this.runsByThread.get(threadId); + if (!threadRunIds) return []; + + const results: Run[] = []; + for (const runId of threadRunIds) { + const run = this.runs.get(runId); + if (!run) continue; + + // 过滤状态 + if (status && run.status !== status) continue; + + // 过滤 metadata(简单版本:检查所有 key 是否匹配) + if (metadata) { + let metadataMatch = true; + for (const [ key, value ] of Object.entries(metadata)) { + if (run.metadata[key] !== value) { + metadataMatch = false; + break; + } + } + if (!metadataMatch) continue; + } + + results.push(run); + } + + // 按创建时间倒序排列 + results.sort((a, b) => b.created_at.getTime() - a.created_at.getTime()); + + // 分页 + const start = offset || 0; + const end = limit ? start + limit : results.length; + return results.slice(start, end); + } + + /** + * 更新 run 状态 + */ + async setStatus(runId: string, status: RunStatus): Promise { + const run = this.runs.get(runId); + if (run) { + run.status = status; + run.updated_at = new Date(); + } + } +} diff --git a/plugin/langchain/lib/runs/RunsService.ts b/plugin/langchain/lib/runs/RunsService.ts new file mode 100644 index 00000000..4b4c0214 --- /dev/null +++ b/plugin/langchain/lib/runs/RunsService.ts @@ -0,0 +1,283 @@ +import { AccessLevel, IncomingHttpHeaders, Inject, LifecyclePostInject, ModuleConfigs, SingletonProto } from '@eggjs/tegg'; +import type { + RunCreateDTO, + Run, + RunKwargs, + RunnableConfig, + StreamMode, + AuthContext, +} from '../../app/controller/types'; +import { RunsRepository } from './RunsRepository'; +import { Graph } from './Graph'; +import { AssistantsInitService } from '../assistants/AssistantsInitService'; + + +@SingletonProto({ accessLevel: AccessLevel.PUBLIC }) +export class RunsService { + @Inject() + private readonly moduleConfigs: ModuleConfigs; + + @Inject() + private readonly runsRepository: RunsRepository; + + @Inject() + private readonly graph: Graph; + + @Inject() + private readonly assistantsInitService: AssistantsInitService; + + private agentConfigs: Map; + + @LifecyclePostInject() + protected async init() { + this.agentConfigs = new Map(); + for (const [ moduleName, moduleInfo ] of this.moduleConfigs) { + if (moduleInfo.config.agents) { + const agents = moduleInfo.config.agents || {}; + for (const [ agentName, agentConfig ] of Object.entries(agents)) { + this.agentConfigs.set(agentName, { moduleName, config: agentConfig }); + } + } + } + } + + public getAllAgentConfigs() { + return this.agentConfigs; + } + + /** + * 创建并验证一个 Run + * 从 langgraphjs/libs/langgraph-api/src/api/runs.mts 的 createValidRun 移植 + */ + public async createValidRun( + threadId: string | undefined, + payload: RunCreateDTO, + kwargs: { + auth?: AuthContext; + headers?: IncomingHttpHeaders; + } = {}, + ): Promise { + const { assistant_id: assistantId, ...run } = payload; + const { auth, headers } = kwargs; + + // 验证 assistant 是否存在 + const assistant = await this.assistantsInitService.getAssistant(assistantId); + if (!assistant) { + throw new Error(`Assistant "${assistantId}" not found`); + } + + console.log('📊 Creating run for assistant:', { + assistantId, + graphId: assistant.graph_id, + name: assistant.name, + }); + + // 获取对应的 graph + const graph = this.graph.getGraph(assistant.graph_id); + console.log('---> graph instance', graph); + + // 生成 run_id + const runId = this.generateRunId(); + + // 处理 stream_mode + const streamMode = Array.isArray(payload.stream_mode) + ? payload.stream_mode + : payload.stream_mode != null + ? [ payload.stream_mode ] + : []; + if (streamMode.length === 0) streamMode.push('values'); + + const multitaskStrategy = payload.multitask_strategy ?? 'reject'; + const preventInsertInInflight = multitaskStrategy === 'reject'; + + // 构建 config + const config: RunnableConfig = { ...run.config }; + + // 处理 checkpoint_id + if (run.checkpoint_id) { + config.configurable ??= {}; + config.configurable.checkpoint_id = run.checkpoint_id; + } + + // 处理 checkpoint + if (run.checkpoint) { + config.configurable ??= {}; + Object.assign(config.configurable, run.checkpoint); + } + + // 处理 langsmith_tracer + if (run.langsmith_tracer) { + config.configurable ??= {}; + Object.assign(config.configurable, { + langsmith_project: run.langsmith_tracer.project_name, + langsmith_example_id: run.langsmith_tracer.example_id, + }); + } + + // 处理 headers(提取 x- 开头的自定义 header) + if (headers) { + for (const [ rawKey, value ] of Object.entries(headers)) { + if (!value) continue; // 跳过 undefined 值 + const key = rawKey.toLowerCase(); + if (key.startsWith('x-')) { + // 跳过敏感的 API keys + if ([ 'x-api-key', 'x-tenant-id', 'x-service-key' ].includes(key)) { + continue; + } + + config.configurable ??= {}; + // 如果是数组,取第一个值 + config.configurable[key] = Array.isArray(value) ? value[0] : value; + } else if (key === 'user-agent') { + config.configurable ??= {}; + config.configurable[key] = Array.isArray(value) ? value[0] : value; + } + } + } + + // 处理认证信息 + let userId: string | undefined; + if (auth) { + userId = auth.user.identity ?? auth.user.id; + config.configurable ??= {}; + config.configurable.langgraph_auth_user = auth.user; + config.configurable.langgraph_auth_user_id = userId; + config.configurable.langgraph_auth_permissions = auth.scopes; + } + + // 处理 feedback_keys + let feedbackKeys = + run.feedback_keys != null + ? Array.isArray(run.feedback_keys) + ? run.feedback_keys + : [ run.feedback_keys ] + : undefined; + if (!feedbackKeys?.length) feedbackKeys = undefined; + + // 构建 RunKwargs + const runKwargs: RunKwargs = { + input: run.input, + command: run.command, + config, + context: run.context, + stream_mode: streamMode as StreamMode[], + interrupt_before: run.interrupt_before, + interrupt_after: run.interrupt_after, + webhook: run.webhook, + feedback_keys: feedbackKeys, + temporary: + threadId == null && (run.on_completion ?? 'delete') === 'delete', + subgraphs: run.stream_subgraphs ?? false, + resumable: run.stream_resumable ?? false, + }; + + // 存储 Run 到仓库 + const [ first, ...inflight ] = await this.runsRepository.put( + runId, + assistantId, + runKwargs, + { + threadId, + userId, + metadata: run.metadata, + status: 'pending', + multitaskStrategy, + preventInsertInInflight, + afterSeconds: payload.after_seconds, + ifNotExists: payload.if_not_exists, + }, + auth, + ); + + // 处理创建成功的情况 + if (first?.run_id === runId) { + console.log('Created run', { run_id: runId, thread_id: threadId }); + + // 处理 multitask_strategy: interrupt 或 rollback + if ( + (multitaskStrategy === 'interrupt' || multitaskStrategy === 'rollback') && + inflight.length > 0 + ) { + try { + await this.runsRepository.cancel( + threadId, + inflight.map(run => run.run_id), + { action: multitaskStrategy }, + auth, + ); + } catch (error) { + console.warn( + 'Failed to cancel inflight runs, might be already cancelled', + { + error, + run_ids: inflight.map(run => run.run_id), + thread_id: threadId, + }, + ); + } + } + + return first; + } else if (multitaskStrategy === 'reject') { + // 如果 multitask_strategy 是 reject,且有冲突,抛出错误 + throw new Error( + 'Thread is already running a task. Wait for it to finish or choose a different multitask strategy.', + ); + } + + throw new Error('Unreachable state when creating run'); + } + + /** + * 获取 Run + */ + public async getRun( + runId: string, + threadId: string | undefined, + auth?: AuthContext, + ): Promise { + return this.runsRepository.get(runId, threadId, auth); + } + + /** + * 删除 Run + */ + public async deleteRun( + runId: string, + threadId: string | undefined, + auth?: AuthContext, + ): Promise { + return this.runsRepository.delete(runId, threadId, auth); + } + + /** + * 搜索 Runs + */ + public async searchRuns( + threadId: string, + options: { + limit?: number | null; + offset?: number | null; + status?: string | null; + metadata?: Record | null; + }, + auth?: AuthContext, + ): Promise { + return this.runsRepository.search(threadId, options, auth); + } + + /** + * 生成 run_id + * 简单版本,实际项目中可能需要使用 uuid + */ + private generateRunId(): string { + return `run_${Date.now()}_${Math.random().toString(36).substring(2, 15)}`; + } + +} + + +// export const getAssistantId = (graphId: string) => { +// if (graphId in GRAPHS) return uuid.v5(graphId, NAMESPACE_GRAPH); +// return graphId; +// }; diff --git a/plugin/langchain/lib/sse/index.ts b/plugin/langchain/lib/sse/index.ts new file mode 100644 index 00000000..44c32435 --- /dev/null +++ b/plugin/langchain/lib/sse/index.ts @@ -0,0 +1,255 @@ +/** + * SSE (Server-Sent Events) 流式响应工具 + * 用于实现 LangGraph API 的流式端点 + */ + +import type { EggContext } from '@eggjs/tegg'; +import { PassThrough } from 'stream'; + +/** + * SSE 事件格式 + * 对应 Hono 的 SSEMessage 接口 + */ +export interface SSEEvent { + data: string | Promise; + event?: string; + id?: string; + retry?: number; +} + +/** + * SSE Stream 类 + * 提供类似 Hono 的 streamSSE API + */ +export class SSEStreamWriter { + private writable: NodeJS.WritableStream; + private closed = false; + + constructor(writable: NodeJS.WritableStream) { + this.writable = writable; + } + + /** + * 写入 SSE 事件 + * 参考 Hono 的 writeSSE 实现 + */ + async writeSSE(message: SSEEvent): Promise { + if (this.closed) { + throw new Error('Stream is closed'); + } + + // 等待 data(支持 Promise) + const data = await Promise.resolve(message.data); + + // 将数据转换为字符串(如果不是字符串则 JSON 序列化) + const dataString = typeof data === 'string' ? data : JSON.stringify(data); + + // 处理多行数据 + const dataLines = dataString + .split('\n') + .map((line) => `data: ${line}`) + .join('\n'); + + // 按照 Hono 的顺序组装 SSE 消息:event -> data -> id -> retry + const sseData = [ + message.event && `event: ${message.event}`, + dataLines, + message.id && `id: ${message.id}`, + message.retry && `retry: ${message.retry}`, + ] + .filter(Boolean) + .join('\n') + '\n\n'; + + return new Promise((resolve, reject) => { + this.writable.write(sseData, (error) => { + if (error) { + reject(error); + } else { + resolve(); + } + }); + }); + } + + /** + * 发送注释(用于保持连接) + */ + async writeComment(comment: string): Promise { + if (this.closed) { + throw new Error('Stream is closed'); + } + + return new Promise((resolve, reject) => { + this.writable.write(`: ${comment}\n\n`, (error) => { + if (error) { + reject(error); + } else { + resolve(); + } + }); + }); + } + + /** + * 睡眠指定时间(毫秒) + */ + async sleep(ms: number): Promise { + return new Promise(resolve => setTimeout(resolve, ms)); + } + + /** + * 关闭流 + */ + close(): void { + if (!this.closed) { + this.closed = true; + this.writable.end(); + } + } + + /** + * 检查流是否已关闭 + */ + isClosed(): boolean { + return this.closed; + } +} + +/** + * 创建 SSE 流式响应 + * 参考 Hono 的 streamSSE API + * + * @param ctx - Egg Context + * @param cb - 流处理回调函数 + * @param onError - 可选的错误处理回调 + * + * @example + * ```ts + * return streamSSE(ctx, async (stream) => { + * await stream.writeSSE({ + * data: 'hello', + * event: 'message', + * id: '1' + * }); + * await stream.sleep(1000); + * await stream.writeSSE({ data: 'world' }); + * }); + * ``` + */ +export async function streamSSE( + ctx: EggContext, + cb: (stream: SSEStreamWriter) => Promise, + onError?: (e: Error, stream: SSEStreamWriter) => Promise +): Promise { + // 设置 SSE 响应头(按照 Hono 的顺序) + ctx.set('Transfer-Encoding', 'chunked'); + ctx.set('Content-Type', 'text/event-stream'); + ctx.set('Cache-Control', 'no-cache'); + ctx.set('Connection', 'keep-alive'); + ctx.set('X-Accel-Buffering', 'no'); // 禁用 nginx 缓冲 + + // 创建 PassThrough 流(既可读又可写) + const passThrough = new PassThrough(); + + // 设置响应体为流 + ctx.body = passThrough; + ctx.status = 200; + + // 创建 SSEStreamWriter 实例 + const stream = new SSEStreamWriter(passThrough); + + // 执行流处理逻辑(参考 Hono 的 run 函数) + const runStream = async () => { + try { + await cb(stream); + } catch (e) { + if (e instanceof Error && onError) { + // 调用自定义错误处理 + await onError(e, stream); + } + + // 发送错误事件 + if (!stream.isClosed()) { + try { + await stream.writeSSE({ + event: 'error', + data: e instanceof Error ? e.message : String(e), + }); + } catch (writeError) { + console.error('Failed to write error event:', writeError); + } + } + + // 如果没有自定义错误处理,输出到控制台 + if (!onError) { + console.error(e); + } + } finally { + // 关闭流 + stream.close(); + } + }; + + // 启动流处理(不阻塞) + runStream(); +} + +/** + * 获取断开连接的 AbortSignal + * 用于在客户端断开连接时取消操作 + * + * @param ctx - Egg Context + * @param stream - SSE Stream Writer (可选,用于清理) + * @returns AbortSignal + */ +export function getDisconnectAbortSignal( + ctx: EggContext, + stream?: SSEStreamWriter +): AbortSignal { + const controller = new AbortController(); + + // 监听请求关闭事件 + const onClose = () => { + if (!controller.signal.aborted) { + controller.abort(); + } + if (stream && !stream.isClosed()) { + stream.close(); + } + }; + + // 监听底层连接关闭 + ctx.req.on('close', onClose); + ctx.req.on('error', onClose); + + // 清理监听器(可选) + const cleanup = () => { + ctx.req.off('close', onClose); + ctx.req.off('error', onClose); + }; + + // 在 abort 时清理 + controller.signal.addEventListener('abort', cleanup, { once: true }); + + return controller.signal; +} + +/** + * 序列化数据为字典格式 + * 用于 SSE 数据传输 + * + * @param data - 要序列化的数据 + * @returns 序列化后的对象 + */ +export function serialiseAsDict(data: unknown): Record { + if (data === null || data === undefined) { + return {}; + } + + if (typeof data === 'object' && !Array.isArray(data)) { + return data as Record; + } + + // 如果不是对象,包装成对象 + return { value: data }; +} diff --git a/plugin/langchain/package.json b/plugin/langchain/package.json index 4714a434..a7d291ab 100644 --- a/plugin/langchain/package.json +++ b/plugin/langchain/package.json @@ -29,6 +29,8 @@ "lib/**/*.d.ts", "app/**/*.js", "app/**/*.d.ts", + "app/controller/**/*.js", + "app/controller/**/*.d.ts", "typings/*.d.ts" ], "types": "typings/index.d.ts", @@ -74,7 +76,9 @@ "koa-compose": "^3.2.1", "langchain": "^1.1.2", "sdk-base": "^4.2.0", - "urllib": "^4.4.0" + "urllib": "^4.4.0", + "uuid": "^11.0.3", + "zod": "^3.24.1" }, "devDependencies": { "@eggjs/module-test-util": "^3.67.2", @@ -84,6 +88,7 @@ "@eggjs/tegg-plugin": "^3.67.2", "@types/mocha": "^10.0.1", "@types/node": "^20.2.4", + "@types/uuid": "^10.0.0", "cross-env": "^7.0.3", "egg": "^3.9.1", "egg-mock": "^5.5.0", diff --git a/plugin/langchain/test/agent.test.ts b/plugin/langchain/test/agent.test.ts new file mode 100644 index 00000000..9d7a286c --- /dev/null +++ b/plugin/langchain/test/agent.test.ts @@ -0,0 +1,354 @@ +import mm from 'egg-mock'; +import path from 'path'; +import assert from 'assert'; + +describe.only('plugin/langchain/test/agent.test.ts', () => { + // https://github.com/langchain-ai/langchainjs/blob/main/libs/langchain/package.json#L9 + if (parseInt(process.version.slice(1, 3)) > 19) { + let app; + + afterEach(() => { + mm.restore(); + }); + + before(async () => { + mm(process.env, 'EGG_TYPESCRIPT', true); + mm(process, 'cwd', () => { + return path.join(__dirname, '..'); + }); + app = mm.app({ + baseDir: path.join(__dirname, 'fixtures/apps/agent'), + framework: path.dirname(require.resolve('egg')), + }); + await app.ready(); + }); + + after(async () => { + await app.close(); + }); + + it.skip('should return 422 when validation fails', async () => { + const res = await app.httpRequest() + .post('/api/runs/stream') + .send({ + assistant_id: 123, // 应该是字符串,但传了数字 + input: { + messages: [ + { + role: 'human', + content: 'hello', + }, + ], + }, + stream_mode: 'invalid_mode', // 无效的 stream_mode + }) + .expect(422); + + // 验证响应格式 + assert(res.body.error, 'Should have error field'); + assert.strictEqual(res.body.error, 'Validation failed'); + assert(Array.isArray(res.body.details), 'Should have details array'); + assert(res.body.details.length > 0, 'Should have at least one error detail'); + + // 验证 details 结构 + const firstError = res.body.details[0]; + assert(firstError.path, 'Error should have path'); + assert(firstError.message, 'Error should have message'); + assert(firstError.code, 'Error should have code'); + }); + + it('should return SSE stream', async () => { + const res = await app.httpRequest() + .post('/api/runs/stream') + .send({ + assistant_id: 'test-assistant-id', + input: { + messages: [ + { + role: 'human', + content: 'hello', + }, + ], + }, + }) + .expect(200); + + // 验证 SSE 响应头 + assert.strictEqual(res.headers['content-type'], 'text/event-stream'); + assert.strictEqual(res.headers['cache-control'], 'no-cache'); + assert.strictEqual(res.headers.connection, 'keep-alive'); + + // 验证 Content-Location header 存在 + assert(res.headers['content-location'], 'Content-Location header should exist'); + assert(res.headers['content-location'].startsWith('/runs/'), 'Content-Location should point to a run'); + }); + + // it('should accept valid RunCreateDTO payload', async () => { + // const payload = { + // assistant_id: 'test-assistant-id', + // input: { + // messages: [ + // { + // role: 'human', + // content: 'test', + // }, + // ], + // }, + // stream_mode: 'values', + // multitask_strategy: 'reject', + // on_disconnect: 'continue', + // }; + + // await app.httpRequest() + // .post('/api/runs/stream') + // .send(payload) + // .expect(200); + // }); + + // it('should handle array stream_mode', async () => { + // const payload = { + // assistant_id: 'test-assistant-id', + // input: { + // messages: [ + // { + // role: 'human', + // content: 'test', + // }, + // ], + // }, + // stream_mode: ['values', 'updates', 'events'], + // }; + + // await app.httpRequest() + // .post('/api/runs/stream') + // .send(payload) + // .expect(200); + // }); + + // it('should accept checkpoint configuration', async () => { + // const payload = { + // assistant_id: 'test-assistant-id', + // input: { + // messages: [ + // { + // role: 'human', + // content: 'test', + // }, + // ], + // }, + // checkpoint_id: 'checkpoint-123', + // checkpoint: { + // checkpoint_id: 'checkpoint-123', + // checkpoint_ns: 'default', + // }, + // }; + + // await app.httpRequest() + // .post('/api/runs/stream') + // .send(payload) + // .expect(200); + // }); + + // it('should accept config with metadata', async () => { + // const payload = { + // assistant_id: 'test-assistant-id', + // input: { + // messages: [ + // { + // role: 'human', + // content: 'test', + // }, + // ], + // }, + // config: { + // tags: ['test', 'stream'], + // recursion_limit: 10, + // configurable: { + // thread_id: 'thread-123', + // }, + // }, + // metadata: { + // user: 'test-user', + // session_id: 'session-123', + // }, + // }; + + // await app.httpRequest() + // .post('/api/runs/stream') + // .send(payload) + // .expect(200); + // }); + + // it('should handle interrupt_before configuration', async () => { + // const payload = { + // assistant_id: 'test-assistant-id', + // input: { + // messages: [ + // { + // role: 'human', + // content: 'test', + // }, + // ], + // }, + // interrupt_before: ['step1', 'step2'], + // }; + + // await app.httpRequest() + // .post('/api/runs/stream') + // .send(payload) + // .expect(200); + // }); + + // it('should handle interrupt_after configuration', async () => { + // const payload = { + // assistant_id: 'test-assistant-id', + // input: { + // messages: [ + // { + // role: 'human', + // content: 'test', + // }, + // ], + // }, + // interrupt_after: '*', + // }; + + // await app.httpRequest() + // .post('/api/runs/stream') + // .send(payload) + // .expect(200); + // }); + + // it('should accept langsmith_tracer configuration', async () => { + // const payload = { + // assistant_id: 'test-assistant-id', + // input: { + // messages: [ + // { + // role: 'human', + // content: 'test', + // }, + // ], + // }, + // langsmith_tracer: { + // project_name: 'test-project', + // example_id: 'example-123', + // }, + // }; + + // await app.httpRequest() + // .post('/api/runs/stream') + // .send(payload) + // .expect(200); + // }); + + // it('should accept command payload', async () => { + // const payload = { + // assistant_id: 'test-assistant-id', + // command: { + // goto: 'step1', + // update: { key: 'value' }, + // }, + // }; + + // await app.httpRequest() + // .post('/api/runs/stream') + // .send(payload) + // .expect(200); + // }); + + // it('should accept on_disconnect = cancel', async () => { + // const payload = { + // assistant_id: 'test-assistant-id', + // input: { + // messages: [ + // { + // role: 'human', + // content: 'test', + // }, + // ], + // }, + // on_disconnect: 'cancel', + // }; + + // await app.httpRequest() + // .post('/api/runs/stream') + // .send(payload) + // .expect(200); + // }); + + // it('should accept stream_subgraphs option', async () => { + // const payload = { + // assistant_id: 'test-assistant-id', + // input: { + // messages: [ + // { + // role: 'human', + // content: 'test', + // }, + // ], + // }, + // stream_subgraphs: true, + // }; + + // await app.httpRequest() + // .post('/api/runs/stream') + // .send(payload) + // .expect(200); + // }); + + // it('should accept stream_resumable option', async () => { + // const payload = { + // assistant_id: 'test-assistant-id', + // input: { + // messages: [ + // { + // role: 'human', + // content: 'test', + // }, + // ], + // }, + // stream_resumable: true, + // }; + + // await app.httpRequest() + // .post('/api/runs/stream') + // .send(payload) + // .expect(200); + // }); + + // TODO: Add validation tests when schema validation is implemented + // it('should reject invalid assistant_id', async () => { + // const payload = { + // assistant_id: 'invalid', + // input: { message: 'test' }, + // }; + // + // await app.httpRequest() + // .post('/api/runs/stream') + // .send(payload) + // .expect(422); + // }); + + // TODO: Add tests for actual SSE stream content when service is implemented + // it('should stream events in SSE format', async () => { + // const res = await app.httpRequest() + // .post('/api/runs/stream') + // .send({ + // assistant_id: 'test-assistant-id', + // input: { message: 'test' }, + // }); + // + // // Parse SSE response + // const events = parseSSE(res.text); + // assert(events.length > 0, 'Should receive at least one SSE event'); + // }); + // }); + + // TODO: Add tests for other endpoints when implemented + // describe('POST /api/runs/wait', () => {}); + // describe('POST /api/runs', () => {}); + // describe('POST /api/runs/batch', () => {}); + // }); + } +}); diff --git a/plugin/langchain/test/fixtures/apps/agent/app/modules/bar/controller/AppController.ts b/plugin/langchain/test/fixtures/apps/agent/app/modules/bar/controller/AppController.ts new file mode 100644 index 00000000..5c20bf76 --- /dev/null +++ b/plugin/langchain/test/fixtures/apps/agent/app/modules/bar/controller/AppController.ts @@ -0,0 +1,62 @@ +import { + HTTPController, + HTTPMethod, + HTTPMethodEnum, + Inject, +} from '@eggjs/tegg'; +import { ChatModelQualifier, TeggBoundModel, TeggCompiledStateGraph } from '@eggjs/tegg-langchain-decorator'; +import { ChatOpenAIModel } from '../../../../../../../../lib/ChatOpenAI'; +import { BoundChatModel } from '../service/BoundChatModel'; +import { FooGraph } from '../service/Graph'; +import { AIMessage } from 'langchain'; + +@HTTPController({ + path: '/llm', +}) +export class AppController { + @Inject() + @ChatModelQualifier('chat') + chatModel: ChatOpenAIModel; + + @Inject() + boundChatModel: TeggBoundModel; + + @Inject() + compiledFooGraph: TeggCompiledStateGraph; + + @HTTPMethod({ + method: HTTPMethodEnum.GET, + path: '/hello', + }) + async hello() { + const res = await this.chatModel.invoke('hello'); + return res; + } + + @HTTPMethod({ + method: HTTPMethodEnum.GET, + path: '/bound-chat', + }) + async boundChat() { + const res = await this.boundChatModel.invoke('hello'); + return res; + } + + @HTTPMethod({ method: HTTPMethodEnum.GET, path: '/graph' }) + async get() { + const res = await this.compiledFooGraph.invoke({ + messages: [], + aggregate: [], + }, { + configurable: { + thread_id: '1', + }, + }); + + return { + value: res.messages.filter(msg => AIMessage.prototype.isPrototypeOf(msg)).reduce((pre, cur) => { + return cur.content + pre; + }, ''), + }; + } +} diff --git a/plugin/langchain/test/fixtures/apps/agent/app/modules/bar/module.yml b/plugin/langchain/test/fixtures/apps/agent/app/modules/bar/module.yml new file mode 100644 index 00000000..7b7338c0 --- /dev/null +++ b/plugin/langchain/test/fixtures/apps/agent/app/modules/bar/module.yml @@ -0,0 +1,22 @@ +ChatModel: + clients: + chat: + apiKey: mock_api_key + model: Qwen2_5_7B_Instruct + temperature: 0 + timeout: 10 + type: openai + configuration: + baseURL: https://antchat.alipay.com/v1 + +mcp: + clients: + bar: + url: http://127.0.0.1:17283/mcp/sse + clientName: barSse + version: 1.0.0 + transportType: SSE + type: http + +agents: + FooGraph: {} \ No newline at end of file diff --git a/plugin/langchain/test/fixtures/apps/agent/app/modules/bar/package.json b/plugin/langchain/test/fixtures/apps/agent/app/modules/bar/package.json new file mode 100644 index 00000000..135968b9 --- /dev/null +++ b/plugin/langchain/test/fixtures/apps/agent/app/modules/bar/package.json @@ -0,0 +1,6 @@ +{ + "name": "llm-test-module", + "eggModule": { + "name": "llmTestModule" + } +} diff --git a/plugin/langchain/test/fixtures/apps/agent/app/modules/bar/service/BoundChatModel.ts b/plugin/langchain/test/fixtures/apps/agent/app/modules/bar/service/BoundChatModel.ts new file mode 100644 index 00000000..c3b66889 --- /dev/null +++ b/plugin/langchain/test/fixtures/apps/agent/app/modules/bar/service/BoundChatModel.ts @@ -0,0 +1,8 @@ +import { BoundModel } from '@eggjs/tegg-langchain-decorator'; +import { FooTool } from './Graph'; + +@BoundModel({ + modelName: 'chat', + tools: [ FooTool ], +}) +export class BoundChatModel {} diff --git a/plugin/langchain/test/fixtures/apps/agent/app/modules/bar/service/Graph.ts b/plugin/langchain/test/fixtures/apps/agent/app/modules/bar/service/Graph.ts new file mode 100644 index 00000000..ef4c2b33 --- /dev/null +++ b/plugin/langchain/test/fixtures/apps/agent/app/modules/bar/service/Graph.ts @@ -0,0 +1,135 @@ +import { AccessLevel, SingletonProto, ToolArgs, ToolArgsSchema } from '@eggjs/tegg'; +import { Graph, GraphEdge, IGraphEdge, AbstractStateGraph, GraphNode, IGraphNode, GraphStateType, GraphTool, IGraphTool, TeggToolNode } from '@eggjs/tegg-langchain-decorator'; +import { Annotation, MemorySaver } from '@langchain/langgraph'; +// import { AIMessage, BaseMessage, ToolMessage } from '@langchain/core/messages'; +import * as z from 'zod/v4'; +import { AIMessage, BaseMessage, ToolMessage } from 'langchain'; + +export enum FooGraphNodeName { + START = '__start__', + END = '__end__', + ACTION = 'action', + AGENT = 'agent', + TOOLS = 'tools', + NODE_A = 'a', + NODE_B = 'b', + NODE_C = 'c', + NODE_D = 'd', +} + +@SingletonProto() +export class FooSaver extends MemorySaver {} + +// state +export const fooAnnotationStateDefinition = { + messages: Annotation({ + reducer: (x, y) => x.concat(y), + }), + aggregate: Annotation({ + reducer: (x, y) => x.concat(y), + }), +}; + +export type fooAnnotationStateDefinitionType = typeof fooAnnotationStateDefinition; + +export const ToolType = { + query: z.string().describe('npm package name'), +}; + +@GraphTool({ + toolName: 'search', + description: 'Call the foo tool', +}) +export class FooTool implements IGraphTool { + + async execute(@ToolArgsSchema(ToolType) args: ToolArgs) { + console.log('query: ', args.query); + return `hello ${args.query}`; + } +} + +@GraphNode({ + nodeName: FooGraphNodeName.ACTION, + tools: [ FooTool ], +}) +export class FooNode implements IGraphNode { + async execute(state: GraphStateType) { + console.log('response: ', state.messages); + const messages = state.messages; + const lastMessage = messages[messages.length - 1]; + if (ToolMessage.prototype.isPrototypeOf(lastMessage)) { + return { + messages: [ + new AIMessage(lastMessage!.text!), + ], + }; + } + return { + messages: [ + new AIMessage({ + tool_calls: [ + { + name: 'search', + args: { + query: 'graph tool', + }, + id: 'fc-6b565ce5-e0cf-4af3-8ed0-0ca75c509d9e', + type: 'tool_call', + }, + ], + content: 'hello world', + }), + ], + }; + } +} + +@GraphNode({ + nodeName: FooGraphNodeName.TOOLS, + tools: [ FooTool ], +}) +export class ToolNode extends TeggToolNode {} + +@GraphEdge({ + fromNodeName: FooGraphNodeName.ACTION, + toNodeNames: [ FooGraphNodeName.TOOLS, FooGraphNodeName.END ], +}) +export class FooContinueEdge implements IGraphEdge { + + async execute( + state: GraphStateType, + ): Promise { + console.log('response: ', state.messages); + const messages = state.messages; + const lastMessage = messages[messages.length - 1] as AIMessage; + if (lastMessage?.tool_calls?.length) { + return FooGraphNodeName.TOOLS; + } + return FooGraphNodeName.END; + } +} + +@GraphEdge({ + fromNodeName: FooGraphNodeName.TOOLS, + toNodeNames: [ FooGraphNodeName.ACTION ], +}) +export class ToolsContinueEdge implements IGraphEdge {} + +@GraphEdge({ + fromNodeName: FooGraphNodeName.START, + toNodeNames: [ FooGraphNodeName.ACTION ], +}) +export class FooStartContinueEdge implements IGraphEdge {} + + +@Graph({ + accessLevel: AccessLevel.PUBLIC, + nodes: [ FooNode, ToolNode ], + edges: [ FooContinueEdge, FooStartContinueEdge, ToolsContinueEdge ], + checkpoint: FooSaver, +}) +export class FooGraph extends AbstractStateGraph { + constructor() { + super(fooAnnotationStateDefinition); + } +} diff --git a/plugin/langchain/test/fixtures/apps/agent/config/config.default.js b/plugin/langchain/test/fixtures/apps/agent/config/config.default.js new file mode 100644 index 00000000..65b44e39 --- /dev/null +++ b/plugin/langchain/test/fixtures/apps/agent/config/config.default.js @@ -0,0 +1,16 @@ +'use strict'; + +module.exports = function() { + const config = { + keys: 'test key', + security: { + csrf: { + enable: false, + }, + }, + bodyParser: { + enable: true, + }, + }; + return config; +}; diff --git a/plugin/langchain/test/fixtures/apps/agent/config/module.json b/plugin/langchain/test/fixtures/apps/agent/config/module.json new file mode 100644 index 00000000..34dd1b57 --- /dev/null +++ b/plugin/langchain/test/fixtures/apps/agent/config/module.json @@ -0,0 +1,7 @@ +[ + { + "path": "../app/modules/bar" + }, { + "package": "../../../../" + } +] diff --git a/plugin/langchain/test/fixtures/apps/agent/config/plugin.js b/plugin/langchain/test/fixtures/apps/agent/config/plugin.js new file mode 100644 index 00000000..739844ce --- /dev/null +++ b/plugin/langchain/test/fixtures/apps/agent/config/plugin.js @@ -0,0 +1,26 @@ +'use strict'; + +// eslint-disable-next-line @typescript-eslint/no-var-requires +const path = require('node:path'); + +exports.tegg = { + package: '@eggjs/tegg-plugin', + enable: true, +}; + +exports.teggConfig = { + package: '@eggjs/tegg-config', + enable: true, +}; + +exports.teggLangChain = { + enable: true, + path: path.join(__dirname, '../../../../../'), +}; + +exports.teggController = { + package: '@eggjs/tegg-controller-plugin', + enable: true, +}; + +exports.watcher = false; diff --git a/plugin/langchain/test/fixtures/apps/agent/package.json b/plugin/langchain/test/fixtures/apps/agent/package.json new file mode 100644 index 00000000..978d31f2 --- /dev/null +++ b/plugin/langchain/test/fixtures/apps/agent/package.json @@ -0,0 +1,3 @@ +{ + "name": "egg-app" +} diff --git a/plugin/langchain/test/fixtures/apps/agent/tsconfig.json b/plugin/langchain/test/fixtures/apps/agent/tsconfig.json new file mode 100644 index 00000000..bfa29259 --- /dev/null +++ b/plugin/langchain/test/fixtures/apps/agent/tsconfig.json @@ -0,0 +1,15 @@ +{ + "compilerOptions": { + "outDir": "dist", + "module": "node18", + "moduleResolution": "node16", + "experimentalDecorators": true, + "emitDecoratorMetadata": true, + "baseUrl": "./" + }, + "exclude": [ + "dist", + "node_modules", + "test" + ] +} diff --git a/plugin/langchain/test/sse-utils.test.ts b/plugin/langchain/test/sse-utils.test.ts new file mode 100644 index 00000000..02abc599 --- /dev/null +++ b/plugin/langchain/test/sse-utils.test.ts @@ -0,0 +1,298 @@ +import assert from 'assert'; +import { PassThrough } from 'stream'; +import { SSEStreamWriter, streamSSE, getDisconnectAbortSignal, serialiseAsDict } from '../lib/sse'; + +describe('test/sse-utils.test.ts', () => { + describe('SSEStreamWriter', () => { + it('should write SSE event correctly', async () => { + const passThrough = new PassThrough(); + const writer = new SSEStreamWriter(passThrough); + + let output = ''; + passThrough.on('data', (chunk) => { + output += chunk.toString(); + }); + + await writer.writeSSE({ + event: 'test', + data: 'hello world', + }); + + writer.close(); + + assert(output.includes('event: test'), 'Should include event'); + assert(output.includes('data: hello world'), 'Should include data'); + assert(output.endsWith('\n\n'), 'Should end with double newline'); + }); + + it('should write SSE event with id and retry', async () => { + const passThrough = new PassThrough(); + const writer = new SSEStreamWriter(passThrough); + + let output = ''; + passThrough.on('data', (chunk) => { + output += chunk.toString(); + }); + + await writer.writeSSE({ + event: 'message', + data: 'test', + id: '123', + retry: 5000, + }); + + writer.close(); + + assert(output.includes('event: message')); + assert(output.includes('data: test')); + assert(output.includes('id: 123')); + assert(output.includes('retry: 5000')); + }); + + it('should handle multiline data', async () => { + const passThrough = new PassThrough(); + const writer = new SSEStreamWriter(passThrough); + + let output = ''; + passThrough.on('data', (chunk) => { + output += chunk.toString(); + }); + + await writer.writeSSE({ + data: 'line1\nline2\nline3', + }); + + writer.close(); + + assert(output.includes('data: line1')); + assert(output.includes('data: line2')); + assert(output.includes('data: line3')); + }); + + it('should write comment', async () => { + const passThrough = new PassThrough(); + const writer = new SSEStreamWriter(passThrough); + + let output = ''; + passThrough.on('data', (chunk) => { + output += chunk.toString(); + }); + + await writer.writeComment('keep-alive'); + + writer.close(); + + assert.strictEqual(output, ': keep-alive\n\n'); + }); + + it('should sleep for specified time', async () => { + const writer = new SSEStreamWriter(new PassThrough()); + + const start = Date.now(); + await writer.sleep(100); + const elapsed = Date.now() - start; + + assert(elapsed >= 90, 'Should sleep for at least 90ms'); + assert(elapsed < 200, 'Should not sleep for more than 200ms'); + }); + + it('should track closed state', () => { + const writer = new SSEStreamWriter(new PassThrough()); + + assert.strictEqual(writer.isClosed(), false, 'Should not be closed initially'); + + writer.close(); + + assert.strictEqual(writer.isClosed(), true, 'Should be closed after close()'); + }); + + it('should throw error when writing to closed stream', async () => { + const writer = new SSEStreamWriter(new PassThrough()); + + writer.close(); + + await assert.rejects( + async () => { + await writer.writeSSE({ data: 'test' }); + }, + { message: 'Stream is closed' } + ); + }); + + it('should throw error when writing comment to closed stream', async () => { + const writer = new SSEStreamWriter(new PassThrough()); + + writer.close(); + + await assert.rejects( + async () => { + await writer.writeComment('test'); + }, + { message: 'Stream is closed' } + ); + }); + }); + + describe('streamSSE', () => { + it('should set correct SSE headers', async () => { + const ctx: any = { + set: (key: string, value: string) => { + ctx.headers = ctx.headers || {}; + ctx.headers[key.toLowerCase()] = value; + }, + headers: {}, + }; + + streamSSE(ctx, async (stream) => { + await stream.writeSSE({ data: 'test' }); + }); + + // 等待一下让 streamSSE 设置 headers + await new Promise(resolve => setTimeout(resolve, 10)); + + assert.strictEqual(ctx.headers['content-type'], 'text/event-stream'); + assert.strictEqual(ctx.headers['cache-control'], 'no-cache'); + assert.strictEqual(ctx.headers['connection'], 'keep-alive'); + assert.strictEqual(ctx.headers['transfer-encoding'], 'chunked'); + assert.strictEqual(ctx.headers['x-accel-buffering'], 'no'); + }); + + it('should call callback with stream writer', async () => { + const ctx: any = { + set: () => {}, + }; + + let streamReceived = false; + + streamSSE(ctx, async (stream) => { + streamReceived = true; + assert(stream.writeSSE, 'Stream should have writeSSE method'); + assert(stream.sleep, 'Stream should have sleep method'); + assert(stream.close, 'Stream should have close method'); + }); + + // 等待回调执行 + await new Promise(resolve => setTimeout(resolve, 10)); + + assert(streamReceived, 'Callback should be called'); + }); + + it('should handle errors in callback', async () => { + const ctx: any = { + set: () => {}, + }; + + const error = new Error('Test error'); + let errorReceived = false; + + await streamSSE(ctx, async () => { + throw error; + }, async (e) => { + errorReceived = true; + assert.strictEqual(e, error); + }); + + // 等待错误处理 + await new Promise(resolve => setTimeout(resolve, 10)); + + assert(errorReceived, 'Error handler should be called'); + }); + }); + + describe('getDisconnectAbortSignal', () => { + it('should return an AbortSignal', () => { + const ctx: any = { + req: { + on: () => {}, + off: () => {}, + }, + }; + + const signal = getDisconnectAbortSignal(ctx); + + assert(signal instanceof AbortSignal, 'Should return an AbortSignal'); + assert.strictEqual(signal.aborted, false, 'Should not be aborted initially'); + }); + + it('should abort when request closes', (done) => { + let closeHandler: Function; + const ctx: any = { + req: { + on: (event: string, handler: Function) => { + if (event === 'close') { + closeHandler = handler; + } + }, + off: () => {}, + }, + }; + + const signal = getDisconnectAbortSignal(ctx); + + signal.addEventListener('abort', () => { + assert.strictEqual(signal.aborted, true); + done(); + }); + + // 模拟请求关闭 + closeHandler!(); + }); + + it('should close stream when request closes', (done) => { + let closeHandler: Function; + const ctx: any = { + req: { + on: (event: string, handler: Function) => { + if (event === 'close') { + closeHandler = handler; + } + }, + off: () => {}, + }, + }; + + const stream = new SSEStreamWriter(new PassThrough()); + + getDisconnectAbortSignal(ctx, stream); + + // 模拟请求关闭 + closeHandler!(); + + // 稍等一下确保流已关闭 + setTimeout(() => { + assert.strictEqual(stream.isClosed(), true); + done(); + }, 10); + }); + }); + + describe('serialiseAsDict', () => { + it('should return empty object for null', () => { + const result = serialiseAsDict(null); + assert.deepStrictEqual(result, {}); + }); + + it('should return empty object for undefined', () => { + const result = serialiseAsDict(undefined); + assert.deepStrictEqual(result, {}); + }); + + it('should return object as-is', () => { + const obj = { key: 'value', nested: { data: 123 } }; + const result = serialiseAsDict(obj); + assert.deepStrictEqual(result, obj); + }); + + it('should wrap non-object values', () => { + assert.deepStrictEqual(serialiseAsDict('string'), { value: 'string' }); + assert.deepStrictEqual(serialiseAsDict(123), { value: 123 }); + assert.deepStrictEqual(serialiseAsDict(true), { value: true }); + }); + + it('should wrap array values', () => { + const arr = [1, 2, 3]; + const result = serialiseAsDict(arr); + assert.deepStrictEqual(result, { value: arr }); + }); + }); +}); diff --git a/plugin/langchain/typings/index.d.ts b/plugin/langchain/typings/index.d.ts index d3673afa..916555e8 100644 --- a/plugin/langchain/typings/index.d.ts +++ b/plugin/langchain/typings/index.d.ts @@ -113,6 +113,17 @@ export type ChatModelConfigModuleConfigType = Static; + }; + + export interface ModuleConfig extends LangChainModuleConfig { + } +} + +declare module '@eggjs/tegg-types' { + export type LangChainModuleConfig = { + ChatModel?: ChatModelConfigModuleConfigType; + agents?: Record; }; export interface ModuleConfig extends LangChainModuleConfig {