Compare commits

...

2 Commits

Author SHA1 Message Date
Peter Steinberger
d1ccc1e542 fix: align prompt failover fallback gating (#1136) (thanks @cheeeee) 2026-01-18 03:31:38 +00:00
Mykyta Bozhenko
62324eed0b fix(agent): Enable model fallback for prompt-phase quota/rate limit errors
When a prompt submission fails with quota or rate limit errors, throw
FailoverError instead of the raw promptError. This enables the model
fallback system to try alternative models.

Previously, rate limit errors during the prompt phase (before streaming)
were thrown directly, bypassing fallback. Only response-phase errors
triggered model fallback.

Now checks if fallback models are configured and the error is failover-
eligible. If so, wraps in FailoverError to trigger the fallback chain.
2026-01-18 03:15:28 +00:00
5 changed files with 184 additions and 18 deletions

View File

@ -10,6 +10,9 @@ Docs: https://docs.clawd.bot
- Docs: document plugin slots and memory plugin behavior. - Docs: document plugin slots and memory plugin behavior.
- Plugins: migrate bundled messaging extensions to the plugin SDK; resolve plugin-sdk imports in loader. - Plugins: migrate bundled messaging extensions to the plugin SDK; resolve plugin-sdk imports in loader.
### Fixes
- Agents: trigger model fallback for prompt-phase failover errors, respecting per-agent overrides. (#1136) — thanks @cheeeee.
## 2026.1.17-5 ## 2026.1.17-5
### Changes ### Changes

View File

@ -1,7 +1,7 @@
import { describe, expect, it, vi } from "vitest"; import { describe, expect, it, vi } from "vitest";
import type { ClawdbotConfig } from "../config/config.js"; import type { ClawdbotConfig } from "../config/config.js";
import { runWithModelFallback } from "./model-fallback.js"; import { hasModelFallbackCandidates, runWithModelFallback } from "./model-fallback.js";
function makeCfg(overrides: Partial<ClawdbotConfig> = {}): ClawdbotConfig { function makeCfg(overrides: Partial<ClawdbotConfig> = {}): ClawdbotConfig {
return { return {
@ -310,3 +310,73 @@ describe("runWithModelFallback", () => {
expect(result.model).toBe("gpt-4.1-mini"); expect(result.model).toBe("gpt-4.1-mini");
}); });
}); });
describe("hasModelFallbackCandidates", () => {
it("returns false when only the primary candidate is available", () => {
const cfg = makeCfg({
agents: {
defaults: {
model: {
primary: "openai/gpt-4.1-mini",
fallbacks: [],
},
},
},
});
expect(
hasModelFallbackCandidates({
cfg,
provider: "openai",
model: "gpt-4.1-mini",
}),
).toBe(false);
});
it("returns true when the configured primary differs from the requested model", () => {
const cfg = makeCfg({
agents: {
defaults: {
model: {
primary: "openai/gpt-4.1-mini",
fallbacks: [],
},
},
},
});
expect(
hasModelFallbackCandidates({
cfg,
provider: "openrouter",
model: "meta-llama/llama-3.3-70b:free",
}),
).toBe(true);
});
it("honors an explicit empty fallbacksOverride", () => {
const cfg = makeCfg();
expect(
hasModelFallbackCandidates({
cfg,
provider: "openai",
model: "gpt-4.1-mini",
fallbacksOverride: [],
}),
).toBe(false);
});
it("returns true when fallbacksOverride provides extra candidates", () => {
const cfg = makeCfg();
expect(
hasModelFallbackCandidates({
cfg,
provider: "openai",
model: "gpt-4.1-mini",
fallbacksOverride: ["openai/gpt-4.1"],
}),
).toBe(true);
});
});

View File

@ -176,6 +176,16 @@ function resolveFallbackCandidates(params: {
return candidates; return candidates;
} }
export function hasModelFallbackCandidates(params: {
cfg: ClawdbotConfig | undefined;
provider: string;
model: string;
/** Optional explicit fallbacks list; when provided (even empty), replaces agents.defaults.model.fallbacks. */
fallbacksOverride?: string[];
}): boolean {
return resolveFallbackCandidates(params).length > 1;
}
export async function runWithModelFallback<T>(params: { export async function runWithModelFallback<T>(params: {
cfg: ClawdbotConfig | undefined; cfg: ClawdbotConfig | undefined;
provider: string; provider: string;

View File

@ -4,6 +4,7 @@ import { enqueueCommandInLane } from "../../process/command-queue.js";
import { resolveUserPath } from "../../utils.js"; import { resolveUserPath } from "../../utils.js";
import { isMarkdownCapableMessageChannel } from "../../utils/message-channel.js"; import { isMarkdownCapableMessageChannel } from "../../utils/message-channel.js";
import { resolveClawdbotAgentDir } from "../agent-paths.js"; import { resolveClawdbotAgentDir } from "../agent-paths.js";
import { resolveAgentModelFallbacksOverride, resolveSessionAgentId } from "../agent-scope.js";
import { import {
markAuthProfileFailure, markAuthProfileFailure,
markAuthProfileGood, markAuthProfileGood,
@ -16,12 +17,13 @@ import {
resolveContextWindowInfo, resolveContextWindowInfo,
} from "../context-window-guard.js"; } from "../context-window-guard.js";
import { DEFAULT_CONTEXT_TOKENS, DEFAULT_MODEL, DEFAULT_PROVIDER } from "../defaults.js"; import { DEFAULT_CONTEXT_TOKENS, DEFAULT_MODEL, DEFAULT_PROVIDER } from "../defaults.js";
import { FailoverError, resolveFailoverStatus } from "../failover-error.js"; import { FailoverError, coerceToFailoverError, resolveFailoverStatus } from "../failover-error.js";
import { import {
ensureAuthProfileStore, ensureAuthProfileStore,
getApiKeyForModel, getApiKeyForModel,
resolveAuthProfileOrder, resolveAuthProfileOrder,
} from "../model-auth.js"; } from "../model-auth.js";
import { hasModelFallbackCandidates } from "../model-fallback.js";
import { ensureClawdbotModelsJson } from "../models-config.js"; import { ensureClawdbotModelsJson } from "../models-config.js";
import { import {
classifyFailoverReason, classifyFailoverReason,
@ -30,7 +32,6 @@ import {
isCompactionFailureError, isCompactionFailureError,
isContextOverflowError, isContextOverflowError,
isFailoverAssistantError, isFailoverAssistantError,
isFailoverErrorMessage,
isRateLimitAssistantError, isRateLimitAssistantError,
isTimeoutErrorMessage, isTimeoutErrorMessage,
pickFallbackThinkingLevel, pickFallbackThinkingLevel,
@ -88,6 +89,20 @@ export async function runEmbeddedPiAgent(
if (!model) { if (!model) {
throw new Error(error ?? `Unknown model: ${provider}/${modelId}`); throw new Error(error ?? `Unknown model: ${provider}/${modelId}`);
} }
const sessionAgentId = resolveSessionAgentId({
sessionKey: params.sessionKey,
config: params.config,
});
const fallbacksOverride = resolveAgentModelFallbacksOverride(
params.config ?? {},
sessionAgentId,
);
const hasModelFallbacks = hasModelFallbackCandidates({
cfg: params.config,
provider,
model: modelId,
fallbacksOverride,
});
const ctxInfo = resolveContextWindowInfo({ const ctxInfo = resolveContextWindowInfo({
cfg: params.config, cfg: params.config,
@ -290,7 +305,14 @@ export async function runEmbeddedPiAgent(
}, },
}; };
} }
const promptFailoverReason = classifyFailoverReason(errorText); const promptFailoverError =
coerceToFailoverError(promptError, {
provider,
model: modelId,
profileId: lastProfileId,
}) ?? null;
const promptFailoverReason =
promptFailoverError?.reason ?? classifyFailoverReason(errorText);
if (promptFailoverReason && promptFailoverReason !== "timeout" && lastProfileId) { if (promptFailoverReason && promptFailoverReason !== "timeout" && lastProfileId) {
await markAuthProfileFailure({ await markAuthProfileFailure({
store: authStore, store: authStore,
@ -301,7 +323,7 @@ export async function runEmbeddedPiAgent(
}); });
} }
if ( if (
isFailoverErrorMessage(errorText) && promptFailoverReason &&
promptFailoverReason !== "timeout" && promptFailoverReason !== "timeout" &&
(await advanceAuthProfile()) (await advanceAuthProfile())
) { ) {
@ -318,6 +340,16 @@ export async function runEmbeddedPiAgent(
thinkLevel = fallbackThinking; thinkLevel = fallbackThinking;
continue; continue;
} }
if (promptFailoverReason && hasModelFallbacks) {
if (promptFailoverError) throw promptFailoverError;
throw new FailoverError(errorText, {
reason: promptFailoverReason,
provider,
model: modelId,
profileId: lastProfileId,
status: resolveFailoverStatus(promptFailoverReason),
});
}
throw promptError; throw promptError;
} }
@ -333,8 +365,6 @@ export async function runEmbeddedPiAgent(
continue; continue;
} }
const fallbackConfigured =
(params.config?.agents?.defaults?.model?.fallbacks?.length ?? 0) > 0;
const authFailure = isAuthAssistantError(lastAssistant); const authFailure = isAuthAssistantError(lastAssistant);
const rateLimitFailure = isRateLimitAssistantError(lastAssistant); const rateLimitFailure = isRateLimitAssistantError(lastAssistant);
const failoverFailure = isFailoverAssistantError(lastAssistant); const failoverFailure = isFailoverAssistantError(lastAssistant);
@ -372,7 +402,7 @@ export async function runEmbeddedPiAgent(
const rotated = await advanceAuthProfile(); const rotated = await advanceAuthProfile();
if (rotated) continue; if (rotated) continue;
if (fallbackConfigured) { if (hasModelFallbacks) {
// Prefer formatted error message (user-friendly) over raw errorMessage // Prefer formatted error message (user-friendly) over raw errorMessage
const message = const message =
(lastAssistant (lastAssistant

View File

@ -45,13 +45,21 @@ export function createPluginRuntime(): PluginRuntime {
hasControlCommand, hasControlCommand,
}, },
reply: { reply: {
dispatchReplyWithBufferedBlockDispatcher, dispatchReplyWithBufferedBlockDispatcher: async (params) => {
createReplyDispatcherWithTyping, await dispatchReplyWithBufferedBlockDispatcher(
params as Parameters<typeof dispatchReplyWithBufferedBlockDispatcher>[0],
);
},
createReplyDispatcherWithTyping: (...args) =>
createReplyDispatcherWithTyping(
...(args as Parameters<typeof createReplyDispatcherWithTyping>),
),
resolveEffectiveMessagesConfig, resolveEffectiveMessagesConfig,
resolveHumanDelayConfig, resolveHumanDelayConfig,
}, },
routing: { routing: {
resolveAgentRoute, resolveAgentRoute: (params) =>
resolveAgentRoute(params as Parameters<typeof resolveAgentRoute>[0]),
}, },
pairing: { pairing: {
buildPairingReply, buildPairingReply,
@ -60,19 +68,61 @@ export function createPluginRuntime(): PluginRuntime {
}, },
media: { media: {
fetchRemoteMedia, fetchRemoteMedia,
saveMediaBuffer, saveMediaBuffer: (buffer, contentType, direction, maxBytes) =>
saveMediaBuffer(
Buffer.isBuffer(buffer) ? buffer : Buffer.from(buffer),
contentType,
direction,
maxBytes,
),
}, },
mentions: { mentions: {
buildMentionRegexes, buildMentionRegexes,
matchesMentionPatterns, matchesMentionPatterns,
}, },
groups: { groups: {
resolveGroupPolicy: resolveChannelGroupPolicy, resolveGroupPolicy: (cfg, channel, accountId, groupId) =>
resolveRequireMention: resolveChannelGroupRequireMention, resolveChannelGroupPolicy({
cfg,
channel: channel as Parameters<typeof resolveChannelGroupPolicy>[0]["channel"],
accountId,
groupId,
}),
resolveRequireMention: (cfg, channel, accountId, groupId, override) =>
resolveChannelGroupRequireMention({
cfg,
channel: channel as Parameters<typeof resolveChannelGroupRequireMention>[0]["channel"],
accountId,
groupId,
requireMentionOverride: override,
}),
}, },
debounce: { debounce: {
createInboundDebouncer, createInboundDebouncer: (opts) => {
resolveInboundDebounceMs, const keys = new Set<string>();
const debouncer = createInboundDebouncer({
debounceMs: opts.debounceMs,
buildKey: (item) => {
const key = opts.buildKey(item);
if (key) keys.add(key);
return key;
},
shouldDebounce: opts.shouldDebounce,
onFlush: opts.onFlush,
onError: opts.onError ? (err) => opts.onError?.(err) : undefined,
});
return {
push: (value) => {
void debouncer.enqueue(value);
},
flush: async () => {
const pending = Array.from(keys);
keys.clear();
await Promise.all(pending.map((key) => debouncer.flushKey(key)));
},
};
},
resolveInboundDebounceMs: (cfg, channel) => resolveInboundDebounceMs({ cfg, channel }),
}, },
commands: { commands: {
resolveCommandAuthorizedFromAuthorizers, resolveCommandAuthorizedFromAuthorizers,
@ -81,7 +131,10 @@ export function createPluginRuntime(): PluginRuntime {
logging: { logging: {
shouldLogVerbose, shouldLogVerbose,
getChildLogger: (bindings, opts) => { getChildLogger: (bindings, opts) => {
const logger = getChildLogger(bindings, opts); const logger = getChildLogger(
bindings,
opts as Parameters<typeof getChildLogger>[1],
);
return { return {
debug: (message) => logger.debug?.(message), debug: (message) => logger.debug?.(message),
info: (message) => logger.info(message), info: (message) => logger.info(message),
@ -91,7 +144,7 @@ export function createPluginRuntime(): PluginRuntime {
}, },
}, },
state: { state: {
resolveStateDir, resolveStateDir: () => resolveStateDir(),
}, },
}; };
} }