diff --git a/conf.json b/conf.json index 940c8bdd6..b97b29fca 100644 --- a/conf.json +++ b/conf.json @@ -10,7 +10,8 @@ "pangea", "promptsecurity", "panw-prisma-airs", - "walledai" + "walledai", + "zscaler" ], "credentials": { "portkey": { diff --git a/plugins/build.ts b/plugins/build.ts index 76b59022c..f3f646355 100644 --- a/plugins/build.ts +++ b/plugins/build.ts @@ -1,43 +1,59 @@ -import conf from '../conf.json'; import fs from 'fs'; +import path from 'path'; +import { fileURLToPath } from 'url'; -const pluginsEnabled = conf.plugins_enabled; - -let importStrings: any = []; -let funcStrings: any = {}; -let funcs: any = {}; - -for (const plugin of pluginsEnabled) { - const manifest = await import(`./${plugin}/manifest.json`); - const functions = manifest.functions.map((func: any) => func.id); - importStrings = [ - ...importStrings, - ...functions.map( - (func: any) => - `import { handler as ${manifest.id}${func} } from "./${plugin}/${func}"` - ), - ]; - - funcs[plugin] = {}; - functions.forEach((func: any) => { - funcs[plugin][func] = func; - }); - - funcStrings[plugin] = []; - for (let key in funcs[plugin]) { - funcStrings[plugin].push(`"${key}": ${manifest.id}${funcs[plugin][key]}`); +const __filename = fileURLToPath(import.meta.url); +const __dirname = path.dirname(__filename); + +async function buildPlugins() { + const conf = JSON.parse( + fs.readFileSync(path.join(__dirname, '../conf.json'), 'utf-8') + ); + const pluginsEnabled = conf.plugins_enabled; + + let importStrings: any = []; + let funcStrings: any = {}; + let funcs: any = {}; + + for (const plugin of pluginsEnabled) { + const manifestPath = path.join(__dirname, plugin, 'manifest.json'); + const manifestContent = fs.readFileSync(manifestPath, 'utf-8'); + const manifest = JSON.parse(manifestContent); + const safePluginId = manifest.id.replace(/-/g, ''); + const functions = manifest.functions.map((func: any) => func.id); + importStrings = [ + ...importStrings, + ...functions.map( + (func: any) => + `import { handler as ${safePluginId}${func.replace(/-/g, '')} } from "./${plugin}/${func}"` + ), + ]; + + funcs[plugin] = {}; + functions.forEach((func: any) => { + funcs[plugin][func] = func; + }); + + funcStrings[plugin] = []; + for (let key in funcs[plugin]) { + funcStrings[plugin].push( + `"${key}": ${safePluginId}${funcs[plugin][key].replace(/-/g, '')}` + ); + } } -} -const indexFilePath = './plugins/index.ts'; + const indexFilePath = './plugins/index.ts'; -let finalFuncStrings: any = []; -for (let key in funcStrings) { - finalFuncStrings.push( - `\n "${key}": {\n ${funcStrings[key].join(',\n ')}\n }` - ); -} + let finalFuncStrings: any = []; + for (let key in funcStrings) { + finalFuncStrings.push( + `\n "${key}": {\n ${funcStrings[key].join(',\n ')}\n }` + ); + } -const content = `${importStrings.join('\n')}\n\nexport const plugins = {${finalFuncStrings}\n};\n`; + const content = `${importStrings.join('\n')}\n\nexport const plugins = {${finalFuncStrings}\n};\n`; + + fs.writeFileSync(indexFilePath, content); +} -fs.writeFileSync(indexFilePath, content); +buildPlugins(); diff --git a/plugins/zscaler/main-function.ts b/plugins/zscaler/main-function.ts new file mode 100644 index 000000000..d490f493d --- /dev/null +++ b/plugins/zscaler/main-function.ts @@ -0,0 +1,148 @@ +import { + HookEventType, + PluginContext, + type PluginHandler, + PluginParameters, +} from '../types'; +import { post, getText } from '../utils'; + +const ZSCALER_EXECUTE_POLICY_URL = + 'https://api.zseclipse.net/v1/detection/execute-policy'; + +interface ZscalerCredentials { + zscalerApiKey: string; +} + +interface ZscalerPluginParameters { + policyId: string; +} + +interface ZscalerExecutePolicyRequest { + policyId: string; + direction: 'IN' | 'OUT'; + content: any; +} + +interface ZscalerExecutePolicyResponse { + action: 'ALLOW' | 'BLOCK'; + detectorResponses?: Record; +} + +export const handler: PluginHandler = async ( + context: PluginContext, + parameters: PluginParameters, + eventType: HookEventType +) => { + let error: Error | null = null; + let verdict = true; + let data: Record = {}; + + const credentials = parameters.credentials as ZscalerCredentials | undefined; + const pluginParams = parameters.parameters as ZscalerPluginParameters; + + // :white_check_mark: FAIL OPEN (aligned with other plugins) + if (!credentials?.zscalerApiKey) { + return { + error: new Error('Zscaler AI Guard API Key must be configured.'), + verdict: true, + data, + }; + } + + if (!pluginParams?.policyId) { + return { + error: new Error('Zscaler AI Guard Policy ID must be configured.'), + verdict: true, + data, + }; + } + + const contentToScan = getText(context, eventType); + + if (!contentToScan) { + return { + error: new Error('No content found to scan.'), + verdict: true, + data, + }; + } + + const direction = eventType === 'beforeRequestHook' ? 'IN' : 'OUT'; + + const zscalerRequest: ZscalerExecutePolicyRequest = { + policyId: pluginParams.policyId, + direction, + content: contentToScan, + }; + + try { + const headers = { + 'Content-Type': 'application/json', + Authorization: `Bearer ${credentials.zscalerApiKey}`, + }; + + const response: ZscalerExecutePolicyResponse = await post( + ZSCALER_EXECUTE_POLICY_URL, + zscalerRequest, + { headers }, + 10000 + ); + + data = { + zscalerAction: response.action, + detectorResponses: response.detectorResponses, + }; + + // :white_check_mark: Check top-level action + let isBlocked = response.action === 'BLOCK'; + + // :white_check_mark: Also check individual detectors (if present) + if (response.detectorResponses) { + const detectorBlocked = Object.values(response.detectorResponses).some( + (detector: any) => detector?.action === 'BLOCK' + ); + + isBlocked = isBlocked || detectorBlocked; + } + + verdict = !isBlocked; + + if (!verdict) { + error = new Error( + 'Zscaler AI Guard blocked the content with action: BLOCK' + ); + } + } catch (e: unknown) { + verdict = false; + + const maybeError = e as any; + const status = maybeError?.response?.status; + + // :white_check_mark: Proper 429 handling (your test will now pass) + if (status === 429) { + error = new Error('Zscaler AI Guard rate limit exceeded. Status: 429'); + } + // :white_check_mark: Proper 5xx handling + else if (status && status >= 500 && status < 600) { + error = new Error( + `Zscaler AI Guard API returned a server error. Status: ${status}` + ); + } + // :white_check_mark: Normal JS Error + else if (e instanceof Error) { + error = e; + } + // :white_check_mark: Fallback + else { + error = new Error('An unknown error occurred during Zscaler API call.'); + } + + data = { originalError: error.message }; + } + + return { + error, + verdict, + data, + }; +}; diff --git a/plugins/zscaler/manifest.json b/plugins/zscaler/manifest.json new file mode 100644 index 000000000..66a0fb6f4 --- /dev/null +++ b/plugins/zscaler/manifest.json @@ -0,0 +1,45 @@ +{ + "id": "zscalerAiGuard", + "name": "Zscaler AI Guard", + "version": "1.0.0", + "description": "Integrates Zscaler AI Guard for advanced GenAI security, including DLP and prompt injection detection.", + "author": "Zscaler", + "credentials": { + "type": "object", + "properties": { + "zscalerApiKey": { + "type": "string", + "label": "Zscaler AI Guard API Key", + "description": "The API Key generated from your Zscaler AI Guard tenant.", + "encrypted": true + } + }, + "required": ["zscalerApiKey"] + }, + "functions": [ + { + "name": "Zscaler AI Guard Check", + "id": "zscalerAiGuardCheck", + "supportedHooks": ["beforeRequestHook", "afterRequestHook"], + "type": "guardrail", + "description": [ + { + "type": "subHeading", + "text": "Performs Zscaler AI Guard policy checks on prompts and LLM responses." + } + ], + "parameters": { + "type": "object", + "properties": { + "policyId": { + "type": "string", + "label": "Zscaler Policy ID", + "description": "The ID of the Zscaler Detections Policy to execute.", + "required": true + } + }, + "required": ["policyId"] + } + } + ] +} diff --git a/plugins/zscaler/test-file.test.ts b/plugins/zscaler/test-file.test.ts new file mode 100644 index 000000000..ba6e2dc91 --- /dev/null +++ b/plugins/zscaler/test-file.test.ts @@ -0,0 +1,120 @@ +/// +import { handler } from './main-function'; +import type { PluginContext, PluginParameters } from '../types'; +import * as utils from '../utils'; + +jest.mock('../utils'); + +const mockedPost = utils.post as jest.Mock; +const mockedGetText = utils.getText as jest.Mock; + +describe('Zscaler AI Guard Plugin - Unit Tests', () => { + const baseParameters: PluginParameters<{ zscalerApiKey: string }> = { + credentials: { zscalerApiKey: 'test-key' }, + parameters: { policyId: 'test-policy' }, + }; + + const baseContext: PluginContext = { + request: { json: {} }, + requestType: 'chatComplete', + }; + + beforeEach(() => { + jest.clearAllMocks(); + mockedGetText.mockReturnValue('test content'); + }); + + it('should fail open when API key is missing', async () => { + const result = await handler( + baseContext, + { credentials: {}, parameters: { policyId: 'x' } } as any, + 'beforeRequestHook' + ); + + expect(result.verdict).toBe(true); + expect(result.error).toBeInstanceOf(Error); + }); + + it('should fail open when policyId is missing', async () => { + const result = await handler( + baseContext, + { credentials: { zscalerApiKey: 'x' }, parameters: {} } as any, + 'beforeRequestHook' + ); + + expect(result.verdict).toBe(true); + expect(result.error).toBeInstanceOf(Error); + }); + + it('should return allow for ALLOW response', async () => { + mockedPost.mockResolvedValue({ + action: 'ALLOW', + }); + + const result = await handler( + baseContext, + baseParameters, + 'beforeRequestHook' + ); + + expect(result.verdict).toBe(true); + }); + + it('should block when top-level action is BLOCK', async () => { + mockedPost.mockResolvedValue({ + action: 'BLOCK', + }); + + const result = await handler( + baseContext, + baseParameters, + 'beforeRequestHook' + ); + + expect(result.verdict).toBe(false); + }); + + it('should block when detector returns BLOCK', async () => { + mockedPost.mockResolvedValue({ + action: 'ALLOW', + detectorResponses: { + dlp: { action: 'BLOCK' }, + }, + }); + + const result = await handler( + baseContext, + baseParameters, + 'beforeRequestHook' + ); + + expect(result.verdict).toBe(false); + }); + + it('should handle 429 rate limit error', async () => { + mockedPost.mockRejectedValue({ + response: { status: 429 }, + }); + + const result = await handler( + baseContext, + baseParameters, + 'beforeRequestHook' + ); + + expect(result.verdict).toBe(false); + expect(result.error?.message).toContain('rate limit'); + }); + + it('should allow empty content', async () => { + mockedGetText.mockReturnValue(''); + + const result = await handler( + baseContext, + baseParameters, + 'beforeRequestHook' + ); + + expect(result.verdict).toBe(true); + }); +});