diff --git a/server/index.ts b/server/index.ts index 120ee4445..2389048f0 100644 --- a/server/index.ts +++ b/server/index.ts @@ -14,6 +14,11 @@ import { readFileSync, existsSync } from "fs"; import { join, dirname, resolve } from "path"; import { fileURLToPath } from "url"; import { MCPClientManager } from "@/sdk"; +import { + buildCorsHeaders, + getAllowedOrigin, + isAllowedHost, +} from "./utils/cors"; const __filename = fileURLToPath(import.meta.url); const __dirname = dirname(__filename); @@ -143,6 +148,31 @@ const app = new Hono().onError((err, c) => { return c.json({ error: "Internal server error" }, 500); }); +// Reject requests that are not targeting the loopback host or that originate +// from unexpected origins. This mitigates the "0.0.0.0 day" class of attacks +// that coerce local services into serving cross-origin traffic. +app.use("*", async (c, next) => { + const hostHeader = c.req.header("x-forwarded-host") || c.req.header("host"); + if (!isAllowedHost(hostHeader)) { + appLogger.warn("Blocked request with disallowed host header", { + host: hostHeader, + path: c.req.path, + }); + return c.json({ error: "Host not allowed" }, 403); + } + + const originHeader = c.req.header("origin"); + if (originHeader && !getAllowedOrigin(originHeader)) { + appLogger.warn("Blocked request with disallowed origin", { + origin: originHeader, + path: c.req.path, + }); + return c.json({ error: "Origin not allowed" }, 403, { Vary: "Origin" }); + } + + await next(); +}); + // Load environment variables early so route handlers can read CONVEX_HTTP_URL const envFile = process.env.NODE_ENV === "production" @@ -218,14 +248,23 @@ app.route("/api/mcp", mcpRoutes); // We resolve the upstream messages endpoint via sessionId and forward with any injected auth. // CORS preflight app.options("/sse/message", (c) => { - return c.body(null, 204, { - "Access-Control-Allow-Origin": "*", - "Access-Control-Allow-Methods": "POST,OPTIONS", - "Access-Control-Allow-Headers": - "Authorization, Content-Type, Accept, Accept-Language", - "Access-Control-Max-Age": "86400", - Vary: "Origin, Access-Control-Request-Headers", + const originHeader = c.req.header("origin"); + const { headers, allowedOrigin } = buildCorsHeaders(originHeader, { + allowMethods: "POST,OPTIONS", + allowHeaders: + "Authorization, Content-Type, Accept, Accept-Language, X-MCPJam-Endpoint-Base", + maxAge: "86400", + allowPrivateNetwork: true, + requestPrivateNetwork: + c.req.header("access-control-request-private-network") === "true", }); + + if (originHeader && !allowedOrigin) { + return c.json({ error: "Origin not allowed" }, 403, { Vary: "Origin" }); + } + + headers.Vary = `${headers.Vary}, Access-Control-Request-Headers`; + return c.body(null, 204, headers); }); // Health check diff --git a/server/routes/mcp/elicitation.ts b/server/routes/mcp/elicitation.ts index abc607bfa..99eef0821 100644 --- a/server/routes/mcp/elicitation.ts +++ b/server/routes/mcp/elicitation.ts @@ -1,5 +1,6 @@ import { Hono } from "hono"; import type { ElicitResult } from "@modelcontextprotocol/sdk/types.js"; +import { buildCorsHeaders } from "../../utils/cors"; const elicitation = new Hono(); @@ -49,6 +50,10 @@ elicitation.use("*", async (c, next) => { // SSE stream for elicitation events elicitation.get("/stream", async (c) => { + const originHeader = c.req.header("origin"); + const { headers: corsHeaders } = buildCorsHeaders(originHeader, { + allowCredentials: true, + }); const encoder = new TextEncoder(); const stream = new ReadableStream({ start(controller) { @@ -85,11 +90,11 @@ elicitation.get("/stream", async (c) => { return new Response(stream as any, { status: 200, headers: { + ...corsHeaders, "Content-Type": "text/event-stream", "Cache-Control": "no-cache, no-transform", Connection: "keep-alive", "X-Accel-Buffering": "no", - "Access-Control-Allow-Origin": "*", }, }); }); diff --git a/server/routes/mcp/http-adapters.ts b/server/routes/mcp/http-adapters.ts index cbfadcbb9..f5a886fa5 100644 --- a/server/routes/mcp/http-adapters.ts +++ b/server/routes/mcp/http-adapters.ts @@ -1,6 +1,7 @@ import { Hono } from "hono"; import "../../types/hono"; import { handleJsonRpc, BridgeMode } from "../../services/mcp-http-bridge"; +import { buildCorsHeaders } from "../../utils/cors"; // In-memory SSE session store per serverId:sessionId type Session = { @@ -16,40 +17,49 @@ const latestSessionByServer: Map = new Map(); function createHttpHandler(mode: BridgeMode, routePrefix: string) { const router = new Hono(); - router.options("/:serverId", (c) => - c.body(null, 204, { - "Access-Control-Allow-Origin": "*", - "Access-Control-Allow-Methods": "GET,POST,HEAD,OPTIONS", - "Access-Control-Allow-Headers": - "*, Authorization, Content-Type, Accept, Accept-Language", - "Access-Control-Expose-Headers": "*", - "Access-Control-Max-Age": "86400", - }), - ); + const handlePreflight = (c: any) => { + const originHeader = c.req.header("origin"); + const { headers, allowedOrigin } = buildCorsHeaders(originHeader, { + allowMethods: "GET,POST,HEAD,OPTIONS", + allowHeaders: + "Authorization, Content-Type, Accept, Accept-Language, X-MCPJam-Endpoint-Base", + exposeHeaders: "*", + maxAge: "86400", + allowCredentials: true, + allowPrivateNetwork: true, + requestPrivateNetwork: + c.req.header("access-control-request-private-network") === "true", + }); + + if (originHeader && !allowedOrigin) { + return c.json({ error: "Origin not allowed" }, 403, { Vary: "Origin" }); + } + + headers.Vary = `${headers.Vary}, Access-Control-Request-Headers`; + return c.body(null, 204, headers); + }; + + router.options("/:serverId", handlePreflight); // Wildcard variants to tolerate trailing paths (e.g., /mcp) - router.options("/:serverId/*", (c) => - c.body(null, 204, { - "Access-Control-Allow-Origin": "*", - "Access-Control-Allow-Methods": "GET,POST,HEAD,OPTIONS", - "Access-Control-Allow-Headers": - "*, Authorization, Content-Type, Accept, Accept-Language", - "Access-Control-Expose-Headers": "*", - "Access-Control-Max-Age": "86400", - }), - ); + router.options("/:serverId/*", handlePreflight); async function handleHttp(c: any) { const serverId = c.req.param("serverId"); const method = c.req.method; + const originHeader = c.req.header("origin"); + const { headers: corsHeaders } = buildCorsHeaders(originHeader, { + exposeHeaders: "*", + allowCredentials: true, + }); // SSE endpoint for clients that probe/subscribe via GET; HEAD advertises event-stream if (method === "HEAD") { return c.body(null, 200, { + ...corsHeaders, "Content-Type": "text/event-stream", "Cache-Control": "no-cache", Connection: "keep-alive", - "Access-Control-Allow-Origin": "*", "X-Accel-Buffering": "no", }); } @@ -127,18 +137,17 @@ function createHttpHandler(mode: BridgeMode, routePrefix: string) { }, }); return c.body(stream as any, 200, { + ...corsHeaders, "Content-Type": "text/event-stream", "Cache-Control": "no-cache", Connection: "keep-alive", - "Access-Control-Allow-Origin": "*", - "Access-Control-Expose-Headers": "*", "X-Accel-Buffering": "no", "Transfer-Encoding": "chunked", }); } if (method !== "POST") { - return c.json({ error: "Unsupported request" }, 400); + return c.json({ error: "Unsupported request" }, 400, corsHeaders); } // Parse JSON body (best effort) @@ -172,17 +181,21 @@ function createHttpHandler(mode: BridgeMode, routePrefix: string) { ); if (!response) { // Notification → 202 Accepted - return c.body("Accepted", 202, { "Access-Control-Allow-Origin": "*" }); + return c.body("Accepted", 202, corsHeaders); } return c.body(JSON.stringify(response), 200, { + ...corsHeaders, "Content-Type": "application/json", - "Access-Control-Allow-Origin": "*", - "Access-Control-Expose-Headers": "*", }); } // Endpoint to receive client messages for SSE transport: /:serverId/messages?sessionId=... router.post("/:serverId/messages", async (c) => { + const originHeader = c.req.header("origin"); + const { headers: corsHeaders } = buildCorsHeaders(originHeader, { + exposeHeaders: "*", + allowCredentials: true, + }); const serverId = c.req.param("serverId"); const url = new URL(c.req.url); const sessionId = url.searchParams.get("sessionId") || ""; @@ -195,7 +208,7 @@ function createHttpHandler(mode: BridgeMode, routePrefix: string) { } } if (!sess) { - return c.json({ error: "Invalid session" }, 400); + return c.json({ error: "Invalid session" }, 400, corsHeaders); } let body: any; try { @@ -242,15 +255,9 @@ function createHttpHandler(mode: BridgeMode, routePrefix: string) { } catch {} } // 202 Accepted per SSE transport semantics - return c.body("Accepted", 202, { - "Access-Control-Allow-Origin": "*", - "Access-Control-Expose-Headers": "*", - }); + return c.body("Accepted", 202, corsHeaders); } catch (e: any) { - return c.body("Error", 400, { - "Access-Control-Allow-Origin": "*", - "Access-Control-Expose-Headers": "*", - }); + return c.body("Error", 400, corsHeaders); } }); diff --git a/server/routes/mcp/servers.ts b/server/routes/mcp/servers.ts index 0cba3725d..88e011e90 100644 --- a/server/routes/mcp/servers.ts +++ b/server/routes/mcp/servers.ts @@ -3,6 +3,7 @@ import type { MCPServerConfig } from "@/sdk"; import "../../types/hono"; // Type extensions import { rpcLogBus, type RpcLogEvent } from "../../services/rpc-log-bus"; import { logger } from "../../utils/logger"; +import { buildCorsHeaders } from "../../utils/cors"; const servers = new Hono(); @@ -48,7 +49,7 @@ servers.get("/status/:serverId", async (c) => { status, }); } catch (error) { - logger.error("Error getting server status", error, { serverId }); + logger.error("Error getting server status", error); return c.json( { success: false, @@ -82,7 +83,7 @@ servers.get("/init-info/:serverId", async (c) => { initInfo, }); } catch (error) { - logger.error("Error getting initialization info", error, { serverId }); + logger.error("Error getting initialization info", error); return c.json( { success: false, @@ -119,7 +120,7 @@ servers.delete("/:serverId", async (c) => { message: `Disconnected from server: ${serverId}`, }); } catch (error) { - logger.error("Error disconnecting server", error, { serverId }); + logger.error("Error disconnecting server", error); return c.json( { success: false, @@ -190,7 +191,7 @@ servers.post("/reconnect", async (c) => { ...(success ? {} : { error: message }), }); } catch (error) { - logger.error("Error reconnecting server", error, { serverId }); + logger.error("Error reconnecting server", error); return c.json( { success: false, @@ -203,6 +204,11 @@ servers.post("/reconnect", async (c) => { // Stream JSON-RPC messages over SSE for all servers. servers.get("/rpc/stream", async (c) => { + const originHeader = c.req.header("origin"); + const { headers: corsHeaders } = buildCorsHeaders(originHeader, { + exposeHeaders: "*", + allowCredentials: true, + }); const serverIds = c.mcpClientManager.listServers(); const url = new URL(c.req.url); const replay = parseInt(url.searchParams.get("replay") || "0", 10); @@ -256,11 +262,10 @@ servers.get("/rpc/stream", async (c) => { return new Response(stream, { headers: { + ...corsHeaders, "Content-Type": "text/event-stream", "Cache-Control": "no-cache", Connection: "keep-alive", - "Access-Control-Allow-Origin": "*", - "Access-Control-Expose-Headers": "*", }, }); }); diff --git a/server/utils/cors.ts b/server/utils/cors.ts new file mode 100644 index 000000000..7e3ec30ce --- /dev/null +++ b/server/utils/cors.ts @@ -0,0 +1,84 @@ +import { CORS_ORIGINS, SERVER_HOSTNAME } from "../config"; + +type CorsOptions = { + allowCredentials?: boolean; + allowHeaders?: string; + allowMethods?: string; + exposeHeaders?: string; + maxAge?: string; + requestPrivateNetwork?: boolean; + allowPrivateNetwork?: boolean; +}; + +const normalizedAllowedOrigins = new Set( + CORS_ORIGINS.map((origin) => normalizeOrigin(origin)), +); + +const allowedHostnames = new Set( + [SERVER_HOSTNAME, "localhost", "127.0.0.1", "::1"].map((host) => + host.toLowerCase(), + ), +); + +function normalizeOrigin(origin: string) { + return origin.trim().replace(/\/$/, "").toLowerCase(); +} + +function extractHostname(hostHeader: string | null | undefined) { + if (!hostHeader) return null; + const trimmed = hostHeader.trim(); + if (trimmed.startsWith("[")) { + const closing = trimmed.indexOf("]"); + if (closing !== -1) return trimmed.slice(1, closing).toLowerCase(); + } + return trimmed.split(":")[0]?.toLowerCase() ?? null; +} + +export function getAllowedOrigin(originHeader: string | null | undefined) { + if (!originHeader) return null; + const normalized = normalizeOrigin(originHeader); + return normalizedAllowedOrigins.has(normalized) ? normalized : null; +} + +export function isAllowedHost(hostHeader: string | null | undefined) { + const hostname = extractHostname(hostHeader); + if (!hostname) return true; + return allowedHostnames.has(hostname); +} + +export function buildCorsHeaders( + originHeader: string | null | undefined, + options: CorsOptions = {}, +) { + const allowedOrigin = getAllowedOrigin(originHeader); + const headers: Record = { Vary: "Origin" }; + + if (allowedOrigin) { + headers["Access-Control-Allow-Origin"] = allowedOrigin; + if (options.allowCredentials) { + headers["Access-Control-Allow-Credentials"] = "true"; + } + if ( + options.allowPrivateNetwork && + options.requestPrivateNetwork && + originHeader + ) { + headers["Access-Control-Allow-Private-Network"] = "true"; + } + } + + if (options.allowMethods) { + headers["Access-Control-Allow-Methods"] = options.allowMethods; + } + if (options.allowHeaders) { + headers["Access-Control-Allow-Headers"] = options.allowHeaders; + } + if (options.exposeHeaders) { + headers["Access-Control-Expose-Headers"] = options.exposeHeaders; + } + if (options.maxAge) { + headers["Access-Control-Max-Age"] = options.maxAge; + } + + return { headers, allowedOrigin }; +}