node.gno
12.51 Kb · 609 lines
1package btree
2
3import "sort"
4
5type node struct {
6 keys []key
7 values []any
8 children []*node
9 count int
10}
11
12func newNode(maxKeys int) *node {
13 return &node{
14 keys: make([]key, 0, maxKeys),
15 values: make([]any, 0, maxKeys),
16 }
17}
18
19func newNodeWithEntry(key key, value any, maxKeys int) *node {
20 n := newNode(maxKeys)
21 n.insertAt(0, key, value)
22 n.refreshCount()
23 return n
24}
25
26func newNodeWithEntries(keys []key, values []any, maxKeys int) *node {
27 n := newNode(maxKeys)
28 n.keys = append(n.keys, keys...)
29 n.values = append(n.values, values...)
30 n.refreshCount()
31 return n
32}
33
34func (n *node) size() int {
35 return len(n.keys)
36}
37
38func (n *node) subtreeSize() int {
39 if n == nil {
40 return 0
41 }
42 return n.count
43}
44
45func (n *node) refreshCount() {
46 if n == nil {
47 return
48 }
49 count := n.size()
50 for _, child := range n.children {
51 count += child.subtreeSize()
52 }
53 n.count = count
54}
55
56func (n *node) numOfChildren() int {
57 return len(n.children)
58}
59
60func (n *node) isLeaf() bool {
61 return n.numOfChildren() == 0
62}
63
64func (n *node) firstChild() *node {
65 return n.children[0]
66}
67
68func (n *node) leftChild(keyIdx int) *node {
69 return n.children[keyIdx]
70}
71
72func (n *node) rightChild(keyIdx int) *node {
73 return n.children[keyIdx+1]
74}
75
76func (n *node) lastChild() *node {
77 return n.children[len(n.children)-1]
78}
79
80func (n *node) lastKey() key {
81 return n.keys[n.size()-1]
82}
83
84func (n *node) removeFirstChild() {
85 n.children = n.children[1:]
86}
87
88func (n *node) removeLastChild() {
89 n.children = n.children[:n.numOfChildren()-1]
90}
91
92func (n *node) lowerBound(key key) (idx int, found bool) {
93 // idx is the key index when found, or the child/search index otherwise.
94 idx = sort.Search(n.size(), func(i int) bool {
95 return !n.keys[i].Less(key)
96 })
97
98 if idx >= n.size() || !equalKey(n.keys[idx], key) {
99 return idx, false
100 }
101
102 return idx, true
103}
104
105func (n *node) get(key key) (any, bool) {
106 // idx is a child index when the key is not found.
107 idx, found := n.lowerBound(key)
108 if found {
109 return n.values[idx], true
110 }
111 if len(n.children) == 0 {
112 return nil, false
113 }
114 return n.children[idx].get(key)
115}
116
117func (n *node) getAt(keyIdx int) (key, any) {
118 return n.keys[keyIdx], n.values[keyIdx]
119}
120
121func (n *node) updateIfFound(key key, value any) (int, bool) {
122 idx, found := n.lowerBound(key)
123 if found {
124 n.values[idx] = value
125 return idx, true
126 }
127 return idx, false
128}
129
130func (n *node) update(key key, value any) bool {
131 idx, found := n.updateIfFound(key, value)
132 if found {
133 return true
134 }
135 if n.isLeaf() {
136 return false
137 }
138 return n.children[idx].update(key, value)
139}
140
141func (n *node) updateAt(keyIdx int, key key, value any) {
142 n.keys[keyIdx] = key
143 n.values[keyIdx] = value
144}
145
146func (n *node) insertAt(keyIdx int, key key, value any) {
147 n.keys = append(n.keys, nil)
148 n.values = append(n.values, nil)
149 copy(n.keys[keyIdx+1:], n.keys[keyIdx:])
150 copy(n.values[keyIdx+1:], n.values[keyIdx:])
151 n.keys[keyIdx] = key
152 n.values[keyIdx] = value
153}
154
155func (n *node) removeAt(keyIdx int) {
156 n.keys = append(n.keys[:keyIdx], n.keys[keyIdx+1:]...)
157 n.values = append(n.values[:keyIdx], n.values[keyIdx+1:]...)
158}
159
160func (n *node) removeChildAt(childIdx int) {
161 n.children = append(n.children[:childIdx], n.children[childIdx+1:]...)
162}
163
164func (n *node) insertChildAt(childIdx int, child *node) {
165 n.children = append(n.children, nil)
166 copy(n.children[childIdx+1:], n.children[childIdx:])
167 n.children[childIdx] = child
168}
169
170func (n *node) setNonFull(key key, value any, maxKeys int) bool {
171 // idx is the lower-bound result before it is refined into childIdx.
172 idx, updated := n.updateIfFound(key, value)
173 if updated {
174 return false
175 }
176
177 if n.isLeaf() {
178 n.insertAt(idx, key, value)
179 n.count++
180 return true
181 }
182
183 childIdx := idx
184 if n.children[childIdx].size() >= maxKeys {
185 n.split(childIdx, maxKeys, n.isAppendToChild(childIdx, key))
186 childIdx, updated = n.updateIfFound(key, value)
187 if updated {
188 return false
189 }
190 }
191
192 inserted := n.children[childIdx].setNonFull(key, value, maxKeys)
193 if inserted {
194 n.count++
195 }
196 return inserted
197}
198
199func (n *node) isAppendToChild(childIdx int, key key) bool {
200 if childIdx != n.numOfChildren()-1 {
201 return false
202 }
203
204 child := n.children[childIdx]
205 if child.size() == 0 {
206 return true
207 }
208 return child.lastKey().Less(key)
209}
210
211func (n *node) split(childIdx int, maxKeys int, appendSplit bool) {
212 mid := maxKeys / 2
213 if appendSplit {
214 mid = maxKeys - 1
215 }
216 right, promotedKey, promotedValue := n.separate(childIdx, mid, maxKeys)
217 n.insertAt(childIdx, promotedKey, promotedValue)
218 n.insertChildAt(childIdx+1, right)
219 n.refreshCount()
220}
221
222func (n *node) separate(childIdx, mid, maxKeys int) (*node, key, any) {
223 left := n.children[childIdx]
224 right := newNodeWithEntries(left.keys[mid+1:], left.values[mid+1:], maxKeys)
225
226 promotedKey, promotedValue := left.getAt(mid)
227 left.keys = left.keys[:mid]
228 left.values = left.values[:mid]
229
230 if !left.isLeaf() {
231 right.children = append([]*node(nil), left.children[mid+1:]...)
232 left.children = left.children[:mid+1]
233 }
234
235 left.refreshCount()
236 right.refreshCount()
237 return right, promotedKey, promotedValue
238}
239
240func (n *node) remove(key key, minKeys int) (any, bool) {
241 // idx is the key index when found, or the child index to descend into.
242 idx, found := n.lowerBound(key)
243
244 if n.isLeaf() {
245 if found {
246 v := n.values[idx]
247 n.removeAt(idx)
248 n.count--
249 return v, true
250 }
251 return nil, false
252 }
253
254 if found {
255 if n.leftChild(idx).size() > minKeys {
256 v, removed := n.replaceWithLeftMax(idx, minKeys)
257 if removed {
258 n.count--
259 }
260 return v, removed
261 }
262 if idx < n.numOfChildren()-1 && n.rightChild(idx).size() > minKeys {
263 v, removed := n.replaceWithRightMin(idx, minKeys)
264 if removed {
265 n.count--
266 }
267 return v, removed
268 }
269 n.mergeChildren(idx)
270 v, removed := n.children[idx].remove(key, minKeys)
271 if !removed {
272 panic("merged key not found")
273 }
274 n.count--
275 return v, true
276 }
277
278 childIdx := idx
279 if n.children[childIdx].size() <= minKeys {
280 childIdx = n.fill(childIdx, minKeys)
281 }
282 v, removed := n.children[childIdx].remove(key, minKeys)
283 if removed {
284 n.count--
285 }
286 return v, removed
287}
288
289func (n *node) replaceWithLeftMax(keyIdx, minKeys int) (any, bool) {
290 old := n.values[keyIdx]
291 k, v := n.leftChild(keyIdx).removeMax(minKeys)
292 n.updateAt(keyIdx, k, v)
293 return old, true
294}
295
296func (n *node) removeMax(minKeys int) (key, any) {
297 if n.isLeaf() {
298 lastKeyIdx := n.size() - 1
299 k, v := n.getAt(lastKeyIdx)
300 n.removeAt(lastKeyIdx)
301 n.count--
302 return k, v
303 }
304
305 if n.lastChild().size() <= minKeys {
306 lastChildIdx := n.numOfChildren() - 1
307 n.fill(lastChildIdx, minKeys)
308 }
309 k, v := n.lastChild().removeMax(minKeys)
310 n.count--
311 return k, v
312}
313
314func (n *node) replaceWithRightMin(keyIdx, minKeys int) (any, bool) {
315 old := n.values[keyIdx]
316 k, v := n.rightChild(keyIdx).removeMin(minKeys)
317 n.updateAt(keyIdx, k, v)
318 return old, true
319}
320
321func (n *node) removeMin(minKeys int) (key, any) {
322 if n.isLeaf() {
323 k, v := n.getAt(0)
324 n.removeAt(0)
325 n.count--
326 return k, v
327 }
328
329 if n.firstChild().size() <= minKeys {
330 n.fill(0, minKeys)
331 }
332 k, v := n.firstChild().removeMin(minKeys)
333 n.count--
334 return k, v
335}
336
337func (n *node) fill(childIdx int, minKeys int) int {
338 if childIdx-1 >= 0 && n.children[childIdx-1].size() > minKeys {
339 n.borrowFromLeft(childIdx)
340 return childIdx
341 } else if childIdx+1 < n.numOfChildren() && n.children[childIdx+1].size() > minKeys {
342 n.borrowFromRight(childIdx)
343 return childIdx
344 } else if childIdx+1 < n.numOfChildren() {
345 n.mergeChildren(childIdx)
346 return childIdx
347 } else {
348 n.mergeChildren(childIdx - 1)
349 return childIdx - 1
350 }
351}
352
353func (n *node) borrowFromLeft(childIdx int) {
354 parentKey, parentValue := n.getAt(childIdx - 1)
355 leftSibling := n.children[childIdx-1]
356 leftLast := leftSibling.size() - 1
357 leftLastKey, leftLastValue := leftSibling.getAt(leftLast)
358
359 child := n.children[childIdx]
360 child.insertAt(0, parentKey, parentValue)
361
362 n.updateAt(childIdx-1, leftLastKey, leftLastValue)
363 leftSibling.removeAt(leftLast)
364
365 if !leftSibling.isLeaf() {
366 child.insertChildAt(0, leftSibling.lastChild())
367 leftSibling.removeLastChild()
368 }
369
370 leftSibling.refreshCount()
371 child.refreshCount()
372 n.refreshCount()
373}
374
375func (n *node) borrowFromRight(childIdx int) {
376 parentKey, parentValue := n.getAt(childIdx)
377 rightSibling := n.children[childIdx+1]
378 rightFirstKey, rightFirstValue := rightSibling.getAt(0)
379
380 child := n.children[childIdx]
381 child.insertAt(child.size(), parentKey, parentValue)
382
383 n.updateAt(childIdx, rightFirstKey, rightFirstValue)
384 rightSibling.removeAt(0)
385
386 if !rightSibling.isLeaf() {
387 child.insertChildAt(child.numOfChildren(), rightSibling.firstChild())
388 rightSibling.removeFirstChild()
389 }
390
391 rightSibling.refreshCount()
392 child.refreshCount()
393 n.refreshCount()
394}
395
396func (n *node) mergeChildren(keyIdx int) {
397 parentKey := n.keys[keyIdx]
398 parentValue := n.values[keyIdx]
399 left := n.children[keyIdx]
400 right := n.children[keyIdx+1]
401
402 left.keys = append(left.keys, parentKey)
403 left.values = append(left.values, parentValue)
404
405 left.keys = append(left.keys, right.keys...)
406 left.values = append(left.values, right.values...)
407
408 if !right.isLeaf() {
409 left.children = append(left.children, right.children...)
410 }
411
412 n.removeAt(keyIdx)
413 n.removeChildAt(keyIdx + 1)
414 left.refreshCount()
415 n.refreshCount()
416}
417
418func (n *node) getByIndex(index int) (key, any) {
419 rem := index
420 for i := 0; i < n.size(); i++ {
421 if !n.isLeaf() {
422 childSize := n.children[i].subtreeSize()
423 if rem < childSize {
424 return n.children[i].getByIndex(rem)
425 }
426 rem -= childSize
427 }
428 if rem == 0 {
429 return n.getAt(i)
430 }
431 rem--
432 }
433 if !n.isLeaf() && rem < n.children[n.size()].subtreeSize() {
434 return n.children[n.size()].getByIndex(rem)
435 }
436 panic("GetByIndex asked for invalid index")
437}
438
439func (n *node) iterate(start, end key, cb iterCbFn) bool {
440 for i := 0; i < n.size(); i++ {
441 if !n.isLeaf() {
442 if n.children[i].iterate(start, end, cb) {
443 return true
444 }
445 }
446 key, value := n.getAt(i)
447 if start != nil && key.Less(start) {
448 continue
449 }
450 if end != nil && !key.Less(end) {
451 return false
452 }
453 if cb(key, value) {
454 return true
455 }
456 }
457 if !n.isLeaf() {
458 return n.children[n.size()].iterate(start, end, cb)
459 }
460 return false
461}
462
463func (n *node) reverseIterate(start, end key, cb iterCbFn) bool {
464 for i := n.size(); i >= 0; i-- {
465 if !n.isLeaf() {
466 if n.children[i].reverseIterate(start, end, cb) {
467 return true
468 }
469 }
470 if i == 0 {
471 break
472 }
473
474 keyIdx := i - 1
475 key, value := n.getAt(keyIdx)
476 if end != nil && end.Less(key) {
477 continue
478 }
479 if start != nil && key.Less(start) {
480 return false
481 }
482 if cb(key, value) {
483 return true
484 }
485 }
486 return false
487}
488
489func (n *node) iterateByOffset(
490 offset int,
491 limit int,
492 reverse bool,
493 seen *int,
494 visited *int,
495 cb iterCbFn,
496) bool {
497 if n == nil || *visited >= limit {
498 return false
499 }
500 if reverse {
501 return n.reverseIterateByOffset(offset, limit, seen, visited, cb)
502 }
503 return n.forwardIterateByOffset(offset, limit, seen, visited, cb)
504}
505
506func (n *node) forwardIterateByOffset(
507 offset int,
508 limit int,
509 seen *int,
510 visited *int,
511 cb iterCbFn,
512) bool {
513 for i := 0; i < n.size(); i++ {
514 if !n.isLeaf() {
515 if visitChildByOffset(n.children[i], offset, limit, false, seen, visited, cb) {
516 return true
517 }
518 if *visited >= limit {
519 return false
520 }
521 }
522 if visitKeyByOffset(n.keys[i], n.values[i], offset, limit, seen, visited, cb) {
523 return true
524 }
525 if *visited >= limit {
526 return false
527 }
528 }
529 if !n.isLeaf() {
530 return visitChildByOffset(n.children[n.size()], offset, limit, false, seen, visited, cb)
531 }
532 return false
533}
534
535func (n *node) reverseIterateByOffset(
536 offset int,
537 limit int,
538 seen *int,
539 visited *int,
540 cb iterCbFn,
541) bool {
542 if !n.isLeaf() {
543 if visitChildByOffset(n.children[n.size()], offset, limit, true, seen, visited, cb) {
544 return true
545 }
546 if *visited >= limit {
547 return false
548 }
549 }
550 for i := n.size() - 1; i >= 0; i-- {
551 if visitKeyByOffset(n.keys[i], n.values[i], offset, limit, seen, visited, cb) {
552 return true
553 }
554 if *visited >= limit {
555 return false
556 }
557 if !n.isLeaf() {
558 if visitChildByOffset(n.children[i], offset, limit, true, seen, visited, cb) {
559 return true
560 }
561 if *visited >= limit {
562 return false
563 }
564 }
565 }
566 return false
567}
568
569func visitChildByOffset(
570 child *node,
571 offset int,
572 limit int,
573 reverse bool,
574 seen *int,
575 visited *int,
576 cb iterCbFn,
577) bool {
578 if child == nil || *visited >= limit {
579 return false
580 }
581 if *seen+child.subtreeSize() <= offset {
582 *seen += child.subtreeSize()
583 return false
584 }
585 return child.iterateByOffset(offset, limit, reverse, seen, visited, cb)
586}
587
588func visitKeyByOffset(
589 key key,
590 value any,
591 offset int,
592 limit int,
593 seen *int,
594 visited *int,
595 cb iterCbFn,
596) bool {
597 if *seen < offset {
598 *seen++
599 return false
600 }
601 if *visited >= limit {
602 return false
603 }
604 if cb(key, value) {
605 return true
606 }
607 *visited++
608 return false
609}