From 10e54dc55c7295ca2239ea8537b47d4cb69116f9 Mon Sep 17 00:00:00 2001 From: TideFinder Date: Wed, 28 Jan 2026 15:54:52 +0900 Subject: [PATCH] feat: add RRF reranker for hybrid search Amp-Thread-ID: https://ampcode.com/threads/T-019c032e-16c5-7589-b68b-4e2651a1e631 Co-authored-by: Amp --- src/agents/memory-search.ts | 26 ++++++++ src/config/types.tools.ts | 7 +++ src/memory/hybrid.ts | 78 ++++++++++++++++++++++-- src/memory/manager.ts | 13 ++++ src/memory/reranker.test.ts | 115 ++++++++++++++++++++++++++++++++++++ src/memory/reranker.ts | 86 +++++++++++++++++++++++++++ 6 files changed, 319 insertions(+), 6 deletions(-) create mode 100644 src/memory/reranker.test.ts create mode 100644 src/memory/reranker.ts diff --git a/src/agents/memory-search.ts b/src/agents/memory-search.ts index c08161d4f..b0a4e9e5b 100644 --- a/src/agents/memory-search.ts +++ b/src/agents/memory-search.ts @@ -63,6 +63,12 @@ export type ResolvedMemorySearchConfig = { textWeight: number; candidateMultiplier: number; }; + reranker: { + method: "rrf" | "weighted" | "none"; + rrf: { + k: number; + }; + }; }; cache: { enabled: boolean; @@ -83,6 +89,8 @@ const DEFAULT_HYBRID_ENABLED = true; const DEFAULT_HYBRID_VECTOR_WEIGHT = 0.7; const DEFAULT_HYBRID_TEXT_WEIGHT = 0.3; const DEFAULT_HYBRID_CANDIDATE_MULTIPLIER = 4; +const DEFAULT_RERANKER_METHOD: "rrf" | "weighted" | "none" = "weighted"; +const DEFAULT_RRF_K = 60; const DEFAULT_CACHE_ENABLED = true; const DEFAULT_SOURCES: Array<"memory" | "sessions"> = ["memory"]; @@ -218,6 +226,18 @@ function mergeConfig( defaults?.query?.hybrid?.candidateMultiplier ?? DEFAULT_HYBRID_CANDIDATE_MULTIPLIER, }; + const reranker = { + method: + overrides?.query?.reranker?.method ?? + defaults?.query?.reranker?.method ?? + DEFAULT_RERANKER_METHOD, + rrf: { + k: + overrides?.query?.reranker?.rrf?.k ?? + defaults?.query?.reranker?.rrf?.k ?? + DEFAULT_RRF_K, + }, + }; const cache = { enabled: overrides?.cache?.enabled ?? defaults?.cache?.enabled ?? DEFAULT_CACHE_ENABLED, maxEntries: overrides?.cache?.maxEntries ?? defaults?.cache?.maxEntries, @@ -262,6 +282,12 @@ function mergeConfig( textWeight: normalizedTextWeight, candidateMultiplier, }, + reranker: { + method: reranker.method, + rrf: { + k: Math.max(1, reranker.rrf.k), + }, + }, }, cache: { enabled: Boolean(cache.enabled), diff --git a/src/config/types.tools.ts b/src/config/types.tools.ts index bb1d45bf0..1b52765be 100644 --- a/src/config/types.tools.ts +++ b/src/config/types.tools.ts @@ -311,6 +311,13 @@ export type MemorySearchConfig = { /** Multiplier for candidate pool size (default: 4). */ candidateMultiplier?: number; }; + reranker?: { + /**Reranking method: "rrf" (Reciprocal Rank Fusion), "weighted" (legacy), or "none". */ + method?: "rrf" | "weighted" | "none"; + rrf?: { + k?: number; + }; + }; }; /** Index cache behavior. */ cache?: { diff --git a/src/memory/hybrid.ts b/src/memory/hybrid.ts index 753748bf9..dfad82a77 100644 --- a/src/memory/hybrid.ts +++ b/src/memory/hybrid.ts @@ -1,3 +1,5 @@ +import { rerank, type RerankerInput } from "./reranker.js"; + export type HybridSource = string; export type HybridVectorResult = { @@ -20,6 +22,8 @@ export type HybridKeywordResult = { textScore: number; }; +export type RerankerMethod = "rrf" | "weighted" | "none"; + export function buildFtsQuery(raw: string): string | null { const tokens = raw @@ -36,19 +40,28 @@ export function bm25RankToScore(rank: number): number { return 1 / (1 + normalized); } -export function mergeHybridResults(params: { +export type MergeHybridParams = { vector: HybridVectorResult[]; keyword: HybridKeywordResult[]; vectorWeight: number; textWeight: number; -}): Array<{ + reranker?: { + method: RerankerMethod; + rrf?: { k: number }; + }; +}; + +export type MergedResult = { + id: string; path: string; startLine: number; endLine: number; score: number; snippet: string; source: HybridSource; -}> { +}; + +export function mergeHybridResults(params: MergeHybridParams): MergedResult[] { const byId = new Map< string, { @@ -60,10 +73,13 @@ export function mergeHybridResults(params: { snippet: string; vectorScore: number; textScore: number; + vectorRank?: number; + keywordRank?: number; } >(); - for (const r of params.vector) { + for (let i = 0; i < params.vector.length; i++) { + const r = params.vector[i]; byId.set(r.id, { id: r.id, path: r.path, @@ -73,13 +89,16 @@ export function mergeHybridResults(params: { snippet: r.snippet, vectorScore: r.vectorScore, textScore: 0, + vectorRank: i + 1, }); } - for (const r of params.keyword) { + for (let i = 0; i < params.keyword.length; i++) { + const r = params.keyword[i]; const existing = byId.get(r.id); if (existing) { existing.textScore = r.textScore; + existing.keywordRank = i + 1; if (r.snippet && r.snippet.length > 0) existing.snippet = r.snippet; } else { byId.set(r.id, { @@ -91,13 +110,60 @@ export function mergeHybridResults(params: { snippet: r.snippet, vectorScore: 0, textScore: r.textScore, + keywordRank: i + 1, }); } } - const merged = Array.from(byId.values()).map((entry) => { + const entries = Array.from(byId.values()); + const method = params.reranker?.method ?? "weighted"; + + if (method === "rrf") { + const rerankerInputs: RerankerInput[] = entries.map((e) => ({ + id: e.id, + vectorRank: e.vectorRank, + keywordRank: e.keywordRank, + })); + const reranked = rerank(rerankerInputs, { + method: "rrf", + rrf: { k: params.reranker?.rrf?.k ?? 60 }, + }); + + const idToRank = new Map(reranked.map((r) => [r.id, r])); + return entries + .map((entry) => { + const ranked = idToRank.get(entry.id); + return { + id: entry.id, + path: entry.path, + startLine: entry.startLine, + endLine: entry.endLine, + score: ranked?.rrfScore ?? 0, + snippet: entry.snippet, + source: entry.source, + rank: ranked?.rank ?? entries.length, + }; + }) + .sort((a, b) => a.rank - b.rank) + .map(({ rank: _, ...rest }) => rest); + } + + if (method === "none") { + return entries.map((entry) => ({ + id: entry.id, + path: entry.path, + startLine: entry.startLine, + endLine: entry.endLine, + score: entry.vectorScore || entry.textScore, + snippet: entry.snippet, + source: entry.source, + })); + } + + const merged = entries.map((entry) => { const score = params.vectorWeight * entry.vectorScore + params.textWeight * entry.textScore; return { + id: entry.id, path: entry.path, startLine: entry.startLine, endLine: entry.endLine, diff --git a/src/memory/manager.ts b/src/memory/manager.ts index 9a9991d10..5d3ab8a03 100644 --- a/src/memory/manager.ts +++ b/src/memory/manager.ts @@ -296,13 +296,21 @@ export class MemoryIndexManager { return vectorResults.filter((entry) => entry.score >= minScore).slice(0, maxResults); } + const reranker = this.settings.query.reranker; const merged = this.mergeHybridResults({ vector: vectorResults, keyword: keywordResults, vectorWeight: hybrid.vectorWeight, textWeight: hybrid.textWeight, + reranker: { + method: reranker.method, + rrf: reranker.rrf, + }, }); + if (reranker.method === "rrf") { + return merged.slice(0, maxResults); + } return merged.filter((entry) => entry.score >= minScore).slice(0, maxResults); } @@ -353,6 +361,10 @@ export class MemoryIndexManager { keyword: Array; vectorWeight: number; textWeight: number; + reranker?: { + method: "rrf" | "weighted" | "none"; + rrf?: { k: number }; + }; }): MemorySearchResult[] { const merged = mergeHybridResults({ vector: params.vector.map((r) => ({ @@ -375,6 +387,7 @@ export class MemoryIndexManager { })), vectorWeight: params.vectorWeight, textWeight: params.textWeight, + reranker: params.reranker, }); return merged.map((entry) => entry as MemorySearchResult); } diff --git a/src/memory/reranker.test.ts b/src/memory/reranker.test.ts new file mode 100644 index 000000000..317fd620b --- /dev/null +++ b/src/memory/reranker.test.ts @@ -0,0 +1,115 @@ +import { describe, test, expect } from "vitest"; +import { rerank, rerankRRF, type RerankerInput } from "./reranker.js"; + +describe("reranker", () => { + describe("rerankRRF", () => { + test("combines vector and keyword ranks", () => { + const inputs: RerankerInput[] = [ + { id: "a", vectorRank: 1, keywordRank: 3 }, + { id: "b", vectorRank: 2, keywordRank: 1 }, + { id: "c", vectorRank: 3, keywordRank: 2 }, + ]; + + const results = rerankRRF(inputs, 60); + + expect(results).toHaveLength(3); + expect(results[0].rank).toBe(1); + expect(results[1].rank).toBe(2); + expect(results[2].rank).toBe(3); + // All should have rrfScore defined + expect(results.every((r) => r.rrfScore !== undefined)).toBe(true); + }); + + test("handles vector-only results", () => { + const inputs: RerankerInput[] = [ + { id: "a", vectorRank: 1 }, + { id: "b", vectorRank: 2 }, + ]; + + const results = rerankRRF(inputs); + + expect(results[0].id).toBe("a"); + expect(results[1].id).toBe("b"); + }); + + test("handles keyword-only results", () => { + const inputs: RerankerInput[] = [ + { id: "a", keywordRank: 2 }, + { id: "b", keywordRank: 1 }, + ]; + + const results = rerankRRF(inputs); + + expect(results[0].id).toBe("b"); + expect(results[1].id).toBe("a"); + }); + + test("item in both sources ranks higher than single-source", () => { + const inputs: RerankerInput[] = [ + { id: "both", vectorRank: 2, keywordRank: 2 }, + { id: "vector-only", vectorRank: 1 }, + { id: "keyword-only", keywordRank: 1 }, + ]; + + const results = rerankRRF(inputs, 60); + + // "both" should rank higher due to contributions from two sources + expect(results[0].id).toBe("both"); + }); + + test("returns empty array for empty input", () => { + expect(rerankRRF([])).toEqual([]); + }); + + test("k parameter affects score distribution", () => { + const inputs: RerankerInput[] = [ + { id: "a", vectorRank: 1 }, + { id: "b", vectorRank: 2 }, + ]; + + const lowK = rerankRRF(inputs, 1); + const highK = rerankRRF(inputs, 100); + + // With lower k, the score difference should be larger + const lowKDiff = (lowK[0].rrfScore ?? 0) - (lowK[1].rrfScore ?? 0); + const highKDiff = (highK[0].rrfScore ?? 0) - (highK[1].rrfScore ?? 0); + expect(lowKDiff).toBeGreaterThan(highKDiff); + }); + }); + + describe("rerank", () => { + test("method=none preserves input order", () => { + const inputs: RerankerInput[] = [ + { id: "c", vectorRank: 3 }, + { id: "a", vectorRank: 1 }, + { id: "b", vectorRank: 2 }, + ]; + + const results = rerank(inputs, { method: "none" }); + + expect(results.map((r) => r.id)).toEqual(["c", "a", "b"]); + expect(results.map((r) => r.rank)).toEqual([1, 2, 3]); + }); + + test("method=rrf applies RRF reranking", () => { + const inputs: RerankerInput[] = [ + { id: "a", vectorRank: 2, keywordRank: 1 }, + { id: "b", vectorRank: 1, keywordRank: 2 }, + ]; + + const results = rerank(inputs, { method: "rrf" }); + + expect(results).toHaveLength(2); + expect(results.every((r) => r.rrfScore !== undefined)).toBe(true); + }); + + test("respects custom k value", () => { + const inputs: RerankerInput[] = [{ id: "a", vectorRank: 1 }]; + + const result = rerank(inputs, { method: "rrf", rrf: { k: 10 } }); + + // With k=10, score for rank 1 is 1/(10+1) = 0.0909... + expect(result[0].rrfScore).toBeCloseTo(1 / 11, 4); + }); + }); +}); diff --git a/src/memory/reranker.ts b/src/memory/reranker.ts new file mode 100644 index 000000000..e6343c0f0 --- /dev/null +++ b/src/memory/reranker.ts @@ -0,0 +1,86 @@ +/** + * Reranker module for memory search results. + * + * Provides lightweight, non-LLM reranking strategies to combine + * and reorder results from multiple retrieval sources. + */ + +export type RerankerInput = { + id: string; + vectorRank?: number; + keywordRank?: number; + snippet?: string; +}; + +export type RerankerOutput = { + id: string; + rank: number; + rrfScore?: number; +}; + +export type RerankerMethod = "rrf" | "none"; + +export type RerankerOptions = { + method: RerankerMethod; + rrf?: { + k?: number; // smoothing constant (default: 60) + }; +}; + +const DEFAULT_RRF_K = 60; + +/** + * Reciprocal Rank Fusion (RRF) + * + * Combines rankings from multiple sources using the formula: + * score = sum(1 / (k + rank_i)) for each source + * + * Benefits: + * - No need to normalize incompatible scores + * - Proven effective in hybrid search (Elasticsearch, Vespa) + * - Deterministic and fast + * + * @param k - Smoothing constant (default 60). Higher = less weight to top ranks. + */ +export function rerankRRF(inputs: RerankerInput[], k: number = DEFAULT_RRF_K): RerankerOutput[] { + if (inputs.length === 0) return []; + + const scored = inputs.map((input) => { + let rrfScore = 0; + if (input.vectorRank !== undefined && input.vectorRank > 0) { + rrfScore += 1 / (k + input.vectorRank); + } + if (input.keywordRank !== undefined && input.keywordRank > 0) { + rrfScore += 1 / (k + input.keywordRank); + } + return { id: input.id, rrfScore }; + }); + + scored.sort((a, b) => b.rrfScore - a.rrfScore); + + return scored.map((item, index) => ({ + id: item.id, + rank: index + 1, + rrfScore: item.rrfScore, + })); +} + +/** + * Main reranker entry point. + * + * @param inputs - Results with rank info from each source + * @param options - Reranking configuration + * @returns Reranked results with final rank assignment + */ +export function rerank(inputs: RerankerInput[], options: RerankerOptions): RerankerOutput[] { + if (options.method === "none" || inputs.length === 0) { + return inputs.map((input, index) => ({ id: input.id, rank: index + 1 })); + } + + if (options.method === "rrf") { + const k = options.rrf?.k ?? DEFAULT_RRF_K; + return rerankRRF(inputs, k); + } + + return inputs.map((input, index) => ({ id: input.id, rank: index + 1 })); +}