feat(security): integrate security shield with gateway

This commit is contained in:
Ulrich Diedrichsen 2026-01-30 10:45:59 +01:00
parent 79597b7a98
commit 18a01881c5
4 changed files with 62 additions and 2 deletions

View File

@ -3,6 +3,7 @@ import type { IncomingMessage } from "node:http";
import type { GatewayAuthConfig, GatewayTailscaleMode } from "../config/config.js"; import type { GatewayAuthConfig, GatewayTailscaleMode } from "../config/config.js";
import { readTailscaleWhoisIdentity, type TailscaleWhoisIdentity } from "../infra/tailscale.js"; import { readTailscaleWhoisIdentity, type TailscaleWhoisIdentity } from "../infra/tailscale.js";
import { isTrustedProxyAddress, parseForwardedForClientIp, resolveGatewayClientIp } from "./net.js"; import { isTrustedProxyAddress, parseForwardedForClientIp, resolveGatewayClientIp } from "./net.js";
import { checkAuthRateLimit, logAuthFailure } from "../security/middleware.js";
export type ResolvedGatewayAuthMode = "token" | "password"; export type ResolvedGatewayAuthMode = "token" | "password";
export type ResolvedGatewayAuth = { export type ResolvedGatewayAuth = {
@ -207,11 +208,23 @@ export async function authorizeGatewayConnect(params: {
req?: IncomingMessage; req?: IncomingMessage;
trustedProxies?: string[]; trustedProxies?: string[];
tailscaleWhois?: TailscaleWhoisLookup; tailscaleWhois?: TailscaleWhoisLookup;
deviceId?: string;
}): Promise<GatewayAuthResult> { }): Promise<GatewayAuthResult> {
const { auth, connectAuth, req, trustedProxies } = params; const { auth, connectAuth, req, trustedProxies, deviceId } = params;
const tailscaleWhois = params.tailscaleWhois ?? readTailscaleWhoisIdentity; const tailscaleWhois = params.tailscaleWhois ?? readTailscaleWhoisIdentity;
const localDirect = isLocalDirectRequest(req, trustedProxies); const localDirect = isLocalDirectRequest(req, trustedProxies);
// Security: Check auth rate limit
if (req) {
const rateCheck = checkAuthRateLimit(req, deviceId);
if (!rateCheck.allowed) {
return {
ok: false,
reason: rateCheck.reason ?? "rate_limit_exceeded",
};
}
}
if (auth.allowTailscale && !localDirect) { if (auth.allowTailscale && !localDirect) {
const tailscaleCheck = await resolveVerifiedTailscaleUser({ const tailscaleCheck = await resolveVerifiedTailscaleUser({
req, req,
@ -234,6 +247,10 @@ export async function authorizeGatewayConnect(params: {
return { ok: false, reason: "token_missing" }; return { ok: false, reason: "token_missing" };
} }
if (!safeEqual(connectAuth.token, auth.token)) { if (!safeEqual(connectAuth.token, auth.token)) {
// Security: Log failed auth for intrusion detection
if (req) {
logAuthFailure(req, "token_mismatch", deviceId);
}
return { ok: false, reason: "token_mismatch" }; return { ok: false, reason: "token_mismatch" };
} }
return { ok: true, method: "token" }; return { ok: true, method: "token" };
@ -248,10 +265,18 @@ export async function authorizeGatewayConnect(params: {
return { ok: false, reason: "password_missing" }; return { ok: false, reason: "password_missing" };
} }
if (!safeEqual(password, auth.password)) { if (!safeEqual(password, auth.password)) {
// Security: Log failed auth for intrusion detection
if (req) {
logAuthFailure(req, "password_mismatch", deviceId);
}
return { ok: false, reason: "password_mismatch" }; return { ok: false, reason: "password_mismatch" };
} }
return { ok: true, method: "password" }; return { ok: true, method: "password" };
} }
// Security: Log unauthorized attempts
if (req) {
logAuthFailure(req, "unauthorized", deviceId);
}
return { ok: false, reason: "unauthorized" }; return { ok: false, reason: "unauthorized" };
} }

View File

@ -28,6 +28,8 @@ import {
} from "./hooks.js"; } from "./hooks.js";
import { applyHookMappings } from "./hooks-mapping.js"; import { applyHookMappings } from "./hooks-mapping.js";
import { handleOpenAiHttpRequest } from "./openai-http.js"; import { handleOpenAiHttpRequest } from "./openai-http.js";
import { checkWebhookRateLimit } from "../security/middleware.js";
import { SecurityShield } from "../security/shield.js";
import { handleOpenResponsesHttpRequest } from "./openresponses-http.js"; import { handleOpenResponsesHttpRequest } from "./openresponses-http.js";
import { handleToolsInvokeHttpRequest } from "./tools-invoke-http.js"; import { handleToolsInvokeHttpRequest } from "./tools-invoke-http.js";
@ -91,6 +93,21 @@ export function createHooksRequestHandler(
); );
} }
// Security: Check webhook rate limit
const subPath = url.pathname.slice(basePath.length).replace(/^\/+/, "");
const rateCheck = checkWebhookRateLimit({
token: token,
path: subPath,
ip: SecurityShield.extractIp(req),
});
if (!rateCheck.allowed) {
res.statusCode = 429;
res.setHeader("Retry-After", String(Math.ceil((rateCheck.retryAfterMs ?? 60000) / 1000)));
res.setHeader("Content-Type", "text/plain; charset=utf-8");
res.end("Too Many Requests");
return true;
}
if (req.method !== "POST") { if (req.method !== "POST") {
res.statusCode = 405; res.statusCode = 405;
res.setHeader("Allow", "POST"); res.setHeader("Allow", "POST");
@ -99,7 +116,6 @@ export function createHooksRequestHandler(
return true; return true;
} }
const subPath = url.pathname.slice(basePath.length).replace(/^\/+/, "");
if (!subPath) { if (!subPath) {
res.statusCode = 404; res.statusCode = 404;
res.setHeader("Content-Type", "text/plain; charset=utf-8"); res.setHeader("Content-Type", "text/plain; charset=utf-8");

View File

@ -58,6 +58,7 @@ import { loadGatewayModelCatalog } from "./server-model-catalog.js";
import { NodeRegistry } from "./node-registry.js"; import { NodeRegistry } from "./node-registry.js";
import { createNodeSubscriptionManager } from "./server-node-subscriptions.js"; import { createNodeSubscriptionManager } from "./server-node-subscriptions.js";
import { safeParseJson } from "./server-methods/nodes.helpers.js"; import { safeParseJson } from "./server-methods/nodes.helpers.js";
import { initSecurityShield } from "../security/shield.js";
import { loadGatewayPlugins } from "./server-plugins.js"; import { loadGatewayPlugins } from "./server-plugins.js";
import { createGatewayReloadHandlers } from "./server-reload-handlers.js"; import { createGatewayReloadHandlers } from "./server-reload-handlers.js";
import { resolveGatewayRuntimeConfig } from "./server-runtime-config.js"; import { resolveGatewayRuntimeConfig } from "./server-runtime-config.js";
@ -215,6 +216,10 @@ export async function startGatewayServer(
startDiagnosticHeartbeat(); startDiagnosticHeartbeat();
} }
setGatewaySigusr1RestartPolicy({ allowExternal: cfgAtStart.commands?.restart === true }); setGatewaySigusr1RestartPolicy({ allowExternal: cfgAtStart.commands?.restart === true });
// Initialize security shield with configuration
initSecurityShield(cfgAtStart.security?.shield);
initSubagentRegistry(); initSubagentRegistry();
const defaultAgentId = resolveDefaultAgentId(cfgAtStart); const defaultAgentId = resolveDefaultAgentId(cfgAtStart);
const defaultWorkspaceDir = resolveAgentWorkspaceDir(cfgAtStart, defaultAgentId); const defaultWorkspaceDir = resolveAgentWorkspaceDir(cfgAtStart, defaultAgentId);

View File

@ -7,6 +7,7 @@ import lockfile from "proper-lockfile";
import { getPairingAdapter } from "../channels/plugins/pairing.js"; import { getPairingAdapter } from "../channels/plugins/pairing.js";
import type { ChannelId, ChannelPairingAdapter } from "../channels/plugins/types.js"; import type { ChannelId, ChannelPairingAdapter } from "../channels/plugins/types.js";
import { resolveOAuthDir, resolveStateDir } from "../config/paths.js"; import { resolveOAuthDir, resolveStateDir } from "../config/paths.js";
import { checkPairingRateLimit } from "../security/middleware.js";
const PAIRING_CODE_LENGTH = 8; const PAIRING_CODE_LENGTH = 8;
const PAIRING_CODE_ALPHABET = "ABCDEFGHJKLMNPQRSTUVWXYZ23456789"; const PAIRING_CODE_ALPHABET = "ABCDEFGHJKLMNPQRSTUVWXYZ23456789";
@ -328,6 +329,19 @@ export async function upsertChannelPairingRequest(params: {
pairingAdapter?: ChannelPairingAdapter; pairingAdapter?: ChannelPairingAdapter;
}): Promise<{ code: string; created: boolean }> { }): Promise<{ code: string; created: boolean }> {
const env = params.env ?? process.env; const env = params.env ?? process.env;
// Security: Check pairing rate limit
const sender = normalizeId(params.id);
const rateCheck = checkPairingRateLimit({
channel: String(params.channel),
sender,
ip: "unknown", // Pairing happens at channel level, not HTTP
});
if (!rateCheck.allowed) {
// Rate limited - return empty code without creating request
return { code: "", created: false };
}
const filePath = resolvePairingPath(params.channel, env); const filePath = resolvePairingPath(params.channel, env);
return await withFileLock( return await withFileLock(
filePath, filePath,