diff --git a/src/infra/outbound/deliver.hooks.test.ts b/src/infra/outbound/deliver.hooks.test.ts new file mode 100644 index 000000000..c65a6eed8 --- /dev/null +++ b/src/infra/outbound/deliver.hooks.test.ts @@ -0,0 +1,193 @@ +import { describe, expect, it, vi, beforeEach, afterEach } from "vitest"; +import { deliverOutboundPayloads } from "./deliver.js"; +import type { MoltbotConfig } from "../../config/config.js"; +import { + initializeGlobalHookRunner, + resetGlobalHookRunner, +} from "../../plugins/hook-runner-global.js"; +import { createPluginRegistry } from "../../plugins/registry.js"; +import type { ReplyPayload } from "../../auto-reply/types.js"; +import type { PluginRecord } from "../../plugins/types.js"; + +// Mock the channel adapter loader +vi.mock("../../channels/plugins/outbound/load.js", () => ({ + loadChannelOutboundAdapter: vi.fn().mockImplementation(async (channel) => { + return { + sendText: vi.fn().mockResolvedValue({ channel, messageId: "msg-123" }), + sendMedia: vi.fn().mockResolvedValue({ channel, messageId: "msg-123" }), + }; + }), +})); + +describe("deliverOutboundPayloads hooks", () => { + const cfg = { + channels: { + telegram: { enabled: true }, + }, + } as MoltbotConfig; + + // Mock plugin record for registration + const mockPluginRecord = { + id: "test-plugin", + name: "Test Plugin", + source: "test", + origin: "workspace", + enabled: true, + status: "loaded", + hookCount: 0, + } as PluginRecord; + + const createTestRegistry = () => { + return createPluginRegistry({ + logger: { + info: vi.fn(), + warn: vi.fn(), + error: vi.fn(), + debug: vi.fn(), + }, + runtime: {} as any, + }); + }; + + beforeEach(() => { + resetGlobalHookRunner(); + }); + + afterEach(() => { + resetGlobalHookRunner(); + vi.restoreAllMocks(); + }); + + it("should ignore hooks when no runner is initialized", async () => { + // NOT initializing global hook runner + + const payloads: ReplyPayload[] = [{ text: "hello" }]; + const results = await deliverOutboundPayloads({ + cfg, + channel: "telegram", + to: "123", + payloads, + }); + + expect(results).toHaveLength(1); + expect(results[0].messageId).toBe("msg-123"); + }); + + it("should call message_sending hook", async () => { + const registryHelper = createTestRegistry(); + const hookFn = vi.fn(); + + registryHelper.registerTypedHook(mockPluginRecord, "message_sending", hookFn); + + initializeGlobalHookRunner(registryHelper.registry); + + const payloads: ReplyPayload[] = [{ text: "hello" }]; + await deliverOutboundPayloads({ + cfg, + channel: "telegram", + to: "chat-123", + accountId: "acc-1", + payloads, + }); + + expect(hookFn).toHaveBeenCalledTimes(1); + expect(hookFn).toHaveBeenCalledWith( + expect.objectContaining({ + content: "hello", + to: "chat-123", + }), + expect.objectContaining({ + channelId: "telegram", + accountId: "acc-1", + }), + ); + }); + + it("should allow message_sending hook to modify content", async () => { + const registryHelper = createTestRegistry(); + + registryHelper.registerTypedHook(mockPluginRecord, "message_sending", async (event) => ({ + content: event.content + " world", + })); + + initializeGlobalHookRunner(registryHelper.registry); + + // We need to spy on the actual send function to verify the content + const { loadChannelOutboundAdapter } = await import("../../channels/plugins/outbound/load.js"); + const sendTextMock = vi.fn().mockResolvedValue({ channel: "telegram", messageId: "msg-1" }); + + vi.mocked(loadChannelOutboundAdapter).mockResolvedValue({ + sendText: sendTextMock, + sendMedia: vi.fn(), + } as any); + + await deliverOutboundPayloads({ + cfg, + channel: "telegram", + to: "123", + payloads: [{ text: "hello" }], + }); + + expect(sendTextMock).toHaveBeenCalledWith( + expect.objectContaining({ + text: "hello world", + }), + ); + }); + + it("should allow message_sending hook to cancel delivery", async () => { + const registryHelper = createTestRegistry(); + + registryHelper.registerTypedHook(mockPluginRecord, "message_sending", async () => ({ + cancel: true, + })); + + initializeGlobalHookRunner(registryHelper.registry); + + const { loadChannelOutboundAdapter } = await import("../../channels/plugins/outbound/load.js"); + const sendTextMock = vi.fn().mockResolvedValue({ channel: "telegram", messageId: "msg-1" }); + + vi.mocked(loadChannelOutboundAdapter).mockResolvedValue({ + sendText: sendTextMock, + sendMedia: vi.fn(), + } as any); + + const results = await deliverOutboundPayloads({ + cfg, + channel: "telegram", + to: "123", + payloads: [{ text: "hello" }], + }); + + expect(results).toHaveLength(0); + expect(sendTextMock).not.toHaveBeenCalled(); + }); + + it("should call message_sent hook after success", async () => { + const registryHelper = createTestRegistry(); + const sentHook = vi.fn(); + + registryHelper.registerTypedHook(mockPluginRecord, "message_sent", sentHook); + + initializeGlobalHookRunner(registryHelper.registry); + + await deliverOutboundPayloads({ + cfg, + channel: "telegram", + to: "chat-123", + payloads: [{ text: "hello" }], + }); + + expect(sentHook).toHaveBeenCalledTimes(1); + expect(sentHook).toHaveBeenCalledWith( + expect.objectContaining({ + to: "chat-123", + content: "hello", + success: true, + }), + expect.objectContaining({ + channelId: "telegram", + }), + ); + }); +}); diff --git a/src/infra/outbound/deliver.ts b/src/infra/outbound/deliver.ts index 1abbd3557..c016537ee 100644 --- a/src/infra/outbound/deliver.ts +++ b/src/infra/outbound/deliver.ts @@ -21,6 +21,8 @@ import { appendAssistantMessageToSessionTranscript, resolveMirroredTranscriptText, } from "../../config/sessions.js"; +import { getGlobalHookRunner } from "../../plugins/hook-runner-global.js"; +import { logVerbose } from "../../globals.js"; import type { NormalizedOutboundPayload } from "./payloads.js"; import { normalizeReplyPayloadsForDelivery } from "./payloads.js"; import type { OutboundChannel } from "./targets.js"; @@ -312,12 +314,51 @@ export async function deliverOutboundPayloads(params: { }; }; const normalizedPayloads = normalizeReplyPayloadsForDelivery(payloads); + const hookRunner = getGlobalHookRunner(); + for (const payload of normalizedPayloads) { const payloadSummary: NormalizedOutboundPayload = { text: payload.text ?? "", mediaUrls: payload.mediaUrls ?? (payload.mediaUrl ? [payload.mediaUrl] : []), channelData: payload.channelData, }; + + if (hookRunner?.hasHooks("message_sending")) { + try { + const hookResult = await hookRunner.runMessageSending( + { + to, + content: payloadSummary.text, + metadata: { + channelData: payloadSummary.channelData, + mediaUrls: payloadSummary.mediaUrls, + }, + }, + { + channelId: channel, + accountId, + conversationId: to, + }, + ); + + if (hookResult?.cancel) { + logVerbose("deliver: message cancelled by message_sending hook"); + continue; + } + + if (hookResult?.content !== undefined) { + payloadSummary.text = hookResult.content; + // Update original payload text as well since it might be used by sendPayload + payload.text = hookResult.content; + } + } catch (err) { + logVerbose(`deliver: message_sending hook failed: ${String(err)}`); + } + } + + const startResultCount = results.length; + let deliveryError: string | undefined; + try { throwIfAborted(abortSignal); params.onPayload?.(payloadSummary); @@ -347,7 +388,29 @@ export async function deliverOutboundPayloads(params: { } } catch (err) { if (!params.bestEffort) throw err; + deliveryError = String(err); params.onError?.(err, payloadSummary); + } finally { + if (hookRunner?.hasHooks("message_sent")) { + const success = results.length > startResultCount; + void hookRunner + .runMessageSent( + { + to, + content: payloadSummary.text, + success, + error: deliveryError, + }, + { + channelId: channel, + accountId, + conversationId: to, + }, + ) + .catch((err) => { + logVerbose(`deliver: message_sent hook failed: ${String(err)}`); + }); + } } } if (params.mirror && results.length > 0) {