diff --git a/src/auto-reply/commands-registry.data.ts b/src/auto-reply/commands-registry.data.ts index e0a3198c0..9bfa47640 100644 --- a/src/auto-reply/commands-registry.data.ts +++ b/src/auto-reply/commands-registry.data.ts @@ -511,6 +511,20 @@ function buildChatCommands(): ChatCommandDefinition[] { }, ], }), + defineChatCommand({ + key: "model_channel", + nativeName: "model_channel", + description: "Set or clear the default model for this channel.", + textAlias: "/model-channel", + category: "options", + args: [ + { + name: "model", + description: "Model id (provider/model or id), or 'clear' to remove override", + type: "string", + }, + ], + }), defineChatCommand({ key: "models", nativeName: "models", diff --git a/src/auto-reply/reply/commands-core.ts b/src/auto-reply/reply/commands-core.ts index a54f90b2b..352a922b7 100644 --- a/src/auto-reply/reply/commands-core.ts +++ b/src/auto-reply/reply/commands-core.ts @@ -16,7 +16,7 @@ import { import { handleAllowlistCommand } from "./commands-allowlist.js"; import { handleApproveCommand } from "./commands-approve.js"; import { handleSubagentsCommand } from "./commands-subagents.js"; -import { handleModelsCommand } from "./commands-models.js"; +import { handleModelsCommand, handleModelChannelCommand } from "./commands-models.js"; import { handleTtsCommands } from "./commands-tts.js"; import { handleAbortTrigger, @@ -53,6 +53,7 @@ const HANDLERS: CommandHandler[] = [ handleConfigCommand, handleDebugCommand, handleModelsCommand, + handleModelChannelCommand, handleStopCommand, handleCompactCommand, handleAbortTrigger, diff --git a/src/auto-reply/reply/commands-model-channel.ts b/src/auto-reply/reply/commands-model-channel.ts new file mode 100644 index 000000000..4f0d58c0e --- /dev/null +++ b/src/auto-reply/reply/commands-model-channel.ts @@ -0,0 +1,149 @@ +import { loadConfig, writeConfigFile } from "../../config/config.js"; +import { resolveChannelConfigWrites } from "../../channels/plugins/config-writes.js"; +import { resolveModelRefFromString, buildModelAliasIndex } from "../../agents/model-selection.js"; +import { DEFAULT_MODEL, DEFAULT_PROVIDER } from "../../agents/defaults.js"; +import type { MoltbotConfig } from "../../config/config.js"; +import type { ReplyPayload } from "../types.js"; +import type { CommandHandler } from "./commands-types.js"; + +export type ModelChannelCommandContext = { + cfg: MoltbotConfig; + commandBodyNormalized: string; + provider?: string; + surface?: string; + accountId?: string; + groupSpace?: string; + groupChannel?: string; + channelId?: string; + commandAuthorized?: boolean; +}; + +function extractChannelId(ctx: ModelChannelCommandContext): string | null { + // Try to extract channel ID from groupChannel (e.g., "#channel-name" -> id not available) + // or from channelId if passed directly + if (ctx.channelId) return ctx.channelId; + + // For Discord, the channel ID is typically in the "To" field as "channel:" + return null; +} + +export async function resolveModelChannelCommandReply( + ctx: ModelChannelCommandContext, +): Promise { + const body = ctx.commandBodyNormalized.trim(); + if (!body.startsWith("/model-channel") && !body.startsWith("/model_channel")) return null; + + const surface = ctx.surface?.toLowerCase() ?? ctx.provider?.toLowerCase(); + + // Only supported for Discord guild channels + if (surface !== "discord") { + return { + text: "The /model-channel command is only available in Discord servers.", + }; + } + + if (!ctx.groupSpace) { + return { + text: "The /model-channel command can only be used in server channels, not DMs.", + }; + } + + // Check if config writes are enabled + if ( + !resolveChannelConfigWrites({ cfg: ctx.cfg, channelId: "discord", accountId: ctx.accountId }) + ) { + return { + text: "Config writes are disabled for this account. Enable `configWrites` in your Discord config to use this command.", + }; + } + + // Check authorization + if (ctx.commandAuthorized === false) { + return { + text: "You are not authorized to use this command.", + }; + } + + const argText = body.replace(/^\/model[-_]channel\b/i, "").trim(); + const channelId = ctx.channelId; + const guildId = ctx.groupSpace; + + if (!channelId) { + return { + text: "Could not determine the current channel. Please try again or specify the channel in your config directly.", + }; + } + + // Handle clear/reset + if (!argText || argText.toLowerCase() === "clear" || argText.toLowerCase() === "reset") { + const currentConfig = loadConfig(); + const guildEntry = currentConfig.channels?.discord?.guilds?.[guildId]; + const channelEntry = guildEntry?.channels?.[channelId]; + + if (!channelEntry?.model) { + return { + text: `No model override is set for this channel.`, + }; + } + + // Remove the model override + delete channelEntry.model; + + // Clean up empty objects + if (Object.keys(channelEntry).length === 0 && guildEntry?.channels) { + delete guildEntry.channels[channelId]; + } + + await writeConfigFile(currentConfig); + return { + text: `Cleared model override for this channel. Messages will now use the guild or agent default model.`, + }; + } + + // Resolve the model reference + const aliasIndex = buildModelAliasIndex({ + cfg: ctx.cfg, + defaultProvider: DEFAULT_PROVIDER, + }); + + const resolved = resolveModelRefFromString({ + raw: argText, + defaultProvider: DEFAULT_PROVIDER, + aliasIndex, + }); + + if (!resolved) { + return { + text: `Unknown model "${argText}". Use /models to see available models.`, + }; + } + + const modelId = `${resolved.ref.provider}/${resolved.ref.model}`; + + // Update the config + const currentConfig = loadConfig(); + + // Ensure the path exists + if (!currentConfig.channels) currentConfig.channels = {}; + if (!currentConfig.channels.discord) currentConfig.channels.discord = {}; + if (!currentConfig.channels.discord.guilds) currentConfig.channels.discord.guilds = {}; + if (!currentConfig.channels.discord.guilds[guildId]) { + currentConfig.channels.discord.guilds[guildId] = {}; + } + if (!currentConfig.channels.discord.guilds[guildId].channels) { + currentConfig.channels.discord.guilds[guildId].channels = {}; + } + if (!currentConfig.channels.discord.guilds[guildId].channels[channelId]) { + currentConfig.channels.discord.guilds[guildId].channels[channelId] = {}; + } + + // Set the model + currentConfig.channels.discord.guilds[guildId].channels[channelId].model = modelId; + + await writeConfigFile(currentConfig); + + const displayModel = resolved.alias ? `${resolved.alias} (${modelId})` : modelId; + return { + text: `Set default model for this channel to **${displayModel}**. Use \`/model-channel clear\` to remove this override.`, + }; +} diff --git a/src/auto-reply/reply/commands-models.ts b/src/auto-reply/reply/commands-models.ts index e92960596..0f56b4539 100644 --- a/src/auto-reply/reply/commands-models.ts +++ b/src/auto-reply/reply/commands-models.ts @@ -241,3 +241,35 @@ export const handleModelsCommand: CommandHandler = async (params, allowTextComma if (!reply) return null; return { reply, shouldContinue: false }; }; + +export const handleModelChannelCommand: CommandHandler = async (params, allowTextCommands) => { + if (!allowTextCommands) return null; + + const { resolveModelChannelCommandReply } = await import("./commands-model-channel.js"); + const reply = await resolveModelChannelCommandReply({ + cfg: params.cfg, + commandBodyNormalized: params.command.commandBodyNormalized, + provider: params.ctx.Provider, + surface: params.ctx.Surface, + accountId: params.ctx.AccountId, + groupSpace: params.ctx.GroupSpace, + groupChannel: params.ctx.GroupChannel, + channelId: extractChannelIdFromTo(params.ctx.To), + commandAuthorized: params.command.isAuthorizedSender, + }); + if (!reply) return null; + return { reply, shouldContinue: false }; +}; + +function extractChannelIdFromTo(to?: string): string | undefined { + if (!to) return undefined; + // Handle "channel:" format + if (to.startsWith("channel:")) { + return to.slice("channel:".length); + } + // Handle "discord:channel:" format + if (to.startsWith("discord:channel:")) { + return to.slice("discord:channel:".length); + } + return undefined; +} diff --git a/src/auto-reply/reply/get-reply-directives.ts b/src/auto-reply/reply/get-reply-directives.ts index a62d1c476..f0b75d899 100644 --- a/src/auto-reply/reply/get-reply-directives.ts +++ b/src/auto-reply/reply/get-reply-directives.ts @@ -106,6 +106,8 @@ export async function resolveReplyDirectives(params: { typing: TypingController; opts?: GetReplyOptions; skillFilter?: string[]; + /** Channel-level model override (fallback when no session override). */ + channelModelOverride?: string; }): Promise { const { ctx, @@ -386,6 +388,7 @@ export async function resolveReplyDirectives(params: { provider, model, hasModelDirective: directives.hasModelDirective, + channelModelOverride: params.channelModelOverride ?? opts?.channelModelOverride, }); provider = modelState.provider; model = modelState.model; diff --git a/src/auto-reply/reply/get-reply.ts b/src/auto-reply/reply/get-reply.ts index eefd3cf87..8630928e9 100644 --- a/src/auto-reply/reply/get-reply.ts +++ b/src/auto-reply/reply/get-reply.ts @@ -166,6 +166,7 @@ export async function getReplyFromConfig( typing, opts, skillFilter: opts?.skillFilter, + channelModelOverride: opts?.channelModelOverride, }); if (directiveResult.kind === "reply") { return directiveResult.reply; diff --git a/src/auto-reply/reply/model-selection.ts b/src/auto-reply/reply/model-selection.ts index 01df5d113..6e1ca26e3 100644 --- a/src/auto-reply/reply/model-selection.ts +++ b/src/auto-reply/reply/model-selection.ts @@ -231,6 +231,8 @@ export async function createModelSelectionState(params: { provider: string; model: string; hasModelDirective: boolean; + /** Channel-level model override (used as fallback when no session override). */ + channelModelOverride?: string; }): Promise { const { cfg, @@ -242,6 +244,7 @@ export async function createModelSelectionState(params: { storePath, defaultProvider, defaultModel, + channelModelOverride, } = params; let provider = params.provider; @@ -310,6 +313,20 @@ export async function createModelSelectionState(params: { provider = candidateProvider; model = storedOverride.model; } + } else if (channelModelOverride?.trim()) { + // Apply channel-level model override as fallback when no session override exists. + const channelRef = resolveModelRefFromString({ + raw: channelModelOverride.trim(), + defaultProvider, + aliasIndex: { byAlias: new Map(), byKey: new Map() }, + }); + if (channelRef) { + const key = modelKey(channelRef.ref.provider, channelRef.ref.model); + if (allowedModelKeys.size === 0 || allowedModelKeys.has(key)) { + provider = channelRef.ref.provider; + model = channelRef.ref.model; + } + } } if (sessionEntry && sessionStore && sessionKey && sessionEntry.authProfileOverride) { diff --git a/src/auto-reply/types.ts b/src/auto-reply/types.ts index 1aa0fe067..3bda31225 100644 --- a/src/auto-reply/types.ts +++ b/src/auto-reply/types.ts @@ -39,6 +39,8 @@ export type GetReplyOptions = { skillFilter?: string[]; /** Mutable ref to track if a reply was sent (for Slack "first" threading mode). */ hasRepliedRef?: { value: boolean }; + /** Channel-level model override (fallback when no session override). */ + channelModelOverride?: string; }; export type ReplyPayload = { diff --git a/src/config/types.discord.ts b/src/config/types.discord.ts index 07d4e658f..96ca69b11 100644 --- a/src/config/types.discord.ts +++ b/src/config/types.discord.ts @@ -37,6 +37,8 @@ export type DiscordGuildChannelConfig = { users?: Array; /** Optional system prompt snippet for this channel. */ systemPrompt?: string; + /** Optional model override for this channel. */ + model?: string; }; export type DiscordReactionNotificationMode = "off" | "own" | "all" | "allowlist"; @@ -51,6 +53,8 @@ export type DiscordGuildEntry = { reactionNotifications?: DiscordReactionNotificationMode; users?: Array; channels?: Record; + /** Optional model override for this guild. */ + model?: string; }; export type DiscordActionConfig = { diff --git a/src/config/zod-schema.providers-core.ts b/src/config/zod-schema.providers-core.ts index ed7dda22a..8cb331fa6 100644 --- a/src/config/zod-schema.providers-core.ts +++ b/src/config/zod-schema.providers-core.ts @@ -196,6 +196,7 @@ export const DiscordGuildChannelSchema = z users: z.array(z.union([z.string(), z.number()])).optional(), systemPrompt: z.string().optional(), autoThread: z.boolean().optional(), + model: z.string().optional(), }) .strict(); @@ -208,6 +209,7 @@ export const DiscordGuildSchema = z reactionNotifications: z.enum(["off", "own", "all", "allowlist"]).optional(), users: z.array(z.union([z.string(), z.number()])).optional(), channels: z.record(z.string(), DiscordGuildChannelSchema.optional()).optional(), + model: z.string().optional(), }) .strict(); diff --git a/src/discord/monitor/allow-list.ts b/src/discord/monitor/allow-list.ts index 12c2d1d39..803dcf5b7 100644 --- a/src/discord/monitor/allow-list.ts +++ b/src/discord/monitor/allow-list.ts @@ -33,8 +33,10 @@ export type DiscordGuildEntryResolved = { users?: Array; systemPrompt?: string; autoThread?: boolean; + model?: string; } >; + model?: string; }; export type DiscordChannelConfigResolved = { @@ -47,6 +49,7 @@ export type DiscordChannelConfigResolved = { autoThread?: boolean; matchKey?: string; matchSource?: ChannelMatchSource; + model?: string; }; export function normalizeDiscordAllowList( @@ -215,6 +218,7 @@ function resolveDiscordChannelConfigEntry( users: entry.users, systemPrompt: entry.systemPrompt, autoThread: entry.autoThread, + model: entry.model, }; return resolved; } diff --git a/src/discord/monitor/message-handler.process.ts b/src/discord/monitor/message-handler.process.ts index 6d502be21..a6a4700fe 100644 --- a/src/discord/monitor/message-handler.process.ts +++ b/src/discord/monitor/message-handler.process.ts @@ -368,6 +368,9 @@ export async function processDiscordMessage(ctx: DiscordMessagePreflightContext) }).onReplyStart, }); + // Resolve channel-level model override: channel config takes precedence over guild config. + const channelModelOverride = channelConfig?.model ?? guildInfo?.model; + const { queuedFinal, counts } = await dispatchInboundMessage({ ctx: ctxPayload, cfg, @@ -382,6 +385,7 @@ export async function processDiscordMessage(ctx: DiscordMessagePreflightContext) onModelSelected: (ctx) => { prefixContext.onModelSelected(ctx); }, + channelModelOverride, }, }); markDispatchIdle();