Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 89 additions & 0 deletions src/semantic-router/pkg/cache/cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1514,6 +1514,95 @@ milvus:
}
}

// Ensures hybrid layer search skips candidates that are already worse than the frontier.
func TestHybridCacheSearchLayerPrunesWeakerBranch(t *testing.T) {
// Regression fixture: the buggy comparison let the frontier accept a much
// worse neighbor (node 3) even after ef was saturated. That re-opened the
// branch to node 4, so the search would walk every reachable node—hurting
// latency and risking a worse match. We wire an artificial edge (3→4) to
// isolate the pruning logic; production HNSW builders try to avoid such links.
embeddings := [][]float32{
{0.80}, // node 0: entry point
{0.79}, // node 1: near-tie neighbor
{0.78}, // node 2: another strong neighbor
{0.10}, // node 3: weak branch that should be pruned
{0.995}, // node 4: hidden best reachable only via node 3
}

nodes := []*HNSWNode{
{
entryIndex: 0,
neighbors: map[int][]int{
0: {1, 2, 3},
},
maxLayer: 0,
},
{
entryIndex: 1,
neighbors: map[int][]int{
0: {0},
},
maxLayer: 0,
},
{
entryIndex: 2,
neighbors: map[int][]int{
0: {0},
},
maxLayer: 0,
},
{
entryIndex: 3,
neighbors: map[int][]int{
0: {0, 4},
},
maxLayer: 0,
},
{
entryIndex: 4,
neighbors: map[int][]int{
0: {3},
},
maxLayer: 0,
},
}

nodeIndex := map[int]*HNSWNode{
0: nodes[0],
1: nodes[1],
2: nodes[2],
3: nodes[3],
4: nodes[4],
}

cache := &HybridCache{
hnswIndex: &HNSWIndex{
nodes: nodes,
nodeIndex: nodeIndex,
entryPoint: 0,
maxLayer: 0,
efConstruction: 4,
M: 4,
Mmax: 4,
Mmax0: 4,
ml: 1,
},
embeddings: embeddings,
idMap: map[int]string{},
}

results := cache.searchLayerHybrid([]float32{1}, 3, 0, []int{0})
if len(results) != 3 {
t.Fatalf("expected frontier to keep three best neighbors, got %v", results)
}
if slices.Contains(results, 4) {
t.Fatalf("expected weaker branch to stay pruned, got %v", results)
}
if !slices.Contains(results, 1) {
t.Fatalf("expected best neighbor 1 to remain in results, got %v", results)
}
}

// BenchmarkHybridCacheAddEntry benchmarks adding entries to hybrid cache
func BenchmarkHybridCacheAddEntry(b *testing.B) {
if os.Getenv("MILVUS_URI") == "" {
Expand Down
12 changes: 6 additions & 6 deletions src/semantic-router/pkg/cache/hybrid_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -938,15 +938,15 @@ func (h *HybridCache) searchLayerHybrid(query []float32, ef int, layer int, entr
if ep < 0 || ep >= len(h.embeddings) {
continue
}
dist := -dotProduct(query, h.embeddings[ep])
dist := -dotProduct(query, h.embeddings[ep]) // Negative product so that higher similarity = lower distance
candidates.push(ep, dist)
results.push(ep, dist)
visited[ep] = true
}

for len(candidates.data) > 0 {
currentIdx, currentDist := candidates.pop()
if len(results.data) > 0 && currentDist > -results.data[0].dist {
if len(results.data) > 0 && currentDist > results.data[0].dist {
break
}

Expand All @@ -964,7 +964,7 @@ func (h *HybridCache) searchLayerHybrid(query []float32, ef int, layer int, entr

dist := -dotProduct(query, h.embeddings[neighborID])

if len(results.data) < ef || dist < -results.data[0].dist {
if len(results.data) < ef || dist < results.data[0].dist {
candidates.push(neighborID, dist)
results.push(neighborID, dist)

Expand Down Expand Up @@ -1062,7 +1062,7 @@ func (h *HybridCache) searchLayerHybridWithEarlyStop(query []float32, ef int, la
if ep < 0 || ep >= len(h.embeddings) {
continue
}
dist := -dotProductSIMD(query, h.embeddings[ep])
dist := -dotProductSIMD(query, h.embeddings[ep]) // Negative product so that higher similarity = lower distance
candidates.push(ep, dist)
results.push(ep, dist)
visited[ep] = true
Expand All @@ -1075,7 +1075,7 @@ func (h *HybridCache) searchLayerHybridWithEarlyStop(query []float32, ef int, la

for len(candidates.data) > 0 {
currentIdx, currentDist := candidates.pop()
if len(results.data) > 0 && currentDist > -results.data[0].dist {
if len(results.data) > 0 && currentDist > results.data[0].dist {
break
}

Expand All @@ -1098,7 +1098,7 @@ func (h *HybridCache) searchLayerHybridWithEarlyStop(query []float32, ef int, la
return []int{neighborID}
}

if len(results.data) < ef || dist < -results.data[0].dist {
if len(results.data) < ef || dist < results.data[0].dist {
candidates.push(neighborID, dist)
results.push(neighborID, dist)

Expand Down
Loading