diff --git a/src/memory/manager-search.test.ts b/src/memory/manager-search.test.ts new file mode 100644 index 000000000..9a188f3e6 --- /dev/null +++ b/src/memory/manager-search.test.ts @@ -0,0 +1,40 @@ +import { describe, expect, it, vi } from "vitest"; + +import { searchVector } from "./manager-search.js"; + +describe("memory vector search SQL", () => { + it("uses sqlite-vec knn query (MATCH + k) when available", async () => { + const rows = [ + { + id: "id-1", + path: "MEMORY.md", + start_line: 1, + end_line: 1, + text: "hello", + source: "memory", + dist: 0.1, + }, + ]; + const all = vi.fn((..._args: unknown[]) => rows); + const prepare = vi.fn((_sql: string) => ({ all })); + const db = { prepare } as unknown as Parameters[0]["db"]; + + const result = await searchVector({ + db, + vectorTable: "chunks_vec", + providerModel: "mock-model", + queryVec: [1, 2, 3], + limit: 5, + snippetMaxChars: 100, + ensureVectorReady: async () => true, + sourceFilterVec: { sql: "", params: [] }, + sourceFilterChunks: { sql: "", params: [] }, + }); + + expect(result).toHaveLength(1); + expect(prepare).toHaveBeenCalledTimes(1); + const sql = String(prepare.mock.calls[0]?.[0] ?? ""); + expect(sql).toContain("embedding MATCH ? AND k = ?"); + expect(sql).toContain("WITH knn AS"); + }); +}); diff --git a/src/memory/manager-search.ts b/src/memory/manager-search.ts index f065a96a5..18e0501af 100644 --- a/src/memory/manager-search.ts +++ b/src/memory/manager-search.ts @@ -31,19 +31,28 @@ export async function searchVector(params: { }): Promise { if (params.queryVec.length === 0 || params.limit <= 0) return []; if (await params.ensureVectorReady(params.queryVec.length)) { + const query = vectorToBlob(params.queryVec); const rows = params.db .prepare( - `SELECT c.id, c.path, c.start_line, c.end_line, c.text,\n` + + `WITH knn AS (\n` + + ` SELECT id\n` + + ` FROM ${params.vectorTable}\n` + + ` WHERE embedding MATCH ? AND k = ?\n` + + `)\n` + + `SELECT c.id, c.path, c.start_line, c.end_line, c.text,\n` + ` c.source,\n` + ` vec_distance_cosine(v.embedding, ?) AS dist\n` + - ` FROM ${params.vectorTable} v\n` + - ` JOIN chunks c ON c.id = v.id\n` + + ` FROM knn\n` + + ` JOIN chunks c ON c.id = knn.id\n` + + ` JOIN ${params.vectorTable} v ON v.id = knn.id\n` + ` WHERE c.model = ?${params.sourceFilterVec.sql}\n` + ` ORDER BY dist ASC\n` + ` LIMIT ?`, ) .all( - vectorToBlob(params.queryVec), + query, + params.limit, + query, params.providerModel, ...params.sourceFilterVec.params, params.limit,