Merge 8954e07cf3 into da71eaebd2
This commit is contained in:
commit
21da40d577
40
src/memory/manager-search.test.ts
Normal file
40
src/memory/manager-search.test.ts
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
import { describe, expect, it, vi } from "vitest";
|
||||||
|
|
||||||
|
import { searchVector } from "./manager-search.js";
|
||||||
|
|
||||||
|
describe("memory vector search SQL", () => {
|
||||||
|
it("uses sqlite-vec knn query (MATCH + k) when available", async () => {
|
||||||
|
const rows = [
|
||||||
|
{
|
||||||
|
id: "id-1",
|
||||||
|
path: "MEMORY.md",
|
||||||
|
start_line: 1,
|
||||||
|
end_line: 1,
|
||||||
|
text: "hello",
|
||||||
|
source: "memory",
|
||||||
|
dist: 0.1,
|
||||||
|
},
|
||||||
|
];
|
||||||
|
const all = vi.fn((..._args: unknown[]) => rows);
|
||||||
|
const prepare = vi.fn((_sql: string) => ({ all }));
|
||||||
|
const db = { prepare } as unknown as Parameters<typeof searchVector>[0]["db"];
|
||||||
|
|
||||||
|
const result = await searchVector({
|
||||||
|
db,
|
||||||
|
vectorTable: "chunks_vec",
|
||||||
|
providerModel: "mock-model",
|
||||||
|
queryVec: [1, 2, 3],
|
||||||
|
limit: 5,
|
||||||
|
snippetMaxChars: 100,
|
||||||
|
ensureVectorReady: async () => true,
|
||||||
|
sourceFilterVec: { sql: "", params: [] },
|
||||||
|
sourceFilterChunks: { sql: "", params: [] },
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(result).toHaveLength(1);
|
||||||
|
expect(prepare).toHaveBeenCalledTimes(1);
|
||||||
|
const sql = String(prepare.mock.calls[0]?.[0] ?? "");
|
||||||
|
expect(sql).toContain("embedding MATCH ? AND k = ?");
|
||||||
|
expect(sql).toContain("WITH knn AS");
|
||||||
|
});
|
||||||
|
});
|
||||||
@ -31,19 +31,28 @@ export async function searchVector(params: {
|
|||||||
}): Promise<SearchRowResult[]> {
|
}): Promise<SearchRowResult[]> {
|
||||||
if (params.queryVec.length === 0 || params.limit <= 0) return [];
|
if (params.queryVec.length === 0 || params.limit <= 0) return [];
|
||||||
if (await params.ensureVectorReady(params.queryVec.length)) {
|
if (await params.ensureVectorReady(params.queryVec.length)) {
|
||||||
|
const query = vectorToBlob(params.queryVec);
|
||||||
const rows = params.db
|
const rows = params.db
|
||||||
.prepare(
|
.prepare(
|
||||||
`SELECT c.id, c.path, c.start_line, c.end_line, c.text,\n` +
|
`WITH knn AS (\n` +
|
||||||
|
` SELECT id\n` +
|
||||||
|
` FROM ${params.vectorTable}\n` +
|
||||||
|
` WHERE embedding MATCH ? AND k = ?\n` +
|
||||||
|
`)\n` +
|
||||||
|
`SELECT c.id, c.path, c.start_line, c.end_line, c.text,\n` +
|
||||||
` c.source,\n` +
|
` c.source,\n` +
|
||||||
` vec_distance_cosine(v.embedding, ?) AS dist\n` +
|
` vec_distance_cosine(v.embedding, ?) AS dist\n` +
|
||||||
` FROM ${params.vectorTable} v\n` +
|
` FROM knn\n` +
|
||||||
` JOIN chunks c ON c.id = v.id\n` +
|
` JOIN chunks c ON c.id = knn.id\n` +
|
||||||
|
` JOIN ${params.vectorTable} v ON v.id = knn.id\n` +
|
||||||
` WHERE c.model = ?${params.sourceFilterVec.sql}\n` +
|
` WHERE c.model = ?${params.sourceFilterVec.sql}\n` +
|
||||||
` ORDER BY dist ASC\n` +
|
` ORDER BY dist ASC\n` +
|
||||||
` LIMIT ?`,
|
` LIMIT ?`,
|
||||||
)
|
)
|
||||||
.all(
|
.all(
|
||||||
vectorToBlob(params.queryVec),
|
query,
|
||||||
|
params.limit,
|
||||||
|
query,
|
||||||
params.providerModel,
|
params.providerModel,
|
||||||
...params.sourceFilterVec.params,
|
...params.sourceFilterVec.params,
|
||||||
params.limit,
|
params.limit,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user