From 8262a0306074f7b0eb1360e0dd626839feaa30a7 Mon Sep 17 00:00:00 2001 From: gerald Ruby Date: Thu, 29 Jan 2026 11:32:13 -0800 Subject: [PATCH] feat: add GitHub 2FA gate extension for sensitive tools Add a new extension that gates sensitive tool calls (exec, Bash, Write, Edit, NotebookEdit) behind GitHub Device Flow authentication. Users must approve on GitHub Mobile or enter a code at github.com/login/device before the bot can execute dangerous operations. Key changes: - Wire up before_tool_call hook in tool execution path (tool-hook-wrapper.ts) - Create 2fa-github extension with: - GitHub Device Authorization Flow implementation - File-based session store with TTL (~/.clawdbot/2fa-sessions.json) - Non-blocking flow: returns immediately with code, user retries after approval - Configurable tool list and session TTL (default 30 min) Configuration: plugins.entries.2fa-github.config.clientId: "Ov23..." # or GITHUB_2FA_CLIENT_ID env var Co-Authored-By: Claude Opus 4.5 --- extensions/2fa-github/index.ts | 78 +++++++ extensions/2fa-github/moltbot.plugin.json | 46 ++++ extensions/2fa-github/package.json | 14 ++ extensions/2fa-github/src/config.ts | 51 +++++ extensions/2fa-github/src/device-flow.ts | 210 ++++++++++++++++++ extensions/2fa-github/src/hook.ts | 152 +++++++++++++ extensions/2fa-github/src/session-store.ts | 173 +++++++++++++++ extensions/2fa-github/src/types.ts | 43 ++++ src/agents/pi-embedded-runner/run/attempt.ts | 10 +- .../run/tool-hook-wrapper.ts | 79 +++++++ 10 files changed, 855 insertions(+), 1 deletion(-) create mode 100644 extensions/2fa-github/index.ts create mode 100644 extensions/2fa-github/moltbot.plugin.json create mode 100644 extensions/2fa-github/package.json create mode 100644 extensions/2fa-github/src/config.ts create mode 100644 extensions/2fa-github/src/device-flow.ts create mode 100644 extensions/2fa-github/src/hook.ts create mode 100644 extensions/2fa-github/src/session-store.ts create mode 100644 extensions/2fa-github/src/types.ts create mode 100644 src/agents/pi-embedded-runner/run/tool-hook-wrapper.ts diff --git a/extensions/2fa-github/index.ts b/extensions/2fa-github/index.ts new file mode 100644 index 000000000..06574a719 --- /dev/null +++ b/extensions/2fa-github/index.ts @@ -0,0 +1,78 @@ +/** + * GitHub Mobile 2FA Gate Extension + * + * Gates sensitive tool calls behind GitHub Mobile push authentication. + * Users must approve on their phone before the bot can execute file writes, + * shell commands, or other dangerous operations. + * + * Configuration: + * ```yaml + * plugins: + * 2fa-github: + * enabled: true + * clientId: "Iv1.your_client_id_here" + * tokenTtlMinutes: 30 + * sensitiveTools: + * - Bash + * - Write + * - Edit + * - NotebookEdit + * gateAllTools: false + * ``` + * + * Or via environment variable: + * ```bash + * export GITHUB_2FA_CLIENT_ID="Iv1.your_client_id_here" + * ``` + * + * GitHub OAuth App Setup: + * 1. Go to GitHub Settings > Developer Settings > OAuth Apps + * 2. Click "New OAuth App" + * 3. Fill in application name and URLs (callback URL not used) + * 4. IMPORTANT: Check "Enable Device Flow" + * 5. Copy the Client ID (no secret needed for device flow) + */ + +import type { MoltbotPluginApi } from "clawdbot/plugin-sdk"; +import { register2FAHook } from "./src/hook.js"; +import { twoFactorConfigSchema } from "./src/config.js"; + +const plugin = { + id: "2fa-github", + name: "GitHub Mobile 2FA Gate", + description: "Gates sensitive tools behind GitHub Mobile push authentication", + configSchema: twoFactorConfigSchema, + + register(api: MoltbotPluginApi) { + register2FAHook(api); + + // Register CLI commands for managing 2FA sessions + api.registerCli( + ({ program }) => { + const twofa = program.command("2fa").description("GitHub 2FA gate commands"); + + twofa + .command("status") + .description("Show 2FA session status") + .action(async () => { + const { getStats } = await import("./src/session-store.js"); + const stats = getStats(); + console.log(`Active sessions: ${stats.sessionCount}`); + console.log(`Pending verifications: ${stats.pendingCount}`); + }); + + twofa + .command("clear") + .description("Clear all 2FA sessions") + .action(async () => { + const { clearAll } = await import("./src/session-store.js"); + clearAll(); + console.log("All 2FA sessions cleared"); + }); + }, + { commands: ["2fa"] }, + ); + }, +}; + +export default plugin; diff --git a/extensions/2fa-github/moltbot.plugin.json b/extensions/2fa-github/moltbot.plugin.json new file mode 100644 index 000000000..818004d10 --- /dev/null +++ b/extensions/2fa-github/moltbot.plugin.json @@ -0,0 +1,46 @@ +{ + "id": "2fa-github", + "name": "GitHub Mobile 2FA Gate", + "description": "Gates sensitive tools behind GitHub Mobile push authentication", + "uiHints": { + "clientId": { + "label": "GitHub OAuth App Client ID", + "placeholder": "Ov23xxxxxxxxxxxxxxxxxx", + "help": "Create at GitHub Settings > Developer Settings > OAuth Apps (enable Device Flow)" + }, + "tokenTtlMinutes": { + "label": "Session TTL (minutes)", + "placeholder": "30", + "help": "How long before re-authentication is required" + }, + "sensitiveTools": { + "label": "Sensitive Tools", + "help": "Tool names requiring 2FA (default: Bash, Write, Edit, NotebookEdit)" + }, + "gateAllTools": { + "label": "Gate All Tools", + "help": "Require 2FA for all tools, not just sensitive ones" + } + }, + "configSchema": { + "type": "object", + "additionalProperties": false, + "properties": { + "clientId": { + "type": "string" + }, + "tokenTtlMinutes": { + "type": "number" + }, + "sensitiveTools": { + "type": "array", + "items": { + "type": "string" + } + }, + "gateAllTools": { + "type": "boolean" + } + } + } +} diff --git a/extensions/2fa-github/package.json b/extensions/2fa-github/package.json new file mode 100644 index 000000000..ba4df4d69 --- /dev/null +++ b/extensions/2fa-github/package.json @@ -0,0 +1,14 @@ +{ + "name": "@moltbot/2fa-github", + "version": "0.1.0", + "type": "module", + "description": "GitHub Mobile 2FA gate for sensitive tool calls", + "moltbot": { + "extensions": [ + "./index.ts" + ] + }, + "devDependencies": { + "moltbot": "workspace:*" + } +} diff --git a/extensions/2fa-github/src/config.ts b/extensions/2fa-github/src/config.ts new file mode 100644 index 000000000..b93e523c6 --- /dev/null +++ b/extensions/2fa-github/src/config.ts @@ -0,0 +1,51 @@ +/** + * GitHub 2FA Extension Configuration + */ + +export type TwoFactorConfig = { + clientId?: string; + tokenTtlMinutes?: number; + sensitiveTools?: string[]; + gateAllTools?: boolean; +}; + +const DEFAULT_SENSITIVE_TOOLS = ["exec", "Bash", "Write", "Edit", "NotebookEdit"]; +const DEFAULT_TTL_MINUTES = 30; + +export function parseConfig(value: unknown): TwoFactorConfig { + if (!value || typeof value !== "object") return {}; + const cfg = value as Record; + return { + clientId: typeof cfg.clientId === "string" ? cfg.clientId : undefined, + tokenTtlMinutes: + typeof cfg.tokenTtlMinutes === "number" ? cfg.tokenTtlMinutes : DEFAULT_TTL_MINUTES, + sensitiveTools: Array.isArray(cfg.sensitiveTools) + ? cfg.sensitiveTools.filter((t): t is string => typeof t === "string") + : DEFAULT_SENSITIVE_TOOLS, + gateAllTools: typeof cfg.gateAllTools === "boolean" ? cfg.gateAllTools : false, + }; +} + +export const twoFactorConfigSchema = { + parse: parseConfig, + uiHints: { + clientId: { + label: "GitHub OAuth App Client ID", + placeholder: "Iv1.xxxxxxxxxxxxxxxx", + help: "Create at GitHub Settings > Developer Settings > OAuth Apps (enable Device Flow)", + }, + tokenTtlMinutes: { + label: "Session TTL (minutes)", + placeholder: "30", + help: "How long before re-authentication is required", + }, + sensitiveTools: { + label: "Sensitive Tools", + help: "Tool names requiring 2FA (default: Bash, Write, Edit, NotebookEdit)", + }, + gateAllTools: { + label: "Gate All Tools", + help: "Require 2FA for all tools, not just sensitive ones", + }, + }, +}; diff --git a/extensions/2fa-github/src/device-flow.ts b/extensions/2fa-github/src/device-flow.ts new file mode 100644 index 000000000..1333de657 --- /dev/null +++ b/extensions/2fa-github/src/device-flow.ts @@ -0,0 +1,210 @@ +/** + * GitHub Device Authorization Flow + * + * Implements the OAuth 2.0 Device Authorization Grant for GitHub. + * This allows authentication without a browser redirect, using GitHub Mobile + * push notifications or manual code entry at github.com/login/device. + * + * Reference: https://docs.github.com/en/apps/oauth-apps/building-oauth-apps/authorizing-oauth-apps#device-flow + */ + +import type { DeviceCodeResponse, DeviceTokenResponse } from "./types.js"; + +const DEVICE_CODE_URL = "https://github.com/login/device/code"; +const ACCESS_TOKEN_URL = "https://github.com/login/oauth/access_token"; +const USER_API_URL = "https://api.github.com/user"; + +function parseJsonResponse(value: unknown): T { + if (!value || typeof value !== "object") { + throw new Error("Unexpected response from GitHub"); + } + return value as T; +} + +/** + * Request a device code from GitHub. + * The user will use this code to authorize at github.com/login/device. + */ +export async function requestDeviceCode(clientId: string): Promise { + const body = new URLSearchParams({ + client_id: clientId, + scope: "read:user", + }); + + const res = await fetch(DEVICE_CODE_URL, { + method: "POST", + headers: { + Accept: "application/json", + "Content-Type": "application/x-www-form-urlencoded", + }, + body, + }); + + if (!res.ok) { + throw new Error(`GitHub device code request failed: HTTP ${res.status}`); + } + + const json = parseJsonResponse(await res.json()); + if (!json.device_code || !json.user_code || !json.verification_uri) { + throw new Error("GitHub device code response missing required fields"); + } + + return json; +} + +/** + * Poll for access token after user has authorized the device. + * + * @param params.clientId - GitHub OAuth App client ID + * @param params.deviceCode - Device code from requestDeviceCode() + * @param params.intervalMs - Minimum polling interval in milliseconds + * @param params.expiresAt - Timestamp when device code expires + * @returns Access token and GitHub username + */ +export async function pollForAccessToken(params: { + clientId: string; + deviceCode: string; + intervalMs: number; + expiresAt: number; +}): Promise<{ accessToken: string; login: string }> { + const body = new URLSearchParams({ + client_id: params.clientId, + device_code: params.deviceCode, + grant_type: "urn:ietf:params:oauth:grant-type:device_code", + }); + + while (Date.now() < params.expiresAt) { + const res = await fetch(ACCESS_TOKEN_URL, { + method: "POST", + headers: { + Accept: "application/json", + "Content-Type": "application/x-www-form-urlencoded", + }, + body, + }); + + if (!res.ok) { + throw new Error(`GitHub device token request failed: HTTP ${res.status}`); + } + + const json = parseJsonResponse(await res.json()); + + // Check for successful token response + if ("access_token" in json && typeof json.access_token === "string") { + // Fetch user info to get the GitHub login + const userRes = await fetch(USER_API_URL, { + headers: { + Authorization: `Bearer ${json.access_token}`, + Accept: "application/json", + }, + }); + + if (!userRes.ok) { + throw new Error(`Failed to fetch GitHub user info: HTTP ${userRes.status}`); + } + + const userJson = (await userRes.json()) as { login?: string }; + const login = userJson.login; + if (!login || typeof login !== "string") { + throw new Error("GitHub user response missing login field"); + } + + return { accessToken: json.access_token, login }; + } + + // Handle error responses + const err = "error" in json ? json.error : "unknown"; + + if (err === "authorization_pending") { + // User hasn't authorized yet, wait and try again + await new Promise((r) => setTimeout(r, params.intervalMs)); + continue; + } + + if (err === "slow_down") { + // Rate limited, wait longer + await new Promise((r) => setTimeout(r, params.intervalMs + 2000)); + continue; + } + + if (err === "expired_token") { + throw new Error("Device code expired"); + } + + if (err === "access_denied") { + throw new Error("Authorization denied by user"); + } + + throw new Error(`GitHub device flow error: ${err}`); + } + + throw new Error("Device code expired"); +} + +/** + * Quick poll - tries once and returns immediately. + * Used when checking if user has already approved on retry. + */ +export async function quickPollForAccessToken(params: { + clientId: string; + deviceCode: string; +}): Promise<{ accessToken: string; login: string } | "pending" | "expired" | "denied"> { + const body = new URLSearchParams({ + client_id: params.clientId, + device_code: params.deviceCode, + grant_type: "urn:ietf:params:oauth:grant-type:device_code", + }); + + const res = await fetch(ACCESS_TOKEN_URL, { + method: "POST", + headers: { + Accept: "application/json", + "Content-Type": "application/x-www-form-urlencoded", + }, + body, + }); + + if (!res.ok) { + throw new Error(`GitHub device token request failed: HTTP ${res.status}`); + } + + const json = parseJsonResponse(await res.json()); + + if ("access_token" in json && typeof json.access_token === "string") { + // Fetch user info + const userRes = await fetch(USER_API_URL, { + headers: { + Authorization: `Bearer ${json.access_token}`, + Accept: "application/json", + }, + }); + + if (!userRes.ok) { + throw new Error(`Failed to fetch GitHub user info: HTTP ${userRes.status}`); + } + + const userJson = (await userRes.json()) as { login?: string }; + const login = userJson.login; + if (!login || typeof login !== "string") { + throw new Error("GitHub user response missing login field"); + } + + return { accessToken: json.access_token, login }; + } + + const err = "error" in json ? json.error : "unknown"; + + if (err === "authorization_pending" || err === "slow_down") { + return "pending"; + } + + if (err === "expired_token") { + return "expired"; + } + + if (err === "access_denied") { + return "denied"; + } + + throw new Error(`GitHub device flow error: ${err}`); +} diff --git a/extensions/2fa-github/src/hook.ts b/extensions/2fa-github/src/hook.ts new file mode 100644 index 000000000..a682b5970 --- /dev/null +++ b/extensions/2fa-github/src/hook.ts @@ -0,0 +1,152 @@ +/** + * 2FA Hook Handler + * + * Registers a before_tool_call hook that gates sensitive tools behind + * GitHub Mobile push authentication using the Device Authorization Flow. + * + * Flow (non-blocking): + * 1. Tool call triggers hook + * 2. Check for valid session -> allow if valid + * 3. Check for pending verification -> quick poll, if approved store session and allow + * 4. No session/pending -> initiate device flow, store pending, return block with instructions + * 5. User approves on phone, retries request -> step 3 succeeds + */ + +import type { MoltbotPluginApi } from "clawdbot/plugin-sdk"; +import { requestDeviceCode, quickPollForAccessToken } from "./device-flow.js"; +import { + getSession, + setSession, + getPending, + setPending, + clearPending, +} from "./session-store.js"; +import { parseConfig, type TwoFactorConfig } from "./config.js"; + +const DEFAULT_SENSITIVE_TOOLS = ["exec", "Bash", "Write", "Edit", "NotebookEdit"]; + +export function register2FAHook(api: MoltbotPluginApi): void { + const cfg = parseConfig(api.pluginConfig); + const clientId = cfg.clientId ?? process.env.GITHUB_2FA_CLIENT_ID; + const ttlMinutes = cfg.tokenTtlMinutes ?? 30; + const sensitiveTools = cfg.sensitiveTools ?? DEFAULT_SENSITIVE_TOOLS; + const gateAllTools = cfg.gateAllTools ?? false; + + if (!clientId) { + api.logger.warn("2fa-github: No clientId configured, plugin disabled"); + api.logger.warn( + "2fa-github: Set plugins.entries.2fa-github.config.clientId in config or GITHUB_2FA_CLIENT_ID env var", + ); + return; + } + + api.on("before_tool_call", async (event, ctx) => { + // Check if this tool requires 2FA + if (!gateAllTools && !sensitiveTools.includes(event.toolName)) { + return; // Allow without 2FA + } + + const sessionKey = ctx.sessionKey ?? "default"; + + // Check for valid session first + const session = getSession(sessionKey); + if (session) { + api.logger.debug?.(`2fa-github: Valid session for ${session.githubLogin}`); + return; // Allow - valid session exists + } + + // Check for pending verification (user might be retrying after approval) + const pending = getPending(sessionKey); + if (pending) { + api.logger.info?.("2fa-github: Found pending verification, checking..."); + + try { + const result = await quickPollForAccessToken({ + clientId, + deviceCode: pending.deviceCode, + }); + + if (result === "pending") { + // Still pending - remind user to approve + return { + block: true, + blockReason: [ + "2FA approval still pending.", + "", + `Visit: ${pending.verificationUri}`, + `Code: ${pending.userCode}`, + "", + "Approve on GitHub Mobile (or enter code on website), then retry your request.", + ].join("\n"), + }; + } + + if (result === "expired") { + clearPending(sessionKey); + // Fall through to create new verification + } else if (result === "denied") { + clearPending(sessionKey); + return { + block: true, + blockReason: "2FA authorization was denied. Please try again.", + }; + } else { + // Success! Store session and allow + const now = new Date(); + const expiry = new Date(now.getTime() + ttlMinutes * 60 * 1000); + setSession(sessionKey, { + githubLogin: result.login, + verifiedAt: now.toISOString(), + expiresAt: expiry.toISOString(), + }); + api.logger.info?.(`2fa-github: Verified as ${result.login}`); + return; // Allow execution + } + } catch (err) { + api.logger.warn?.(`2fa-github: Poll error: ${String(err)}`); + clearPending(sessionKey); + // Fall through to create new verification + } + } + + // No session, no valid pending - initiate new device flow + api.logger.info?.("2fa-github: Initiating GitHub device flow"); + + try { + const device = await requestDeviceCode(clientId); + + // Store pending verification for retry + const expiresAt = new Date(Date.now() + device.expires_in * 1000); + setPending(sessionKey, { + deviceCode: device.device_code, + userCode: device.user_code, + verificationUri: device.verification_uri, + expiresAt: expiresAt.toISOString(), + intervalMs: Math.max(1000, device.interval * 1000), + }); + + // Return block with instructions (non-blocking - returns immediately) + return { + block: true, + blockReason: [ + "2FA verification required for this operation.", + "", + `Visit: ${device.verification_uri}`, + `Code: ${device.user_code}`, + "", + "Approve on GitHub Mobile (or enter code on website), then retry your request.", + ].join("\n"), + }; + } catch (err) { + api.logger.error?.(`2fa-github: Failed to initiate device flow: ${String(err)}`); + return { + block: true, + blockReason: `2FA verification failed: ${String(err)}`, + }; + } + }); + + api.logger.info?.( + `2fa-github: Enabled (TTL: ${ttlMinutes}min, tools: ${sensitiveTools.join(", ")})`, + ); +} diff --git a/extensions/2fa-github/src/session-store.ts b/extensions/2fa-github/src/session-store.ts new file mode 100644 index 000000000..36319c1ec --- /dev/null +++ b/extensions/2fa-github/src/session-store.ts @@ -0,0 +1,173 @@ +/** + * Session Store + * + * File-based storage for 2FA sessions and pending verifications. + * Sessions are keyed by sessionKey and include TTL handling. + */ + +import * as fs from "node:fs"; +import * as path from "node:path"; +import * as os from "node:os"; + +import type { Session, PendingVerification, SessionStore } from "./types.js"; + +const STORE_FILENAME = "2fa-sessions.json"; + +function getStorePath(): string { + return path.join(os.homedir(), ".clawdbot", STORE_FILENAME); +} + +function loadStore(): SessionStore { + const storePath = getStorePath(); + + if (!fs.existsSync(storePath)) { + return { version: 1, sessions: {}, pending: {} }; + } + + try { + const data = JSON.parse(fs.readFileSync(storePath, "utf-8")); + return { + version: 1, + sessions: data.sessions ?? {}, + pending: data.pending ?? {}, + }; + } catch { + // Corrupted file, start fresh + return { version: 1, sessions: {}, pending: {} }; + } +} + +function saveStore(store: SessionStore): void { + const storePath = getStorePath(); + const dir = path.dirname(storePath); + + if (!fs.existsSync(dir)) { + fs.mkdirSync(dir, { recursive: true }); + } + + fs.writeFileSync(storePath, JSON.stringify(store, null, 2)); +} + +/** + * Prune expired entries from the store. + */ +function pruneExpired(store: SessionStore): void { + const now = new Date(); + + // Prune expired sessions + for (const [key, session] of Object.entries(store.sessions)) { + if (new Date(session.expiresAt) < now) { + delete store.sessions[key]; + } + } + + // Prune expired pending verifications + for (const [key, pending] of Object.entries(store.pending)) { + if (new Date(pending.expiresAt) < now) { + delete store.pending[key]; + } + } +} + +/** + * Get a valid session for the given key. + * Returns undefined if no valid session exists. + */ +export function getSession(sessionKey: string): Session | undefined { + const store = loadStore(); + const session = store.sessions[sessionKey]; + + if (!session) return undefined; + + // Check if expired + if (new Date(session.expiresAt) < new Date()) { + delete store.sessions[sessionKey]; + saveStore(store); + return undefined; + } + + return session; +} + +/** + * Set a session for the given key. + * Also clears any pending verification for this key. + */ +export function setSession(sessionKey: string, session: Session): void { + const store = loadStore(); + + // Store the new session + store.sessions[sessionKey] = session; + + // Clear pending verification on successful auth + delete store.pending[sessionKey]; + + // Prune expired entries + pruneExpired(store); + + saveStore(store); +} + +/** + * Get a pending verification for the given key. + * Returns undefined if no valid pending verification exists. + */ +export function getPending(sessionKey: string): PendingVerification | undefined { + const store = loadStore(); + const pending = store.pending[sessionKey]; + + if (!pending) return undefined; + + // Check if expired + if (new Date(pending.expiresAt) < new Date()) { + delete store.pending[sessionKey]; + saveStore(store); + return undefined; + } + + return pending; +} + +/** + * Set a pending verification for the given key. + */ +export function setPending(sessionKey: string, pending: PendingVerification): void { + const store = loadStore(); + store.pending[sessionKey] = pending; + pruneExpired(store); + saveStore(store); +} + +/** + * Clear a pending verification for the given key. + */ +export function clearPending(sessionKey: string): void { + const store = loadStore(); + + if (store.pending[sessionKey]) { + delete store.pending[sessionKey]; + saveStore(store); + } +} + +/** + * Clear all sessions and pending verifications. + * Useful for testing or manual reset. + */ +export function clearAll(): void { + const store = { version: 1 as const, sessions: {}, pending: {} }; + saveStore(store); +} + +/** + * Get statistics about the store. + */ +export function getStats(): { sessionCount: number; pendingCount: number } { + const store = loadStore(); + pruneExpired(store); + + return { + sessionCount: Object.keys(store.sessions).length, + pendingCount: Object.keys(store.pending).length, + }; +} diff --git a/extensions/2fa-github/src/types.ts b/extensions/2fa-github/src/types.ts new file mode 100644 index 000000000..6a7667dc1 --- /dev/null +++ b/extensions/2fa-github/src/types.ts @@ -0,0 +1,43 @@ +/** + * GitHub 2FA Extension Types + */ + +export type Session = { + githubLogin: string; + verifiedAt: string; + expiresAt: string; +}; + +export type PendingVerification = { + deviceCode: string; + userCode: string; + verificationUri: string; + expiresAt: string; + intervalMs: number; +}; + +export type SessionStore = { + version: 1; + sessions: Record; + pending: Record; +}; + +export type DeviceCodeResponse = { + device_code: string; + user_code: string; + verification_uri: string; + expires_in: number; + interval: number; +}; + +export type DeviceTokenResponse = + | { + access_token: string; + token_type: string; + scope?: string; + } + | { + error: string; + error_description?: string; + error_uri?: string; + }; diff --git a/src/agents/pi-embedded-runner/run/attempt.ts b/src/agents/pi-embedded-runner/run/attempt.ts index 46a53bd8f..717314bcf 100644 --- a/src/agents/pi-embedded-runner/run/attempt.ts +++ b/src/agents/pi-embedded-runner/run/attempt.ts @@ -85,6 +85,7 @@ import { getGlobalHookRunner } from "../../../plugins/hook-runner-global.js"; import { MAX_IMAGE_BYTES } from "../../../media/constants.js"; import type { EmbeddedRunAttemptParams, EmbeddedRunAttemptResult } from "./types.js"; import { detectAndLoadPromptImages } from "./images.js"; +import { wrapToolsWithHook } from "./tool-hook-wrapper.js"; export function injectHistoryImagesIntoMessages( messages: AgentMessage[], @@ -432,8 +433,15 @@ export async function runEmbeddedAttempt( model: params.model, }); + // Wrap tools with before_tool_call hook invocation before splitting + const toolHookCtx = { + agentId: params.sessionKey?.split(":")[0] ?? "main", + sessionKey: params.sessionKey, + }; + const wrappedTools = wrapToolsWithHook(tools, toolHookCtx); + const { builtInTools, customTools } = splitSdkTools({ - tools, + tools: wrappedTools, sandboxEnabled: !!sandbox?.enabled, }); diff --git a/src/agents/pi-embedded-runner/run/tool-hook-wrapper.ts b/src/agents/pi-embedded-runner/run/tool-hook-wrapper.ts new file mode 100644 index 000000000..6b1fd58b1 --- /dev/null +++ b/src/agents/pi-embedded-runner/run/tool-hook-wrapper.ts @@ -0,0 +1,79 @@ +/** + * Tool Hook Wrapper + * + * Wraps tool execute functions to invoke before_tool_call hooks before execution. + * If a hook returns { block: true }, the tool returns an error result instead of executing. + */ + +import type { AgentToolResult } from "@mariozechner/pi-agent-core"; + +import { getGlobalHookRunner } from "../../../plugins/hook-runner-global.js"; +import type { AnyAgentTool } from "../../pi-tools.types.js"; +import { log } from "../logger.js"; + +export type ToolHookContext = { + agentId?: string; + sessionKey?: string; +}; + +/** + * Create a blocked tool result with proper typing. + */ +function blockedResult(reason: string): AgentToolResult { + return { + content: [{ type: "text", text: reason }], + details: { blocked: true, reason }, + }; +} + +/** + * Wrap a tool with before_tool_call hook invocation. + * The hook can block execution or modify parameters. + */ +export function wrapToolWithHook(tool: AnyAgentTool, ctx: ToolHookContext): AnyAgentTool { + const originalExecute = tool.execute; + if (!originalExecute) return tool; + + return { + ...tool, + execute: async (toolCallId, params, signal, onUpdate) => { + const hookRunner = getGlobalHookRunner(); + + // Check if any before_tool_call hooks are registered + if (hookRunner?.hasHooks("before_tool_call")) { + try { + const hookResult = await hookRunner.runBeforeToolCall( + { toolName: tool.name, params: params as Record }, + { agentId: ctx.agentId, sessionKey: ctx.sessionKey, toolName: tool.name }, + ); + + // If hook wants to block execution + if (hookResult?.block) { + log.debug( + `Tool ${tool.name} blocked by before_tool_call hook: ${hookResult.blockReason ?? "no reason given"}`, + ); + return blockedResult(hookResult.blockReason ?? `Tool ${tool.name} blocked by plugin`); + } + + // If hook modified params, use the modified version + if (hookResult?.params) { + params = hookResult.params; + } + } catch (err) { + log.warn(`before_tool_call hook failed for ${tool.name}: ${String(err)}`); + // Continue with execution on hook error (fail-open for safety) + } + } + + // Execute the original tool + return originalExecute.call(tool, toolCallId, params, signal, onUpdate); + }, + }; +} + +/** + * Wrap multiple tools with hook invocation. + */ +export function wrapToolsWithHook(tools: AnyAgentTool[], ctx: ToolHookContext): AnyAgentTool[] { + return tools.map((tool) => wrapToolWithHook(tool, ctx)); +}