diff --git a/src/memory/batch-mistral.ts b/src/memory/batch-mistral.ts new file mode 100644 index 000000000..911f84386 --- /dev/null +++ b/src/memory/batch-mistral.ts @@ -0,0 +1,173 @@ +import type { MistralEmbeddingClient } from "./embeddings-mistral.js"; +import { hashText } from "./internal.js"; + +export type MistralBatchRequest = { + custom_id: string; + text: string; +}; + +export type MistralBatchStatus = { + id?: string; + status?: string; + output?: Map; + error?: string; +}; + +const MISTRAL_BATCH_MAX_REQUESTS = 100; + +function getMistralBaseUrl(mistral: MistralEmbeddingClient): string { + return mistral.baseUrl?.replace(/\/$/, "") ?? ""; +} + +function getMistralHeaders(mistral: MistralEmbeddingClient): Record { + const headers = mistral.headers ? { ...mistral.headers } : {}; + if (!headers["Content-Type"] && !headers["content-type"]) { + headers["Content-Type"] = "application/json"; + } + return headers; +} + +function splitMistralBatchRequests(requests: MistralBatchRequest[]): MistralBatchRequest[][] { + if (requests.length <= MISTRAL_BATCH_MAX_REQUESTS) return [requests]; + const groups: MistralBatchRequest[][] = []; + for (let i = 0; i < requests.length; i += MISTRAL_BATCH_MAX_REQUESTS) { + groups.push(requests.slice(i, i + MISTRAL_BATCH_MAX_REQUESTS)); + } + return groups; +} + +async function submitMistralBatch(params: { + mistral: MistralEmbeddingClient; + requests: MistralBatchRequest[]; +}): Promise> { + if (params.requests.length === 0) return new Map(); + + const baseUrl = getMistralBaseUrl(params.mistral); + const url = `${baseUrl}/embeddings`; + + const byCustomId = new Map(); + + // Process all requests in one batch API call + const inputTexts = params.requests.map((req) => req.text); + + const res = await fetch(url, { + method: "POST", + headers: getMistralHeaders(params.mistral), + body: JSON.stringify({ + model: params.mistral.model, + input: inputTexts, + }), + }); + + if (!res.ok) { + const text = await res.text(); + throw new Error(`mistral batch failed: ${res.status} ${text}`); + } + + const payload = (await res.json()) as { + data?: Array<{ embedding?: number[]; index?: number }>; + error?: { message?: string }; + }; + + if (payload.error?.message) { + throw new Error(`mistral batch failed: ${payload.error.message}`); + } + + const data = payload.data ?? []; + if (data.length !== params.requests.length) { + throw new Error( + `mistral batch failed: expected ${params.requests.length} results, got ${data.length}`, + ); + } + + // Map results back to custom IDs + for (let i = 0; i < data.length; i++) { + const result = data[i]; + const customId = params.requests[i].custom_id; + const embedding = result.embedding ?? []; + if (embedding.length === 0) { + throw new Error(`mistral batch failed: empty embedding for ${customId}`); + } + byCustomId.set(customId, embedding); + } + + return byCustomId; +} + +async function runWithConcurrency(tasks: Array<() => Promise>, limit: number): Promise { + if (tasks.length === 0) return []; + const resolvedLimit = Math.max(1, Math.min(limit, tasks.length)); + const results: T[] = Array.from({ length: tasks.length }); + let next = 0; + let firstError: unknown = null; + + const workers = Array.from({ length: resolvedLimit }, async () => { + while (true) { + if (firstError) return; + const index = next; + next += 1; + if (index >= tasks.length) return; + try { + results[index] = await tasks[index](); + } catch (err) { + firstError = err; + return; + } + } + }); + + await Promise.allSettled(workers); + if (firstError) throw firstError; + return results; +} + +export async function runMistralEmbeddingBatches(params: { + mistral: MistralEmbeddingClient; + agentId: string; + requests: MistralBatchRequest[]; + wait: boolean; + pollIntervalMs: number; + timeoutMs: number; + concurrency: number; + debug?: (message: string, data?: Record) => void; +}): Promise> { + if (params.requests.length === 0) return new Map(); + + const groups = splitMistralBatchRequests(params.requests); + const byCustomId = new Map(); + + const tasks = groups.map((group, groupIndex) => async () => { + params.debug?.("memory embeddings: mistral batch start", { + group: groupIndex + 1, + groups: groups.length, + requests: group.length, + }); + + const results = await submitMistralBatch({ + mistral: params.mistral, + requests: group, + }); + + params.debug?.("memory embeddings: mistral batch complete", { + group: groupIndex + 1, + results: results.size, + }); + + // Merge results into main map + for (const [customId, embedding] of results.entries()) { + byCustomId.set(customId, embedding); + } + }); + + params.debug?.("memory embeddings: mistral batch submit", { + requests: params.requests.length, + groups: groups.length, + wait: params.wait, + concurrency: params.concurrency, + pollIntervalMs: params.pollIntervalMs, + timeoutMs: params.timeoutMs, + }); + + await runWithConcurrency(tasks, params.concurrency); + return byCustomId; +} diff --git a/src/memory/embeddings-mistral.ts b/src/memory/embeddings-mistral.ts new file mode 100644 index 000000000..f3a8ea50d --- /dev/null +++ b/src/memory/embeddings-mistral.ts @@ -0,0 +1,90 @@ +import { requireApiKey, resolveApiKeyForProvider } from "../agents/model-auth.js"; +import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.js"; + +export type MistralEmbeddingClient = { + baseUrl: string; + headers: Record; + model: string; +}; + +export const DEFAULT_MISTRAL_EMBEDDING_MODEL = "mistral-embed"; +const DEFAULT_MISTRAL_BASE_URL = "https://api.mistral.ai/v1"; + +export function normalizeMistralModel(model: string): string { + const trimmed = model.trim(); + if (!trimmed) return DEFAULT_MISTRAL_EMBEDDING_MODEL; + if (trimmed.startsWith("mistral/")) return trimmed.slice("mistral/".length); + return trimmed; +} + +export async function createMistralEmbeddingProvider( + options: EmbeddingProviderOptions, +): Promise<{ provider: EmbeddingProvider; client: MistralEmbeddingClient }> { + const client = await resolveMistralEmbeddingClient(options); + const url = `${client.baseUrl.replace(/\/$/, "")}/embeddings`; + + const embed = async (input: string[]): Promise => { + if (input.length === 0) return []; + const res = await fetch(url, { + method: "POST", + headers: client.headers, + body: JSON.stringify({ model: client.model, input }), + }); + if (!res.ok) { + const text = await res.text(); + throw new Error(`mistral embeddings failed: ${res.status} ${text}`); + } + const payload = (await res.json()) as { + data?: Array<{ embedding?: number[] }>; + error?: { message?: string }; + }; + if (payload.error?.message) { + throw new Error(`mistral embeddings failed: ${payload.error.message}`); + } + const data = payload.data ?? []; + return data.map((entry) => entry.embedding ?? []); + }; + + return { + provider: { + id: "mistral", + model: client.model, + embedQuery: async (text) => { + const [vec] = await embed([text]); + return vec ?? []; + }, + embedBatch: embed, + }, + client, + }; +} + +export async function resolveMistralEmbeddingClient( + options: EmbeddingProviderOptions, +): Promise { + const remote = options.remote; + const remoteApiKey = remote?.apiKey?.trim(); + const remoteBaseUrl = remote?.baseUrl?.trim(); + + const apiKey = remoteApiKey + ? remoteApiKey + : requireApiKey( + await resolveApiKeyForProvider({ + provider: "mistral", + cfg: options.config, + agentDir: options.agentDir, + }), + "mistral", + ); + + const providerConfig = options.config.models?.providers?.mistral; + const baseUrl = remoteBaseUrl || providerConfig?.baseUrl?.trim() || DEFAULT_MISTRAL_BASE_URL; + const headerOverrides = Object.assign({}, providerConfig?.headers, remote?.headers); + const headers: Record = { + "Content-Type": "application/json", + Authorization: `Bearer ${apiKey}`, + ...headerOverrides, + }; + const model = normalizeMistralModel(options.model); + return { baseUrl, headers, model }; +} diff --git a/src/memory/embeddings.test.ts b/src/memory/embeddings.test.ts index 1809b24b8..570edd23c 100644 --- a/src/memory/embeddings.test.ts +++ b/src/memory/embeddings.test.ts @@ -1,5 +1,6 @@ import { afterEach, describe, expect, it, vi } from "vitest"; +import { DEFAULT_MISTRAL_EMBEDDING_MODEL } from "./embeddings-mistral.js"; import { DEFAULT_GEMINI_EMBEDDING_MODEL } from "./embeddings-gemini.js"; vi.mock("../agents/model-auth.js", () => ({ @@ -18,6 +19,109 @@ const createFetchMock = () => })) as unknown as typeof fetch; describe("embedding provider remote overrides", () => { + it("builds Mistral embeddings requests with api key header", async () => { + const fetchMock = vi.fn(async () => ({ + ok: true, + status: 200, + json: async () => ({ data: [{ embedding: [1, 2, 3] }] }), + })) as unknown as typeof fetch; + vi.stubGlobal("fetch", fetchMock); + + const { createEmbeddingProvider } = await import("./embeddings.js"); + const authModule = await import("../agents/model-auth.js"); + vi.mocked(authModule.resolveApiKeyForProvider).mockResolvedValue({ + apiKey: "mistral-key", + mode: "api-key", + source: "test", + }); + + const cfg = { + models: { + providers: { + mistral: { + baseUrl: "https://api.mistral.ai/v1", + }, + }, + }, + }; + + const result = await createEmbeddingProvider({ + config: cfg as never, + provider: "mistral", + remote: { + apiKey: "mistral-key", + }, + model: "mistral-embed", + fallback: "openai", + }); + + await result.provider.embedQuery("hello"); + + const [url, init] = fetchMock.mock.calls[0] ?? []; + expect(url).toBe("https://api.mistral.ai/v1/embeddings"); + const headers = (init?.headers ?? {}) as Record; + expect(headers.Authorization).toBe("Bearer mistral-key"); + expect(headers["Content-Type"]).toBe("application/json"); + }); + + it("uses Mistral remote baseUrl/apiKey and merges headers", async () => { + const fetchMock = vi.fn(async () => ({ + ok: true, + status: 200, + json: async () => ({ data: [{ embedding: [1, 2, 3] }] }), + })) as unknown as typeof fetch; + vi.stubGlobal("fetch", fetchMock); + + const { createEmbeddingProvider } = await import("./embeddings.js"); + const authModule = await import("../agents/model-auth.js"); + vi.mocked(authModule.resolveApiKeyForProvider).mockResolvedValue({ + apiKey: "provider-key", + mode: "api-key", + source: "test", + }); + + const cfg = { + models: { + providers: { + mistral: { + baseUrl: "https://provider.example/v1", + headers: { + "X-Provider": "p", + "X-Shared": "provider", + }, + }, + }, + }, + }; + + const result = await createEmbeddingProvider({ + config: cfg as never, + provider: "mistral", + remote: { + baseUrl: "https://remote.example/v1", + apiKey: " remote-key ", + headers: { + "X-Shared": "remote", + "X-Remote": "r", + }, + }, + model: "mistral-embed", + fallback: "openai", + }); + + await result.provider.embedQuery("hello"); + + expect(authModule.resolveApiKeyForProvider).not.toHaveBeenCalled(); + const [url, init] = fetchMock.mock.calls[0] ?? []; + expect(url).toBe("https://remote.example/v1/embeddings"); + const headers = (init?.headers ?? {}) as Record; + expect(headers.Authorization).toBe("Bearer remote-key"); + expect(headers["Content-Type"]).toBe("application/json"); + expect(headers["X-Provider"]).toBe("p"); + expect(headers["X-Shared"]).toBe("remote"); + expect(headers["X-Remote"]).toBe("r"); + }); + afterEach(() => { vi.resetAllMocks(); vi.resetModules(); @@ -167,6 +271,27 @@ describe("embedding provider remote overrides", () => { }); describe("embedding provider auto selection", () => { + it("prefers mistral when a key resolves", async () => { + const { createEmbeddingProvider } = await import("./embeddings.js"); + const authModule = await import("../agents/model-auth.js"); + vi.mocked(authModule.resolveApiKeyForProvider).mockImplementation(async ({ provider }) => { + if (provider === "mistral") { + return { apiKey: "mistral-key", source: "env: MISTRAL_API_KEY", mode: "api-key" }; + } + throw new Error(`No API key found for provider "${provider}".`); + }); + + const result = await createEmbeddingProvider({ + config: {} as never, + provider: "auto", + model: "", + fallback: "none", + }); + + expect(result.requestedProvider).toBe("auto"); + expect(result.provider.id).toBe("mistral"); + }); + afterEach(() => { vi.resetAllMocks(); vi.resetModules(); diff --git a/src/memory/embeddings.ts b/src/memory/embeddings.ts index 993fe8124..73d1e035c 100644 --- a/src/memory/embeddings.ts +++ b/src/memory/embeddings.ts @@ -4,10 +4,15 @@ import type { Llama, LlamaEmbeddingContext, LlamaModel } from "node-llama-cpp"; import type { OpenClawConfig } from "../config/config.js"; import { resolveUserPath } from "../utils.js"; import { createGeminiEmbeddingProvider, type GeminiEmbeddingClient } from "./embeddings-gemini.js"; +import { + createMistralEmbeddingProvider, + type MistralEmbeddingClient, +} from "./embeddings-mistral.js"; import { createOpenAiEmbeddingProvider, type OpenAiEmbeddingClient } from "./embeddings-openai.js"; import { importNodeLlamaCpp } from "./node-llama.js"; export type { GeminiEmbeddingClient } from "./embeddings-gemini.js"; +export type { MistralEmbeddingClient } from "./embeddings-mistral.js"; export type { OpenAiEmbeddingClient } from "./embeddings-openai.js"; export type EmbeddingProvider = { @@ -19,24 +24,25 @@ export type EmbeddingProvider = { export type EmbeddingProviderResult = { provider: EmbeddingProvider; - requestedProvider: "openai" | "local" | "gemini" | "auto"; - fallbackFrom?: "openai" | "local" | "gemini"; + requestedProvider: "openai" | "local" | "gemini" | "mistral" | "auto"; + fallbackFrom?: "openai" | "local" | "gemini" | "mistral"; fallbackReason?: string; openAi?: OpenAiEmbeddingClient; gemini?: GeminiEmbeddingClient; + mistral?: MistralEmbeddingClient; }; export type EmbeddingProviderOptions = { config: OpenClawConfig; agentDir?: string; - provider: "openai" | "local" | "gemini" | "auto"; + provider: "openai" | "local" | "gemini" | "mistral" | "auto"; remote?: { baseUrl?: string; apiKey?: string; headers?: Record; }; model: string; - fallback: "openai" | "gemini" | "local" | "none"; + fallback: "openai" | "gemini" | "mistral" | "local" | "none"; local?: { modelPath?: string; modelCacheDir?: string; @@ -116,7 +122,7 @@ export async function createEmbeddingProvider( const requestedProvider = options.provider; const fallback = options.fallback; - const createProvider = async (id: "openai" | "local" | "gemini") => { + const createProvider = async (id: "openai" | "local" | "gemini" | "mistral") => { if (id === "local") { const provider = await createLocalEmbeddingProvider(options); return { provider }; @@ -125,11 +131,15 @@ export async function createEmbeddingProvider( const { provider, client } = await createGeminiEmbeddingProvider(options); return { provider, gemini: client }; } + if (id === "mistral") { + const { provider, client } = await createMistralEmbeddingProvider(options); + return { provider, mistral: client }; + } const { provider, client } = await createOpenAiEmbeddingProvider(options); return { provider, openAi: client }; }; - const formatPrimaryError = (err: unknown, provider: "openai" | "local" | "gemini") => + const formatPrimaryError = (err: unknown, provider: "openai" | "local" | "gemini" | "mistral") => provider === "local" ? formatLocalSetupError(err) : formatError(err); if (requestedProvider === "auto") { @@ -145,7 +155,7 @@ export async function createEmbeddingProvider( } } - for (const provider of ["openai", "gemini"] as const) { + for (const provider of ["openai", "gemini", "mistral"] as const) { try { const result = await createProvider(provider); return { ...result, requestedProvider }; @@ -224,3 +234,5 @@ function formatLocalSetupError(err: unknown): string { .filter(Boolean) .join("\n"); } + +export { createMistralEmbeddingProvider } from "./embeddings-mistral.js"; diff --git a/src/memory/manager.ts b/src/memory/manager.ts index aa3cb317d..2c8610855 100644 --- a/src/memory/manager.ts +++ b/src/memory/manager.ts @@ -19,9 +19,11 @@ import { type EmbeddingProvider, type EmbeddingProviderResult, type GeminiEmbeddingClient, + type MistralEmbeddingClient, type OpenAiEmbeddingClient, } from "./embeddings.js"; import { DEFAULT_GEMINI_EMBEDDING_MODEL } from "./embeddings-gemini.js"; +import { DEFAULT_MISTRAL_EMBEDDING_MODEL } from "./embeddings-mistral.js"; import { DEFAULT_OPENAI_EMBEDDING_MODEL } from "./embeddings-openai.js"; import { OPENAI_BATCH_ENDPOINT, @@ -29,6 +31,7 @@ import { runOpenAiEmbeddingBatches, } from "./batch-openai.js"; import { runGeminiEmbeddingBatches, type GeminiBatchRequest } from "./batch-gemini.js"; +import { runMistralEmbeddingBatches, type MistralBatchRequest } from "./batch-mistral.js"; import { buildFileEntry, chunkMarkdown, @@ -123,11 +126,12 @@ export class MemoryIndexManager { private readonly workspaceDir: string; private readonly settings: ResolvedMemorySearchConfig; private provider: EmbeddingProvider; - private readonly requestedProvider: "openai" | "local" | "gemini" | "auto"; - private fallbackFrom?: "openai" | "local" | "gemini"; + private readonly requestedProvider: "openai" | "local" | "gemini" | "mistral" | "auto"; + private fallbackFrom?: "openai" | "local" | "gemini" | "mistral"; private fallbackReason?: string; private openAi?: OpenAiEmbeddingClient; private gemini?: GeminiEmbeddingClient; + private mistral?: MistralEmbeddingClient; private batch: { enabled: boolean; wait: boolean; @@ -224,6 +228,7 @@ export class MemoryIndexManager { this.fallbackReason = params.providerResult.fallbackReason; this.openAi = params.providerResult.openAi; this.gemini = params.providerResult.gemini; + this.mistral = params.providerResult.mistral; this.sources = new Set(params.settings.sources); this.db = this.openDatabase(); this.providerKey = this.computeProviderKey(); @@ -1303,7 +1308,8 @@ export class MemoryIndexManager { const enabled = Boolean( batch?.enabled && ((this.openAi && this.provider.id === "openai") || - (this.gemini && this.provider.id === "gemini")), + (this.gemini && this.provider.id === "gemini") || + (this.mistral && this.provider.id === "mistral")), ); return { enabled, @@ -1318,14 +1324,16 @@ export class MemoryIndexManager { const fallback = this.settings.fallback; if (!fallback || fallback === "none" || fallback === this.provider.id) return false; if (this.fallbackFrom) return false; - const fallbackFrom = this.provider.id as "openai" | "gemini" | "local"; + const fallbackFrom = this.provider.id as "openai" | "gemini" | "mistral" | "local"; const fallbackModel = fallback === "gemini" ? DEFAULT_GEMINI_EMBEDDING_MODEL - : fallback === "openai" - ? DEFAULT_OPENAI_EMBEDDING_MODEL - : this.settings.model; + : fallback === "mistral" + ? DEFAULT_MISTRAL_EMBEDDING_MODEL + : fallback === "openai" + ? DEFAULT_OPENAI_EMBEDDING_MODEL + : this.settings.model; const fallbackResult = await createEmbeddingProvider({ config: this.cfg, @@ -1342,6 +1350,7 @@ export class MemoryIndexManager { this.provider = fallbackResult.provider; this.openAi = fallbackResult.openAi; this.gemini = fallbackResult.gemini; + this.mistral = fallbackResult.mistral; this.providerKey = this.computeProviderKey(); this.batch = this.resolveBatchConfig(); log.warn(`memory embeddings: switched to fallback provider (${fallback})`, { reason }); @@ -1758,6 +1767,20 @@ export class MemoryIndexManager { }), ); } + if (this.provider.id === "mistral" && this.mistral) { + const entries = Object.entries(this.mistral.headers) + .filter(([key]) => key.toLowerCase() !== "authorization") + .sort(([a], [b]) => a.localeCompare(b)) + .map(([key, value]) => [key, value]); + return hashText( + JSON.stringify({ + provider: "mistral", + baseUrl: this.mistral.baseUrl, + model: this.mistral.model, + headers: entries, + }), + ); + } return hashText(JSON.stringify({ provider: this.provider.id, model: this.provider.model })); } @@ -1772,6 +1795,9 @@ export class MemoryIndexManager { if (this.provider.id === "gemini" && this.gemini) { return this.embedChunksWithGeminiBatch(chunks, entry, source); } + if (this.provider.id === "mistral" && this.mistral) { + return this.embedChunksWithMistralBatch(chunks, entry, source); + } return this.embedChunksInBatches(chunks); } @@ -1918,6 +1944,75 @@ export class MemoryIndexManager { return embeddings; } + private async embedChunksWithMistralBatch( + chunks: MemoryChunk[], + entry: MemoryFileEntry | SessionFileEntry, + source: MemorySource, + ): Promise { + const mistral = this.mistral; + if (!mistral) { + return this.embedChunksInBatches(chunks); + } + if (chunks.length === 0) return []; + const cached = this.loadEmbeddingCache(chunks.map((chunk) => chunk.hash)); + const embeddings: number[][] = Array.from({ length: chunks.length }, () => []); + const missing: Array<{ index: number; chunk: MemoryChunk }> = []; + + for (let i = 0; i < chunks.length; i += 1) { + const chunk = chunks[i]; + const hit = chunk?.hash ? cached.get(chunk.hash) : undefined; + if (hit && hit.length > 0) { + embeddings[i] = hit; + } else if (chunk) { + missing.push({ index: i, chunk }); + } + } + + if (missing.length === 0) return embeddings; + + const requests: MistralBatchRequest[] = []; + const mapping = new Map(); + for (const item of missing) { + const chunk = item.chunk; + const customId = hashText( + `${source}:${entry.path}:${chunk.startLine}:${chunk.endLine}:${chunk.hash}:${item.index}`, + ); + mapping.set(customId, { index: item.index, hash: chunk.hash }); + requests.push({ + custom_id: customId, + text: chunk.text, + }); + } + + const batchResult = await this.runBatchWithFallback({ + provider: "mistral", + run: async () => + await runMistralEmbeddingBatches({ + mistral, + agentId: this.agentId, + requests, + wait: this.batch.wait, + concurrency: this.batch.concurrency, + pollIntervalMs: this.batch.pollIntervalMs, + timeoutMs: this.batch.timeoutMs, + debug: (message, data) => log.debug(message, { ...data, source, chunks: chunks.length }), + }), + fallback: async () => await this.embedChunksInBatches(chunks), + }); + if (Array.isArray(batchResult)) return batchResult; + const byCustomId = batchResult; + + const toCache: Array<{ hash: string; embedding: number[] }> = []; + for (const [customId, embedding] of byCustomId.entries()) { + const mapped = mapping.get(customId); + if (!mapped) continue; + embeddings[mapped.index] = embedding; + toCache.push({ hash: mapped.hash, embedding }); + } + this.upsertEmbeddingCache(toCache); + return embeddings; + } + private async embedBatchWithRetry(texts: string[]): Promise { if (texts.length === 0) return []; let attempt = 0;