From ab9bdbe4dbb2f57ce037df84b6b022155e2b6a6d Mon Sep 17 00:00:00 2001 From: TideFinder Date: Thu, 29 Jan 2026 11:51:19 +0900 Subject: [PATCH] feat: add OpenRouter embedding batch support Amp-Thread-ID: https://ampcode.com/threads/T-019c032e-16c5-7589-b68b-4e2651a1e631 Co-authored-by: Amp --- msg.txt | 1 + src/agents/memory-search.ts | 15 +++- src/config/schema.ts | 5 +- src/config/types.tools.ts | 4 +- src/memory/embeddings-openrouter.ts | 86 +++++++++++++++++++ src/memory/embeddings.test.ts | 125 ++++++++++++++++++++++++++++ src/memory/embeddings.ts | 24 ++++-- src/memory/manager.ts | 29 ++++++- src/memory/provider-key.ts | 12 +++ 9 files changed, 282 insertions(+), 19 deletions(-) create mode 100644 msg.txt create mode 100644 src/memory/embeddings-openrouter.ts diff --git a/msg.txt b/msg.txt new file mode 100644 index 000000000..cfefd7bb0 --- /dev/null +++ b/msg.txt @@ -0,0 +1 @@ +feat: add OpenRouter embedding batch support diff --git a/src/agents/memory-search.ts b/src/agents/memory-search.ts index c08161d4f..f94800578 100644 --- a/src/agents/memory-search.ts +++ b/src/agents/memory-search.ts @@ -9,7 +9,7 @@ import { resolveAgentConfig } from "./agent-scope.js"; export type ResolvedMemorySearchConfig = { enabled: boolean; sources: Array<"memory" | "sessions">; - provider: "openai" | "local" | "gemini" | "auto"; + provider: "openai" | "local" | "gemini" | "openrouter" | "auto"; remote?: { baseUrl?: string; apiKey?: string; @@ -25,7 +25,7 @@ export type ResolvedMemorySearchConfig = { experimental: { sessionMemory: boolean; }; - fallback: "openai" | "gemini" | "local" | "none"; + fallback: "openai" | "gemini" | "openrouter" | "local" | "none"; model: string; local: { modelPath?: string; @@ -72,6 +72,7 @@ export type ResolvedMemorySearchConfig = { const DEFAULT_OPENAI_MODEL = "text-embedding-3-small"; const DEFAULT_GEMINI_MODEL = "gemini-embedding-001"; +const DEFAULT_OPENROUTER_MODEL = "openai/text-embedding-3-small"; const DEFAULT_CHUNK_TOKENS = 400; const DEFAULT_CHUNK_OVERLAP = 80; const DEFAULT_WATCH_DEBOUNCE_MS = 1500; @@ -128,7 +129,11 @@ function mergeConfig( defaultRemote?.headers, ); const includeRemote = - hasRemoteConfig || provider === "openai" || provider === "gemini" || provider === "auto"; + hasRemoteConfig || + provider === "openai" || + provider === "gemini" || + provider === "openrouter" || + provider === "auto"; const batch = { enabled: overrideRemote?.batch?.enabled ?? defaultRemote?.batch?.enabled ?? true, wait: overrideRemote?.batch?.wait ?? defaultRemote?.batch?.wait ?? true, @@ -155,7 +160,9 @@ function mergeConfig( ? DEFAULT_GEMINI_MODEL : provider === "openai" ? DEFAULT_OPENAI_MODEL - : undefined; + : provider === "openrouter" + ? DEFAULT_OPENROUTER_MODEL + : undefined; const model = overrides?.model ?? defaults?.model ?? modelDefault ?? ""; const local = { modelPath: overrides?.local?.modelPath ?? defaults?.local?.modelPath, diff --git a/src/config/schema.ts b/src/config/schema.ts index b4ec8723b..4d5466a22 100644 --- a/src/config/schema.ts +++ b/src/config/schema.ts @@ -501,7 +501,8 @@ const FIELD_HELP: Record = { 'Sources to index for memory search (default: ["memory"]; add "sessions" to include session transcripts).', "agents.defaults.memorySearch.experimental.sessionMemory": "Enable experimental session transcript indexing for memory search (default: false).", - "agents.defaults.memorySearch.provider": 'Embedding provider ("openai", "gemini", or "local").', + "agents.defaults.memorySearch.provider": + 'Embedding provider ("openai", "gemini", "openrouter", or "local").', "agents.defaults.memorySearch.remote.baseUrl": "Custom base URL for remote embeddings (OpenAI-compatible proxies or Gemini overrides).", "agents.defaults.memorySearch.remote.apiKey": "Custom API key for the remote embedding provider.", @@ -520,7 +521,7 @@ const FIELD_HELP: Record = { "agents.defaults.memorySearch.local.modelPath": "Local GGUF model path or hf: URI (node-llama-cpp).", "agents.defaults.memorySearch.fallback": - 'Fallback provider when embeddings fail ("openai", "gemini", "local", or "none").', + 'Fallback provider when embeddings fail ("openai", "gemini", "openrouter", "local", or "none").', "agents.defaults.memorySearch.store.path": "SQLite index path (default: ~/.clawdbot/memory/{agentId}.sqlite).", "agents.defaults.memorySearch.store.vector.enabled": diff --git a/src/config/types.tools.ts b/src/config/types.tools.ts index bb1d45bf0..1f65bd739 100644 --- a/src/config/types.tools.ts +++ b/src/config/types.tools.ts @@ -232,7 +232,7 @@ export type MemorySearchConfig = { sessionMemory?: boolean; }; /** Embedding provider mode. */ - provider?: "openai" | "gemini" | "local"; + provider?: "openai" | "gemini" | "openrouter" | "local"; remote?: { baseUrl?: string; apiKey?: string; @@ -251,7 +251,7 @@ export type MemorySearchConfig = { }; }; /** Fallback behavior when embeddings fail. */ - fallback?: "openai" | "gemini" | "local" | "none"; + fallback?: "openai" | "gemini" | "openrouter" | "local" | "none"; /** Embedding model id (remote) or alias (local). */ model?: string; /** Local embedding settings (node-llama-cpp). */ diff --git a/src/memory/embeddings-openrouter.ts b/src/memory/embeddings-openrouter.ts new file mode 100644 index 000000000..b2e872bfe --- /dev/null +++ b/src/memory/embeddings-openrouter.ts @@ -0,0 +1,86 @@ +import { requireApiKey, resolveApiKeyForProvider } from "../agents/model-auth.js"; +import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.js"; + +export type OpenRouterEmbeddingClient = { + baseUrl: string; + headers: Record; + model: string; +}; + +export const DEFAULT_OPENROUTER_EMBEDDING_MODEL = "openai/text-embedding-3-small"; +const DEFAULT_OPENROUTER_BASE_URL = "https://openrouter.ai/api/v1"; + +export function normalizeOpenRouterModel(model: string): string { + const trimmed = model.trim(); + if (!trimmed) return DEFAULT_OPENROUTER_EMBEDDING_MODEL; + if (trimmed.startsWith("openrouter/")) return trimmed.slice("openrouter/".length); + return trimmed; +} + +export async function createOpenRouterEmbeddingProvider( + options: EmbeddingProviderOptions, +): Promise<{ provider: EmbeddingProvider; client: OpenRouterEmbeddingClient }> { + const client = await resolveOpenRouterEmbeddingClient(options); + const url = `${client.baseUrl.replace(/\/$/, "")}/embeddings`; + + const embed = async (input: string[]): Promise => { + if (input.length === 0) return []; + const res = await fetch(url, { + method: "POST", + headers: client.headers, + body: JSON.stringify({ model: client.model, input }), + }); + if (!res.ok) { + const text = await res.text(); + throw new Error(`openrouter embeddings failed: ${res.status} ${text}`); + } + const payload = (await res.json()) as { + data?: Array<{ embedding?: number[] }>; + }; + const data = payload.data ?? []; + return data.map((entry) => entry.embedding ?? []); + }; + + return { + provider: { + id: "openrouter", + model: client.model, + embedQuery: async (text) => { + const [vec] = await embed([text]); + return vec ?? []; + }, + embedBatch: embed, + }, + client, + }; +} + +export async function resolveOpenRouterEmbeddingClient( + options: EmbeddingProviderOptions, +): Promise { + const remote = options.remote; + const remoteApiKey = remote?.apiKey?.trim(); + const remoteBaseUrl = remote?.baseUrl?.trim(); + + const apiKey = remoteApiKey + ? remoteApiKey + : requireApiKey( + await resolveApiKeyForProvider({ + provider: "openrouter", + cfg: options.config, + agentDir: options.agentDir, + }), + "openrouter", + ); + + const providerConfig = options.config.models?.providers?.openrouter; + const baseUrl = remoteBaseUrl || providerConfig?.baseUrl?.trim() || DEFAULT_OPENROUTER_BASE_URL; + const headerOverrides = Object.assign({}, providerConfig?.headers, remote?.headers); + const headers: Record = { + "Content-Type": "application/json", + Authorization: `Bearer ${apiKey}`, + ...headerOverrides, + }; + const model = normalizeOpenRouterModel(options.model); + return { baseUrl, headers, model }; +} diff --git a/src/memory/embeddings.test.ts b/src/memory/embeddings.test.ts index 1809b24b8..86dfb35b5 100644 --- a/src/memory/embeddings.test.ts +++ b/src/memory/embeddings.test.ts @@ -264,6 +264,131 @@ describe("embedding provider auto selection", () => { }); }); +describe("embedding provider openrouter", () => { + afterEach(() => { + vi.resetAllMocks(); + vi.resetModules(); + vi.unstubAllGlobals(); + }); + + it("uses openrouter with default base url", async () => { + const fetchMock = createFetchMock(); + vi.stubGlobal("fetch", fetchMock); + + const { createEmbeddingProvider } = await import("./embeddings.js"); + const authModule = await import("../agents/model-auth.js"); + vi.mocked(authModule.resolveApiKeyForProvider).mockResolvedValue({ + apiKey: "openrouter-key", + mode: "api-key", + source: "env: OPENROUTER_API_KEY", + }); + + const result = await createEmbeddingProvider({ + config: {} as never, + provider: "openrouter", + model: "openai/text-embedding-3-small", + fallback: "none", + }); + + expect(result.provider.id).toBe("openrouter"); + expect(result.provider.model).toBe("openai/text-embedding-3-small"); + + await result.provider.embedQuery("hello"); + + const [url, init] = fetchMock.mock.calls[0] ?? []; + expect(url).toBe("https://openrouter.ai/api/v1/embeddings"); + const headers = (init?.headers ?? {}) as Record; + expect(headers.Authorization).toBe("Bearer openrouter-key"); + }); + + it("normalizes model name by stripping openrouter/ prefix", async () => { + const fetchMock = createFetchMock(); + vi.stubGlobal("fetch", fetchMock); + + const { createEmbeddingProvider } = await import("./embeddings.js"); + const authModule = await import("../agents/model-auth.js"); + vi.mocked(authModule.resolveApiKeyForProvider).mockResolvedValue({ + apiKey: "openrouter-key", + mode: "api-key", + source: "env: OPENROUTER_API_KEY", + }); + + const result = await createEmbeddingProvider({ + config: {} as never, + provider: "openrouter", + model: "openrouter/openai/text-embedding-3-small", + fallback: "none", + }); + + expect(result.provider.model).toBe("openai/text-embedding-3-small"); + }); + + it("uses openrouter in auto mode when openai/gemini are unavailable", async () => { + const fetchMock = createFetchMock(); + vi.stubGlobal("fetch", fetchMock); + + const { createEmbeddingProvider } = await import("./embeddings.js"); + const authModule = await import("../agents/model-auth.js"); + vi.mocked(authModule.resolveApiKeyForProvider).mockImplementation(async ({ provider }) => { + if (provider === "openai") { + throw new Error('No API key found for provider "openai".'); + } + if (provider === "google") { + throw new Error('No API key found for provider "google".'); + } + if (provider === "openrouter") { + return { apiKey: "openrouter-key", source: "env: OPENROUTER_API_KEY", mode: "api-key" }; + } + throw new Error(`Unexpected provider ${provider}`); + }); + + const result = await createEmbeddingProvider({ + config: {} as never, + provider: "auto", + model: "", + fallback: "none", + }); + + expect(result.requestedProvider).toBe("auto"); + expect(result.provider.id).toBe("openrouter"); + }); + + it("batch embeds multiple texts in one request", async () => { + const fetchMock = vi.fn(async () => ({ + ok: true, + status: 200, + json: async () => ({ + data: [{ embedding: [1, 2, 3] }, { embedding: [4, 5, 6] }, { embedding: [7, 8, 9] }], + }), + })) as unknown as typeof fetch; + vi.stubGlobal("fetch", fetchMock); + + const { createEmbeddingProvider } = await import("./embeddings.js"); + const authModule = await import("../agents/model-auth.js"); + vi.mocked(authModule.resolveApiKeyForProvider).mockResolvedValue({ + apiKey: "openrouter-key", + mode: "api-key", + source: "env: OPENROUTER_API_KEY", + }); + + const result = await createEmbeddingProvider({ + config: {} as never, + provider: "openrouter", + model: "openai/text-embedding-3-small", + fallback: "none", + }); + + const embeddings = await result.provider.embedBatch(["text1", "text2", "text3"]); + + expect(embeddings).toEqual([[1, 2, 3], [4, 5, 6], [7, 8, 9]]); + expect(fetchMock).toHaveBeenCalledTimes(1); + const body = JSON.parse(String(fetchMock.mock.calls[0]?.[1]?.body ?? "{}")) as { + input?: string[]; + }; + expect(body.input).toEqual(["text1", "text2", "text3"]); + }); +}); + describe("embedding provider local fallback", () => { afterEach(() => { vi.resetAllMocks(); diff --git a/src/memory/embeddings.ts b/src/memory/embeddings.ts index 98de1ab42..be02f5d37 100644 --- a/src/memory/embeddings.ts +++ b/src/memory/embeddings.ts @@ -5,10 +5,15 @@ import type { MoltbotConfig } from "../config/config.js"; import { resolveUserPath } from "../utils.js"; import { createGeminiEmbeddingProvider, type GeminiEmbeddingClient } from "./embeddings-gemini.js"; import { createOpenAiEmbeddingProvider, type OpenAiEmbeddingClient } from "./embeddings-openai.js"; +import { + createOpenRouterEmbeddingProvider, + type OpenRouterEmbeddingClient, +} from "./embeddings-openrouter.js"; import { importNodeLlamaCpp } from "./node-llama.js"; export type { GeminiEmbeddingClient } from "./embeddings-gemini.js"; export type { OpenAiEmbeddingClient } from "./embeddings-openai.js"; +export type { OpenRouterEmbeddingClient } from "./embeddings-openrouter.js"; export type EmbeddingProvider = { id: string; @@ -19,24 +24,25 @@ export type EmbeddingProvider = { export type EmbeddingProviderResult = { provider: EmbeddingProvider; - requestedProvider: "openai" | "local" | "gemini" | "auto"; - fallbackFrom?: "openai" | "local" | "gemini"; + requestedProvider: "openai" | "local" | "gemini" | "openrouter" | "auto"; + fallbackFrom?: "openai" | "local" | "gemini" | "openrouter"; fallbackReason?: string; openAi?: OpenAiEmbeddingClient; gemini?: GeminiEmbeddingClient; + openRouter?: OpenRouterEmbeddingClient; }; export type EmbeddingProviderOptions = { config: MoltbotConfig; agentDir?: string; - provider: "openai" | "local" | "gemini" | "auto"; + provider: "openai" | "local" | "gemini" | "openrouter" | "auto"; remote?: { baseUrl?: string; apiKey?: string; headers?: Record; }; model: string; - fallback: "openai" | "gemini" | "local" | "none"; + fallback: "openai" | "gemini" | "openrouter" | "local" | "none"; local?: { modelPath?: string; modelCacheDir?: string; @@ -116,7 +122,7 @@ export async function createEmbeddingProvider( const requestedProvider = options.provider; const fallback = options.fallback; - const createProvider = async (id: "openai" | "local" | "gemini") => { + const createProvider = async (id: "openai" | "local" | "gemini" | "openrouter") => { if (id === "local") { const provider = await createLocalEmbeddingProvider(options); return { provider }; @@ -125,11 +131,15 @@ export async function createEmbeddingProvider( const { provider, client } = await createGeminiEmbeddingProvider(options); return { provider, gemini: client }; } + if (id === "openrouter") { + const { provider, client } = await createOpenRouterEmbeddingProvider(options); + return { provider, openRouter: client }; + } const { provider, client } = await createOpenAiEmbeddingProvider(options); return { provider, openAi: client }; }; - const formatPrimaryError = (err: unknown, provider: "openai" | "local" | "gemini") => + const formatPrimaryError = (err: unknown, provider: "openai" | "local" | "gemini" | "openrouter") => provider === "local" ? formatLocalSetupError(err) : formatError(err); if (requestedProvider === "auto") { @@ -145,7 +155,7 @@ export async function createEmbeddingProvider( } } - for (const provider of ["openai", "gemini"] as const) { + for (const provider of ["openai", "gemini", "openrouter"] as const) { try { const result = await createProvider(provider); return { ...result, requestedProvider }; diff --git a/src/memory/manager.ts b/src/memory/manager.ts index 9a9991d10..c4eb2d186 100644 --- a/src/memory/manager.ts +++ b/src/memory/manager.ts @@ -19,9 +19,11 @@ import { type EmbeddingProviderResult, type GeminiEmbeddingClient, type OpenAiEmbeddingClient, + type OpenRouterEmbeddingClient, } from "./embeddings.js"; import { DEFAULT_GEMINI_EMBEDDING_MODEL } from "./embeddings-gemini.js"; import { DEFAULT_OPENAI_EMBEDDING_MODEL } from "./embeddings-openai.js"; +import { DEFAULT_OPENROUTER_EMBEDDING_MODEL } from "./embeddings-openrouter.js"; import { OPENAI_BATCH_ENDPOINT, type OpenAiBatchRequest, @@ -122,11 +124,12 @@ export class MemoryIndexManager { private readonly workspaceDir: string; private readonly settings: ResolvedMemorySearchConfig; private provider: EmbeddingProvider; - private readonly requestedProvider: "openai" | "local" | "gemini" | "auto"; - private fallbackFrom?: "openai" | "local" | "gemini"; + private readonly requestedProvider: "openai" | "local" | "gemini" | "openrouter" | "auto"; + private fallbackFrom?: "openai" | "local" | "gemini" | "openrouter"; private fallbackReason?: string; private openAi?: OpenAiEmbeddingClient; private gemini?: GeminiEmbeddingClient; + private openRouter?: OpenRouterEmbeddingClient; private batch: { enabled: boolean; wait: boolean; @@ -223,6 +226,7 @@ export class MemoryIndexManager { this.fallbackReason = params.providerResult.fallbackReason; this.openAi = params.providerResult.openAi; this.gemini = params.providerResult.gemini; + this.openRouter = params.providerResult.openRouter; this.sources = new Set(params.settings.sources); this.db = this.openDatabase(); this.providerKey = this.computeProviderKey(); @@ -1264,14 +1268,16 @@ export class MemoryIndexManager { const fallback = this.settings.fallback; if (!fallback || fallback === "none" || fallback === this.provider.id) return false; if (this.fallbackFrom) return false; - const fallbackFrom = this.provider.id as "openai" | "gemini" | "local"; + const fallbackFrom = this.provider.id as "openai" | "gemini" | "openrouter" | "local"; const fallbackModel = fallback === "gemini" ? DEFAULT_GEMINI_EMBEDDING_MODEL : fallback === "openai" ? DEFAULT_OPENAI_EMBEDDING_MODEL - : this.settings.model; + : fallback === "openrouter" + ? DEFAULT_OPENROUTER_EMBEDDING_MODEL + : this.settings.model; const fallbackResult = await createEmbeddingProvider({ config: this.cfg, @@ -1288,6 +1294,7 @@ export class MemoryIndexManager { this.provider = fallbackResult.provider; this.openAi = fallbackResult.openAi; this.gemini = fallbackResult.gemini; + this.openRouter = fallbackResult.openRouter; this.providerKey = this.computeProviderKey(); this.batch = this.resolveBatchConfig(); log.warn(`memory embeddings: switched to fallback provider (${fallback})`, { reason }); @@ -1704,6 +1711,20 @@ export class MemoryIndexManager { }), ); } + if (this.provider.id === "openrouter" && this.openRouter) { + const entries = Object.entries(this.openRouter.headers) + .filter(([key]) => key.toLowerCase() !== "authorization") + .sort(([a], [b]) => a.localeCompare(b)) + .map(([key, value]) => [key, value]); + return hashText( + JSON.stringify({ + provider: "openrouter", + baseUrl: this.openRouter.baseUrl, + model: this.openRouter.model, + headers: entries, + }), + ); + } return hashText(JSON.stringify({ provider: this.provider.id, model: this.provider.model })); } diff --git a/src/memory/provider-key.ts b/src/memory/provider-key.ts index 09485c0f2..2b4b7e9e3 100644 --- a/src/memory/provider-key.ts +++ b/src/memory/provider-key.ts @@ -6,6 +6,7 @@ export function computeEmbeddingProviderKey(params: { providerModel: string; openAi?: { baseUrl: string; model: string; headers: Record }; gemini?: { baseUrl: string; model: string; headers: Record }; + openRouter?: { baseUrl: string; model: string; headers: Record }; }): string { if (params.openAi) { const headerNames = fingerprintHeaderNames(params.openAi.headers); @@ -29,5 +30,16 @@ export function computeEmbeddingProviderKey(params: { }), ); } + if (params.openRouter) { + const headerNames = fingerprintHeaderNames(params.openRouter.headers); + return hashText( + JSON.stringify({ + provider: "openrouter", + baseUrl: params.openRouter.baseUrl, + model: params.openRouter.model, + headerNames, + }), + ); + } return hashText(JSON.stringify({ provider: params.providerId, model: params.providerModel })); }