node.gno
11.01 Kb · 569 lines
1package btreeset
2
3import "sort"
4
5type node struct {
6 keys []key
7 children []*node
8 count int
9}
10
11func newNode(maxKeys int) *node {
12 return &node{
13 keys: make([]key, 0, maxKeys),
14 }
15}
16
17func newNodeWithEntry(key key, maxKeys int) *node {
18 n := newNode(maxKeys)
19 n.insertAt(0, key)
20 n.refreshCount()
21 return n
22}
23
24func newNodeWithEntries(keys []key, maxKeys int) *node {
25 n := newNode(maxKeys)
26 n.keys = append(n.keys, keys...)
27 n.refreshCount()
28 return n
29}
30
31func (n *node) size() int {
32 return len(n.keys)
33}
34
35func (n *node) subtreeSize() int {
36 if n == nil {
37 return 0
38 }
39 return n.count
40}
41
42func (n *node) refreshCount() {
43 if n == nil {
44 return
45 }
46
47 count := n.size()
48 for _, child := range n.children {
49 count += child.subtreeSize()
50 }
51
52 n.count = count
53}
54
55func (n *node) numOfChildren() int {
56 return len(n.children)
57}
58
59func (n *node) isLeaf() bool {
60 return n.numOfChildren() == 0
61}
62
63func (n *node) firstChild() *node {
64 return n.children[0]
65}
66
67func (n *node) leftChild(keyIdx int) *node {
68 return n.children[keyIdx]
69}
70
71func (n *node) rightChild(keyIdx int) *node {
72 return n.children[keyIdx+1]
73}
74
75func (n *node) lastChild() *node {
76 return n.children[len(n.children)-1]
77}
78
79func (n *node) lastKey() key {
80 return n.keys[n.size()-1]
81}
82
83func (n *node) removeFirstChild() {
84 n.children = n.children[1:]
85}
86
87func (n *node) removeLastChild() {
88 n.children = n.children[:n.numOfChildren()-1]
89}
90
91func (n *node) has(key key) bool {
92 // idx is a child index when the key is not found.
93 idx, found := n.lowerBound(key)
94 if found {
95 return true
96 }
97
98 if len(n.children) == 0 {
99 return false
100 }
101
102 return n.children[idx].has(key)
103}
104
105func (n *node) updateAt(keyIdx int, key key) {
106 n.keys[keyIdx] = key
107}
108
109func (n *node) insertAt(keyIdx int, key key) {
110 n.keys = append(n.keys, nil)
111 copy(n.keys[keyIdx+1:], n.keys[keyIdx:])
112 n.keys[keyIdx] = key
113}
114
115func (n *node) insertChildAt(childIdx int, child *node) {
116 n.children = append(n.children, nil)
117 copy(n.children[childIdx+1:], n.children[childIdx:])
118 n.children[childIdx] = child
119}
120
121func (n *node) removeAt(keyIdx int) {
122 n.keys = append(n.keys[:keyIdx], n.keys[keyIdx+1:]...)
123}
124
125func (n *node) removeChildAt(childIdx int) {
126 n.children = append(n.children[:childIdx], n.children[childIdx+1:]...)
127}
128
129func (n *node) lowerBound(key key) (idx int, found bool) {
130 // idx is the key index when found, or the child/search index otherwise.
131 idx = sort.Search(n.size(), func(i int) bool {
132 return !n.keys[i].Less(key)
133 })
134
135 if idx >= n.size() || !equalKey(n.keys[idx], key) {
136 return idx, false
137 }
138
139 return idx, true
140}
141
142func (n *node) setNonFull(key key, maxKeys int) (inserted bool) {
143 // idx is the lower-bound result before it is refined into childIdx.
144 idx, found := n.lowerBound(key)
145 if found {
146 return false
147 }
148
149 if n.isLeaf() {
150 n.insertAt(idx, key)
151 n.count++
152 return true
153 }
154
155 childIdx := idx
156 if n.children[childIdx].size() >= maxKeys {
157 n.split(childIdx, maxKeys, n.isAppendToChild(childIdx, key))
158 childIdx, found = n.lowerBound(key)
159 if found {
160 return false
161 }
162 }
163
164 inserted = n.children[childIdx].setNonFull(key, maxKeys)
165 if inserted {
166 n.count++
167 }
168 return inserted
169}
170
171func (n *node) isAppendToChild(childIdx int, key key) bool {
172 if childIdx != n.numOfChildren()-1 {
173 return false
174 }
175
176 child := n.children[childIdx]
177 if child.size() == 0 {
178 return true
179 }
180 return child.lastKey().Less(key)
181}
182
183func (n *node) split(childIdx, maxKeys int, appendSplit bool) {
184 right, promotedKey := n.separate(childIdx, maxKeys, appendSplit)
185 n.insertAt(childIdx, promotedKey)
186 n.insertChildAt(childIdx+1, right)
187 n.refreshCount()
188}
189
190func (n *node) separate(childIdx, maxKeys int, appendSplit bool) (right *node, promoted key) {
191 mid := maxKeys / 2
192 if appendSplit {
193 mid = maxKeys - 1
194 }
195 left := n.children[childIdx]
196 right = newNodeWithEntries(left.keys[mid+1:], maxKeys)
197 promoted = left.keys[mid]
198 left.keys = left.keys[:mid]
199
200 if !left.isLeaf() {
201 right.children = append([]*node(nil), left.children[mid+1:]...)
202 left.children = left.children[:mid+1]
203 }
204
205 left.refreshCount()
206 right.refreshCount()
207 return right, promoted
208}
209
210func (n *node) remove(key key, minKeys int) bool {
211 // idx is the key index when found, or the child index to descend into.
212 idx, found := n.lowerBound(key)
213
214 if n.isLeaf() {
215 if found {
216 n.removeAt(idx)
217 n.count--
218 return true
219 }
220 return false
221 }
222
223 if found {
224 if n.leftChild(idx).size() > minKeys {
225 n.replaceWithLeftMax(idx, minKeys)
226 n.count--
227 return true
228 }
229
230 if idx < n.numOfChildren()-1 && n.rightChild(idx).size() > minKeys {
231 n.replaceWithRightMin(idx, minKeys)
232 n.count--
233 return true
234 }
235
236 n.mergeChildren(idx)
237 removed := n.children[idx].remove(key, minKeys)
238 if !removed {
239 panic("merged key not found")
240 }
241 n.count--
242 return true
243 }
244
245 childIdx := idx
246 if n.children[childIdx].size() <= minKeys {
247 childIdx = n.fill(childIdx, minKeys)
248 }
249
250 removed := n.children[childIdx].remove(key, minKeys)
251 if removed {
252 n.count--
253 }
254 return removed
255}
256
257func (n *node) replaceWithLeftMax(keyIdx, minKeys int) {
258 key := n.leftChild(keyIdx).removeMax(minKeys)
259 n.keys[keyIdx] = key
260}
261
262func (n *node) removeMax(minKeys int) (maxKey key) {
263 if n.isLeaf() {
264 lastKeyIdx := n.size() - 1
265 maxKey = n.keys[lastKeyIdx]
266 n.removeAt(lastKeyIdx)
267 n.count--
268 return maxKey
269 }
270
271 if n.lastChild().size() <= minKeys {
272 lastChildIdx := n.numOfChildren() - 1
273 n.fill(lastChildIdx, minKeys)
274 }
275
276 maxKey = n.lastChild().removeMax(minKeys)
277 n.count--
278 return maxKey
279}
280
281func (n *node) replaceWithRightMin(keyIdx, minKeys int) {
282 key := n.rightChild(keyIdx).removeMin(minKeys)
283 n.keys[keyIdx] = key
284}
285
286func (n *node) removeMin(minKeys int) (minKey key) {
287 if n.isLeaf() {
288 minKey = n.keys[0]
289 n.removeAt(0)
290 n.count--
291 return minKey
292 }
293
294 if n.firstChild().size() <= minKeys {
295 n.fill(0, minKeys)
296 }
297
298 minKey = n.firstChild().removeMin(minKeys)
299 n.count--
300 return minKey
301}
302
303func (n *node) fill(childIdx, minKeys int) (updatedIdx int) {
304 if childIdx-1 >= 0 && n.children[childIdx-1].size() > minKeys {
305 n.borrowFromLeft(childIdx)
306 return childIdx
307 } else if childIdx+1 < n.numOfChildren() && n.children[childIdx+1].size() > minKeys {
308 n.borrowFromRight(childIdx)
309 return childIdx
310 } else if childIdx+1 < n.numOfChildren() {
311 n.mergeChildren(childIdx)
312 return childIdx
313 } else {
314 n.mergeChildren(childIdx - 1)
315 return childIdx - 1
316 }
317}
318
319func (n *node) borrowFromLeft(childIdx int) {
320 parentKey := n.keys[childIdx-1]
321 leftSibling := n.children[childIdx-1]
322 leftLastKey := leftSibling.keys[leftSibling.size()-1]
323
324 child := n.children[childIdx]
325 child.insertAt(0, parentKey)
326
327 n.updateAt(childIdx-1, leftLastKey)
328 leftSibling.removeAt(leftSibling.size() - 1)
329
330 if !leftSibling.isLeaf() {
331 child.insertChildAt(0, leftSibling.lastChild())
332 leftSibling.removeLastChild()
333 }
334
335 leftSibling.refreshCount()
336 child.refreshCount()
337 n.refreshCount()
338}
339
340func (n *node) borrowFromRight(childIdx int) {
341 parentKey := n.keys[childIdx]
342 rightSibling := n.children[childIdx+1]
343 rightFirstKey := rightSibling.keys[0]
344
345 child := n.children[childIdx]
346 child.insertAt(child.size(), parentKey)
347
348 n.updateAt(childIdx, rightFirstKey)
349 rightSibling.removeAt(0)
350
351 if !rightSibling.isLeaf() {
352 child.insertChildAt(child.numOfChildren(), rightSibling.firstChild())
353 rightSibling.removeFirstChild()
354 }
355
356 rightSibling.refreshCount()
357 child.refreshCount()
358 n.refreshCount()
359}
360
361func (n *node) mergeChildren(keyIdx int) {
362 parentKey := n.keys[keyIdx]
363 left := n.children[keyIdx]
364 right := n.children[keyIdx+1]
365
366 left.keys = append(left.keys, parentKey)
367 left.keys = append(left.keys, right.keys...)
368
369 if !right.isLeaf() {
370 left.children = append(left.children, right.children...)
371 }
372
373 n.removeAt(keyIdx)
374 n.removeChildAt(keyIdx + 1)
375 left.refreshCount()
376 n.refreshCount()
377}
378
379func (n *node) getByIndex(index int) key {
380 rem := index
381 for i := 0; i < n.size(); i++ {
382 if !n.isLeaf() {
383 childSize := n.children[i].subtreeSize()
384 if rem < childSize {
385 return n.children[i].getByIndex(rem)
386 }
387 rem -= childSize
388 }
389 if rem == 0 {
390 return n.keys[i]
391 }
392 rem--
393 }
394 if !n.isLeaf() && rem < n.children[n.size()].subtreeSize() {
395 return n.children[n.size()].getByIndex(rem)
396 }
397 panic("GetByIndex asked for invalid index")
398}
399
400func (n *node) iterate(start, end key, cb iterCbFn) bool {
401 for i := 0; i < n.size(); i++ {
402 if !n.isLeaf() {
403 if n.children[i].iterate(start, end, cb) {
404 return true
405 }
406 }
407 key := n.keys[i]
408 if start != nil && key.Less(start) {
409 continue
410 }
411 if end != nil && !key.Less(end) {
412 return false
413 }
414 if cb(key) {
415 return true
416 }
417 }
418 if !n.isLeaf() {
419 return n.children[n.size()].iterate(start, end, cb)
420 }
421 return false
422}
423
424func (n *node) reverseIterate(start, end key, cb iterCbFn) bool {
425 for i := n.size(); i >= 0; i-- {
426 if !n.isLeaf() {
427 if n.children[i].reverseIterate(start, end, cb) {
428 return true
429 }
430 }
431 if i == 0 {
432 break
433 }
434
435 keyIdx := i - 1
436 key := n.keys[keyIdx]
437 if end != nil && end.Less(key) {
438 continue
439 }
440 if start != nil && !start.Less(key) {
441 return false
442 }
443 if cb(key) {
444 return true
445 }
446 }
447 return false
448}
449
450func (n *node) iterateByOffset(
451 offset int,
452 limit int,
453 reverse bool,
454 seen *int,
455 visited *int,
456 cb iterCbFn,
457) bool {
458 if n == nil || *visited >= limit {
459 return false
460 }
461 if reverse {
462 return n.reverseIterateByOffset(offset, limit, seen, visited, cb)
463 }
464 return n.forwardIterateByOffset(offset, limit, seen, visited, cb)
465}
466
467func (n *node) forwardIterateByOffset(
468 offset int,
469 limit int,
470 seen *int,
471 visited *int,
472 cb iterCbFn,
473) bool {
474 for i := 0; i < n.size(); i++ {
475 if !n.isLeaf() {
476 if visitChildByOffset(n.children[i], offset, limit, false, seen, visited, cb) {
477 return true
478 }
479 if *visited >= limit {
480 return false
481 }
482 }
483 if visitKeyByOffset(n.keys[i], offset, limit, seen, visited, cb) {
484 return true
485 }
486 if *visited >= limit {
487 return false
488 }
489 }
490 if !n.isLeaf() {
491 return visitChildByOffset(n.children[n.size()], offset, limit, false, seen, visited, cb)
492 }
493 return false
494}
495
496func (n *node) reverseIterateByOffset(
497 offset int,
498 limit int,
499 seen *int,
500 visited *int,
501 cb iterCbFn,
502) bool {
503 if !n.isLeaf() {
504 if visitChildByOffset(n.children[n.size()], offset, limit, true, seen, visited, cb) {
505 return true
506 }
507 if *visited >= limit {
508 return false
509 }
510 }
511 for i := n.size() - 1; i >= 0; i-- {
512 if visitKeyByOffset(n.keys[i], offset, limit, seen, visited, cb) {
513 return true
514 }
515 if *visited >= limit {
516 return false
517 }
518 if !n.isLeaf() {
519 if visitChildByOffset(n.children[i], offset, limit, true, seen, visited, cb) {
520 return true
521 }
522 if *visited >= limit {
523 return false
524 }
525 }
526 }
527 return false
528}
529
530func visitChildByOffset(
531 child *node,
532 offset int,
533 limit int,
534 reverse bool,
535 seen *int,
536 visited *int,
537 cb iterCbFn,
538) bool {
539 if child == nil || *visited >= limit {
540 return false
541 }
542 if *seen+child.subtreeSize() <= offset {
543 *seen += child.subtreeSize()
544 return false
545 }
546 return child.iterateByOffset(offset, limit, reverse, seen, visited, cb)
547}
548
549func visitKeyByOffset(
550 key key,
551 offset int,
552 limit int,
553 seen *int,
554 visited *int,
555 cb iterCbFn,
556) bool {
557 if *seen < offset {
558 *seen++
559 return false
560 }
561 if *visited >= limit {
562 return false
563 }
564 if cb(key) {
565 return true
566 }
567 *visited++
568 return false
569}