import type { IncomingMessage } from "node:http"; import os from "node:os"; import type { WebSocket } from "ws"; import { deriveDeviceIdFromPublicKey, normalizeDevicePublicKeyBase64Url, verifyDeviceSignature, } from "../../../infra/device-identity.js"; import { approveDevicePairing, ensureDeviceToken, getPairedDevice, requestDevicePairing, updatePairedDeviceMetadata, verifyDeviceToken, } from "../../../infra/device-pairing.js"; import { recordRemoteNodeInfo, refreshRemoteNodeBins } from "../../../infra/skills-remote.js"; import { loadVoiceWakeConfig } from "../../../infra/voicewake.js"; import { upsertPresence } from "../../../infra/system-presence.js"; import { rawDataToString } from "../../../infra/ws.js"; import type { createSubsystemLogger } from "../../../logging/subsystem.js"; import { isGatewayCliClient, isWebchatClient } from "../../../utils/message-channel.js"; import type { ResolvedGatewayAuth } from "../../auth.js"; import { authorizeGatewayConnect } from "../../auth.js"; import { loadConfig } from "../../../config/config.js"; import { buildDeviceAuthPayload } from "../../device-auth.js"; import { isLoopbackAddress } from "../../net.js"; import { resolveNodeCommandAllowlist } from "../../node-command-policy.js"; import { type ConnectParams, ErrorCodes, type ErrorShape, errorShape, formatValidationErrors, PROTOCOL_VERSION, type RequestFrame, validateConnectParams, validateRequestFrame, } from "../../protocol/index.js"; import { MAX_BUFFERED_BYTES, MAX_PAYLOAD_BYTES, TICK_INTERVAL_MS } from "../../server-constants.js"; import type { GatewayRequestContext, GatewayRequestHandlers } from "../../server-methods/types.js"; import { handleGatewayRequest } from "../../server-methods.js"; import { formatError } from "../../server-utils.js"; import { formatForLog, logWs } from "../../ws-log.js"; import { truncateCloseReason } from "../close-reason.js"; import { buildGatewaySnapshot, getHealthCache, getHealthVersion, incrementPresenceVersion, refreshGatewayHealthSnapshot, } from "../health-state.js"; import type { GatewayWsClient } from "../ws-types.js"; type SubsystemLogger = ReturnType; const DEVICE_SIGNATURE_SKEW_MS = 10 * 60 * 1000; export function attachGatewayWsMessageHandler(params: { socket: WebSocket; upgradeReq: IncomingMessage; connId: string; remoteAddr?: string; forwardedFor?: string; requestHost?: string; requestOrigin?: string; requestUserAgent?: string; canvasHostUrl?: string; connectNonce: string; resolvedAuth: ResolvedGatewayAuth; gatewayMethods: string[]; events: string[]; extraHandlers: GatewayRequestHandlers; buildRequestContext: () => GatewayRequestContext; send: (obj: unknown) => void; close: (code?: number, reason?: string) => void; isClosed: () => boolean; clearHandshakeTimer: () => void; getClient: () => GatewayWsClient | null; setClient: (next: GatewayWsClient) => void; setHandshakeState: (state: "pending" | "connected" | "failed") => void; setCloseCause: (cause: string, meta?: Record) => void; setLastFrameMeta: (meta: { type?: string; method?: string; id?: string }) => void; logGateway: SubsystemLogger; logHealth: SubsystemLogger; logWsControl: SubsystemLogger; }) { const { socket, upgradeReq, connId, remoteAddr, forwardedFor, requestHost, requestOrigin, requestUserAgent, canvasHostUrl, connectNonce, resolvedAuth, gatewayMethods, events, extraHandlers, buildRequestContext, send, close, isClosed, clearHandshakeTimer, getClient, setClient, setHandshakeState, setCloseCause, setLastFrameMeta, logGateway, logHealth, logWsControl, } = params; const isWebchatConnect = (p: ConnectParams | null | undefined) => isWebchatClient(p?.client); socket.on("message", async (data) => { if (isClosed()) return; const text = rawDataToString(data); try { const parsed = JSON.parse(text); const frameType = parsed && typeof parsed === "object" && "type" in parsed ? typeof (parsed as { type?: unknown }).type === "string" ? String((parsed as { type?: unknown }).type) : undefined : undefined; const frameMethod = parsed && typeof parsed === "object" && "method" in parsed ? typeof (parsed as { method?: unknown }).method === "string" ? String((parsed as { method?: unknown }).method) : undefined : undefined; const frameId = parsed && typeof parsed === "object" && "id" in parsed ? typeof (parsed as { id?: unknown }).id === "string" ? String((parsed as { id?: unknown }).id) : undefined : undefined; if (frameType || frameMethod || frameId) { setLastFrameMeta({ type: frameType, method: frameMethod, id: frameId }); } const client = getClient(); if (!client) { // Handshake must be a normal request: // { type:"req", method:"connect", params: ConnectParams }. const isRequestFrame = validateRequestFrame(parsed); if ( !isRequestFrame || (parsed as RequestFrame).method !== "connect" || !validateConnectParams((parsed as RequestFrame).params) ) { const handshakeError = isRequestFrame ? (parsed as RequestFrame).method === "connect" ? `invalid connect params: ${formatValidationErrors(validateConnectParams.errors)}` : "invalid handshake: first request must be connect" : "invalid request frame"; setHandshakeState("failed"); setCloseCause("invalid-handshake", { frameType, frameMethod, frameId, handshakeError, }); if (isRequestFrame) { const req = parsed as RequestFrame; send({ type: "res", id: req.id, ok: false, error: errorShape(ErrorCodes.INVALID_REQUEST, handshakeError), }); } else { logWsControl.warn( `invalid handshake conn=${connId} remote=${remoteAddr ?? "?"} fwd=${forwardedFor ?? "n/a"} origin=${requestOrigin ?? "n/a"} host=${requestHost ?? "n/a"} ua=${requestUserAgent ?? "n/a"}`, ); } const closeReason = truncateCloseReason(handshakeError || "invalid handshake"); if (isRequestFrame) { queueMicrotask(() => close(1008, closeReason)); } else { close(1008, closeReason); } return; } const frame = parsed as RequestFrame; const connectParams = frame.params as ConnectParams; const clientLabel = connectParams.client.displayName ?? connectParams.client.id; // protocol negotiation const { minProtocol, maxProtocol } = connectParams; if (maxProtocol < PROTOCOL_VERSION || minProtocol > PROTOCOL_VERSION) { setHandshakeState("failed"); logWsControl.warn( `protocol mismatch conn=${connId} remote=${remoteAddr ?? "?"} client=${clientLabel} ${connectParams.client.mode} v${connectParams.client.version}`, ); setCloseCause("protocol-mismatch", { minProtocol, maxProtocol, expectedProtocol: PROTOCOL_VERSION, client: connectParams.client.id, clientDisplayName: connectParams.client.displayName, mode: connectParams.client.mode, version: connectParams.client.version, }); send({ type: "res", id: frame.id, ok: false, error: errorShape(ErrorCodes.INVALID_REQUEST, "protocol mismatch", { details: { expectedProtocol: PROTOCOL_VERSION }, }), }); close(1002, "protocol mismatch"); return; } const roleRaw = connectParams.role ?? "operator"; const role = roleRaw === "operator" || roleRaw === "node" ? roleRaw : null; if (!role) { setHandshakeState("failed"); setCloseCause("invalid-role", { role: roleRaw, client: connectParams.client.id, clientDisplayName: connectParams.client.displayName, mode: connectParams.client.mode, version: connectParams.client.version, }); send({ type: "res", id: frame.id, ok: false, error: errorShape(ErrorCodes.INVALID_REQUEST, "invalid role"), }); close(1008, "invalid role"); return; } const requestedScopes = Array.isArray(connectParams.scopes) ? connectParams.scopes : []; const scopes = requestedScopes.length > 0 ? requestedScopes : role === "operator" ? ["operator.admin"] : []; connectParams.role = role; connectParams.scopes = scopes; const device = connectParams.device; let devicePublicKey: string | null = null; // Allow token-authenticated connections (e.g., control-ui) to skip device identity const hasTokenAuth = !!connectParams.auth?.token; if (!device && !hasTokenAuth) { setHandshakeState("failed"); setCloseCause("device-required", { client: connectParams.client.id, clientDisplayName: connectParams.client.displayName, mode: connectParams.client.mode, version: connectParams.client.version, }); send({ type: "res", id: frame.id, ok: false, error: errorShape(ErrorCodes.NOT_PAIRED, "device identity required"), }); close(1008, "device identity required"); return; } if (device) { const derivedId = deriveDeviceIdFromPublicKey(device.publicKey); if (!derivedId || derivedId !== device.id) { setHandshakeState("failed"); setCloseCause("device-auth-invalid", { reason: "device-id-mismatch", client: connectParams.client.id, deviceId: device.id, }); send({ type: "res", id: frame.id, ok: false, error: errorShape(ErrorCodes.INVALID_REQUEST, "device identity mismatch"), }); close(1008, "device identity mismatch"); return; } const signedAt = device.signedAt; if ( typeof signedAt !== "number" || Math.abs(Date.now() - signedAt) > DEVICE_SIGNATURE_SKEW_MS ) { setHandshakeState("failed"); setCloseCause("device-auth-invalid", { reason: "device-signature-stale", client: connectParams.client.id, deviceId: device.id, }); send({ type: "res", id: frame.id, ok: false, error: errorShape(ErrorCodes.INVALID_REQUEST, "device signature expired"), }); close(1008, "device signature expired"); return; } const nonceRequired = !isLoopbackAddress(remoteAddr); const providedNonce = typeof device.nonce === "string" ? device.nonce.trim() : ""; if (nonceRequired && !providedNonce) { setHandshakeState("failed"); setCloseCause("device-auth-invalid", { reason: "device-nonce-missing", client: connectParams.client.id, deviceId: device.id, }); send({ type: "res", id: frame.id, ok: false, error: errorShape(ErrorCodes.INVALID_REQUEST, "device nonce required"), }); close(1008, "device nonce required"); return; } if (providedNonce && providedNonce !== connectNonce) { setHandshakeState("failed"); setCloseCause("device-auth-invalid", { reason: "device-nonce-mismatch", client: connectParams.client.id, deviceId: device.id, }); send({ type: "res", id: frame.id, ok: false, error: errorShape(ErrorCodes.INVALID_REQUEST, "device nonce mismatch"), }); close(1008, "device nonce mismatch"); return; } const payload = buildDeviceAuthPayload({ deviceId: device.id, clientId: connectParams.client.id, clientMode: connectParams.client.mode, role, scopes: requestedScopes, signedAtMs: signedAt, token: connectParams.auth?.token ?? null, nonce: providedNonce || undefined, version: providedNonce ? "v2" : "v1", }); const signatureOk = verifyDeviceSignature(device.publicKey, payload, device.signature); const allowLegacy = !nonceRequired && !providedNonce; if (!signatureOk && allowLegacy) { const legacyPayload = buildDeviceAuthPayload({ deviceId: device.id, clientId: connectParams.client.id, clientMode: connectParams.client.mode, role, scopes: requestedScopes, signedAtMs: signedAt, token: connectParams.auth?.token ?? null, version: "v1", }); if (verifyDeviceSignature(device.publicKey, legacyPayload, device.signature)) { // accepted legacy loopback signature } else { setHandshakeState("failed"); setCloseCause("device-auth-invalid", { reason: "device-signature", client: connectParams.client.id, deviceId: device.id, }); send({ type: "res", id: frame.id, ok: false, error: errorShape(ErrorCodes.INVALID_REQUEST, "device signature invalid"), }); close(1008, "device signature invalid"); return; } } else if (!signatureOk) { setHandshakeState("failed"); setCloseCause("device-auth-invalid", { reason: "device-signature", client: connectParams.client.id, deviceId: device.id, }); send({ type: "res", id: frame.id, ok: false, error: errorShape(ErrorCodes.INVALID_REQUEST, "device signature invalid"), }); close(1008, "device signature invalid"); return; } devicePublicKey = normalizeDevicePublicKeyBase64Url(device.publicKey); if (!devicePublicKey) { setHandshakeState("failed"); setCloseCause("device-auth-invalid", { reason: "device-public-key", client: connectParams.client.id, deviceId: device.id, }); send({ type: "res", id: frame.id, ok: false, error: errorShape(ErrorCodes.INVALID_REQUEST, "device public key invalid"), }); close(1008, "device public key invalid"); return; } } const authResult = await authorizeGatewayConnect({ auth: resolvedAuth, connectAuth: connectParams.auth, req: upgradeReq, }); let authOk = authResult.ok; let authMethod = authResult.method ?? "none"; if (!authOk && connectParams.auth?.token && device) { const tokenCheck = await verifyDeviceToken({ deviceId: device.id, token: connectParams.auth.token, role, scopes, }); if (tokenCheck.ok) { authOk = true; authMethod = "device-token"; } } if (!authOk) { setHandshakeState("failed"); logWsControl.warn( `unauthorized conn=${connId} remote=${remoteAddr ?? "?"} client=${clientLabel} ${connectParams.client.mode} v${connectParams.client.version}`, ); const authProvided = connectParams.auth?.token ? "token" : connectParams.auth?.password ? "password" : "none"; setCloseCause("unauthorized", { authMode: resolvedAuth.mode, authProvided, authReason: authResult.reason, allowTailscale: resolvedAuth.allowTailscale, client: connectParams.client.id, clientDisplayName: connectParams.client.displayName, mode: connectParams.client.mode, version: connectParams.client.version, }); send({ type: "res", id: frame.id, ok: false, error: errorShape(ErrorCodes.INVALID_REQUEST, "unauthorized"), }); close(1008, "unauthorized"); return; } if (device && devicePublicKey) { const requirePairing = async (reason: string, _paired?: { deviceId: string }) => { const pairing = await requestDevicePairing({ deviceId: device.id, publicKey: devicePublicKey, displayName: connectParams.client.displayName, platform: connectParams.client.platform, clientId: connectParams.client.id, clientMode: connectParams.client.mode, role, scopes, remoteIp: remoteAddr, silent: isLoopbackAddress(remoteAddr), }); const context = buildRequestContext(); if (pairing.request.silent === true) { const approved = await approveDevicePairing(pairing.request.requestId); if (approved) { logGateway.info( `device pairing auto-approved device=${approved.device.deviceId} role=${approved.device.role ?? "unknown"}`, ); context.broadcast( "device.pair.resolved", { requestId: pairing.request.requestId, deviceId: approved.device.deviceId, decision: "approved", ts: Date.now(), }, { dropIfSlow: true }, ); } } else if (pairing.created) { context.broadcast("device.pair.requested", pairing.request, { dropIfSlow: true }); } if (pairing.request.silent !== true) { setHandshakeState("failed"); setCloseCause("pairing-required", { deviceId: device.id, requestId: pairing.request.requestId, reason, }); send({ type: "res", id: frame.id, ok: false, error: errorShape(ErrorCodes.NOT_PAIRED, "pairing required", { details: { requestId: pairing.request.requestId }, }), }); close(1008, "pairing required"); return false; } return true; }; const paired = await getPairedDevice(device.id); const isPaired = paired?.publicKey === devicePublicKey; if (!isPaired) { const ok = await requirePairing("not-paired"); if (!ok) return; } else { const allowedRoles = new Set( Array.isArray(paired.roles) ? paired.roles : paired.role ? [paired.role] : [], ); if (allowedRoles.size === 0) { const ok = await requirePairing("role-upgrade", paired); if (!ok) return; } else if (!allowedRoles.has(role)) { const ok = await requirePairing("role-upgrade", paired); if (!ok) return; } const pairedScopes = Array.isArray(paired.scopes) ? paired.scopes : []; if (scopes.length > 0) { if (pairedScopes.length === 0) { const ok = await requirePairing("scope-upgrade", paired); if (!ok) return; } else { const allowedScopes = new Set(pairedScopes); const missingScope = scopes.find((scope) => !allowedScopes.has(scope)); if (missingScope) { const ok = await requirePairing("scope-upgrade", paired); if (!ok) return; } } } await updatePairedDeviceMetadata(device.id, { displayName: connectParams.client.displayName, platform: connectParams.client.platform, clientId: connectParams.client.id, clientMode: connectParams.client.mode, role, scopes, remoteIp: remoteAddr, }); } } const deviceToken = device ? await ensureDeviceToken({ deviceId: device.id, role, scopes }) : null; if (role === "node") { const cfg = loadConfig(); const allowlist = resolveNodeCommandAllowlist(cfg, { platform: connectParams.client.platform, deviceFamily: connectParams.client.deviceFamily, }); const declared = Array.isArray(connectParams.commands) ? connectParams.commands : []; const filtered = declared .map((cmd) => cmd.trim()) .filter((cmd) => cmd.length > 0 && allowlist.has(cmd)); connectParams.commands = filtered; } const shouldTrackPresence = !isGatewayCliClient(connectParams.client); const clientId = connectParams.client.id; const instanceId = connectParams.client.instanceId; const presenceKey = shouldTrackPresence ? (connectParams.device?.id ?? instanceId ?? connId) : undefined; logWs("in", "connect", { connId, client: connectParams.client.id, clientDisplayName: connectParams.client.displayName, version: connectParams.client.version, mode: connectParams.client.mode, clientId, platform: connectParams.client.platform, auth: authMethod, }); if (isWebchatConnect(connectParams)) { logWsControl.info( `webchat connected conn=${connId} remote=${remoteAddr ?? "?"} client=${clientLabel} ${connectParams.client.mode} v${connectParams.client.version}`, ); } if (presenceKey) { upsertPresence(presenceKey, { host: connectParams.client.displayName ?? connectParams.client.id ?? os.hostname(), ip: isLoopbackAddress(remoteAddr) ? undefined : remoteAddr, version: connectParams.client.version, platform: connectParams.client.platform, deviceFamily: connectParams.client.deviceFamily, modelIdentifier: connectParams.client.modelIdentifier, mode: connectParams.client.mode, deviceId: connectParams.device?.id, roles: [role], scopes, instanceId: connectParams.device?.id ?? instanceId, reason: "connect", }); incrementPresenceVersion(); } const snapshot = buildGatewaySnapshot(); const cachedHealth = getHealthCache(); if (cachedHealth) { snapshot.health = cachedHealth; snapshot.stateVersion.health = getHealthVersion(); } const helloOk = { type: "hello-ok", protocol: PROTOCOL_VERSION, server: { version: process.env.CLAWDBOT_VERSION ?? process.env.npm_package_version ?? "dev", commit: process.env.GIT_COMMIT, host: os.hostname(), connId, }, features: { methods: gatewayMethods, events }, snapshot, canvasHostUrl, auth: deviceToken ? { deviceToken: deviceToken.token, role: deviceToken.role, scopes: deviceToken.scopes, issuedAtMs: deviceToken.rotatedAtMs ?? deviceToken.createdAtMs, } : undefined, policy: { maxPayload: MAX_PAYLOAD_BYTES, maxBufferedBytes: MAX_BUFFERED_BYTES, tickIntervalMs: TICK_INTERVAL_MS, }, }; clearHandshakeTimer(); const nextClient: GatewayWsClient = { socket, connect: connectParams, connId, presenceKey, }; setClient(nextClient); setHandshakeState("connected"); if (role === "node") { const context = buildRequestContext(); const nodeSession = context.nodeRegistry.register(nextClient, { remoteIp: remoteAddr }); recordRemoteNodeInfo({ nodeId: nodeSession.nodeId, displayName: nodeSession.displayName, platform: nodeSession.platform, deviceFamily: nodeSession.deviceFamily, commands: nodeSession.commands, remoteIp: nodeSession.remoteIp, }); void refreshRemoteNodeBins({ nodeId: nodeSession.nodeId, platform: nodeSession.platform, deviceFamily: nodeSession.deviceFamily, commands: nodeSession.commands, cfg: loadConfig(), }).catch((err) => logGateway.warn( `remote bin probe failed for ${nodeSession.nodeId}: ${formatForLog(err)}`, ), ); void loadVoiceWakeConfig() .then((cfg) => { context.nodeRegistry.sendEvent(nodeSession.nodeId, "voicewake.changed", { triggers: cfg.triggers, }); }) .catch((err) => logGateway.warn( `voicewake snapshot failed for ${nodeSession.nodeId}: ${formatForLog(err)}`, ), ); } logWs("out", "hello-ok", { connId, methods: gatewayMethods.length, events: events.length, presence: snapshot.presence.length, stateVersion: snapshot.stateVersion.presence, }); send({ type: "res", id: frame.id, ok: true, payload: helloOk }); void refreshGatewayHealthSnapshot({ probe: true }).catch((err) => logHealth.error(`post-connect health refresh failed: ${formatError(err)}`), ); return; } // After handshake, accept only req frames if (!validateRequestFrame(parsed)) { send({ type: "res", id: (parsed as { id?: unknown })?.id ?? "invalid", ok: false, error: errorShape( ErrorCodes.INVALID_REQUEST, `invalid request frame: ${formatValidationErrors(validateRequestFrame.errors)}`, ), }); return; } const req = parsed as RequestFrame; logWs("in", "req", { connId, id: req.id, method: req.method }); const respond = ( ok: boolean, payload?: unknown, error?: ErrorShape, meta?: Record, ) => { send({ type: "res", id: req.id, ok, payload, error }); logWs("out", "res", { connId, id: req.id, ok, method: req.method, errorCode: error?.code, errorMessage: error?.message, ...meta, }); }; void (async () => { await handleGatewayRequest({ req, respond, client, isWebchatConnect, extraHandlers, context: buildRequestContext(), }); })().catch((err) => { logGateway.error(`request handler failed: ${formatForLog(err)}`); respond(false, undefined, errorShape(ErrorCodes.UNAVAILABLE, formatForLog(err))); }); } catch (err) { logGateway.error(`parse/handle error: ${String(err)}`); logWs("out", "parse-error", { connId, error: formatForLog(err) }); if (!getClient()) { close(); } } }); }