Search Apps Documentation Source Content File Folder Download Copy Actions Download

node.gno

12.51 Kb · 609 lines
  1package btree
  2
  3import "sort"
  4
  5type node struct {
  6	keys     []key
  7	values   []any
  8	children []*node
  9	count    int
 10}
 11
 12func newNode(maxKeys int) *node {
 13	return &node{
 14		keys:   make([]key, 0, maxKeys),
 15		values: make([]any, 0, maxKeys),
 16	}
 17}
 18
 19func newNodeWithEntry(key key, value any, maxKeys int) *node {
 20	n := newNode(maxKeys)
 21	n.insertAt(0, key, value)
 22	n.refreshCount()
 23	return n
 24}
 25
 26func newNodeWithEntries(keys []key, values []any, maxKeys int) *node {
 27	n := newNode(maxKeys)
 28	n.keys = append(n.keys, keys...)
 29	n.values = append(n.values, values...)
 30	n.refreshCount()
 31	return n
 32}
 33
 34func (n *node) size() int {
 35	return len(n.keys)
 36}
 37
 38func (n *node) subtreeSize() int {
 39	if n == nil {
 40		return 0
 41	}
 42	return n.count
 43}
 44
 45func (n *node) refreshCount() {
 46	if n == nil {
 47		return
 48	}
 49	count := n.size()
 50	for _, child := range n.children {
 51		count += child.subtreeSize()
 52	}
 53	n.count = count
 54}
 55
 56func (n *node) numOfChildren() int {
 57	return len(n.children)
 58}
 59
 60func (n *node) isLeaf() bool {
 61	return n.numOfChildren() == 0
 62}
 63
 64func (n *node) firstChild() *node {
 65	return n.children[0]
 66}
 67
 68func (n *node) leftChild(keyIdx int) *node {
 69	return n.children[keyIdx]
 70}
 71
 72func (n *node) rightChild(keyIdx int) *node {
 73	return n.children[keyIdx+1]
 74}
 75
 76func (n *node) lastChild() *node {
 77	return n.children[len(n.children)-1]
 78}
 79
 80func (n *node) lastKey() key {
 81	return n.keys[n.size()-1]
 82}
 83
 84func (n *node) removeFirstChild() {
 85	n.children = n.children[1:]
 86}
 87
 88func (n *node) removeLastChild() {
 89	n.children = n.children[:n.numOfChildren()-1]
 90}
 91
 92func (n *node) lowerBound(key key) (idx int, found bool) {
 93	// idx is the key index when found, or the child/search index otherwise.
 94	idx = sort.Search(n.size(), func(i int) bool {
 95		return !n.keys[i].Less(key)
 96	})
 97
 98	if idx >= n.size() || !equalKey(n.keys[idx], key) {
 99		return idx, false
100	}
101
102	return idx, true
103}
104
105func (n *node) get(key key) (any, bool) {
106	// idx is a child index when the key is not found.
107	idx, found := n.lowerBound(key)
108	if found {
109		return n.values[idx], true
110	}
111	if len(n.children) == 0 {
112		return nil, false
113	}
114	return n.children[idx].get(key)
115}
116
117func (n *node) getAt(keyIdx int) (key, any) {
118	return n.keys[keyIdx], n.values[keyIdx]
119}
120
121func (n *node) updateIfFound(key key, value any) (int, bool) {
122	idx, found := n.lowerBound(key)
123	if found {
124		n.values[idx] = value
125		return idx, true
126	}
127	return idx, false
128}
129
130func (n *node) update(key key, value any) bool {
131	idx, found := n.updateIfFound(key, value)
132	if found {
133		return true
134	}
135	if n.isLeaf() {
136		return false
137	}
138	return n.children[idx].update(key, value)
139}
140
141func (n *node) updateAt(keyIdx int, key key, value any) {
142	n.keys[keyIdx] = key
143	n.values[keyIdx] = value
144}
145
146func (n *node) insertAt(keyIdx int, key key, value any) {
147	n.keys = append(n.keys, nil)
148	n.values = append(n.values, nil)
149	copy(n.keys[keyIdx+1:], n.keys[keyIdx:])
150	copy(n.values[keyIdx+1:], n.values[keyIdx:])
151	n.keys[keyIdx] = key
152	n.values[keyIdx] = value
153}
154
155func (n *node) removeAt(keyIdx int) {
156	n.keys = append(n.keys[:keyIdx], n.keys[keyIdx+1:]...)
157	n.values = append(n.values[:keyIdx], n.values[keyIdx+1:]...)
158}
159
160func (n *node) removeChildAt(childIdx int) {
161	n.children = append(n.children[:childIdx], n.children[childIdx+1:]...)
162}
163
164func (n *node) insertChildAt(childIdx int, child *node) {
165	n.children = append(n.children, nil)
166	copy(n.children[childIdx+1:], n.children[childIdx:])
167	n.children[childIdx] = child
168}
169
170func (n *node) setNonFull(key key, value any, maxKeys int) bool {
171	// idx is the lower-bound result before it is refined into childIdx.
172	idx, updated := n.updateIfFound(key, value)
173	if updated {
174		return false
175	}
176
177	if n.isLeaf() {
178		n.insertAt(idx, key, value)
179		n.count++
180		return true
181	}
182
183	childIdx := idx
184	if n.children[childIdx].size() >= maxKeys {
185		n.split(childIdx, maxKeys, n.isAppendToChild(childIdx, key))
186		childIdx, updated = n.updateIfFound(key, value)
187		if updated {
188			return false
189		}
190	}
191
192	inserted := n.children[childIdx].setNonFull(key, value, maxKeys)
193	if inserted {
194		n.count++
195	}
196	return inserted
197}
198
199func (n *node) isAppendToChild(childIdx int, key key) bool {
200	if childIdx != n.numOfChildren()-1 {
201		return false
202	}
203
204	child := n.children[childIdx]
205	if child.size() == 0 {
206		return true
207	}
208	return child.lastKey().Less(key)
209}
210
211func (n *node) split(childIdx int, maxKeys int, appendSplit bool) {
212	mid := maxKeys / 2
213	if appendSplit {
214		mid = maxKeys - 1
215	}
216	right, promotedKey, promotedValue := n.separate(childIdx, mid, maxKeys)
217	n.insertAt(childIdx, promotedKey, promotedValue)
218	n.insertChildAt(childIdx+1, right)
219	n.refreshCount()
220}
221
222func (n *node) separate(childIdx, mid, maxKeys int) (*node, key, any) {
223	left := n.children[childIdx]
224	right := newNodeWithEntries(left.keys[mid+1:], left.values[mid+1:], maxKeys)
225
226	promotedKey, promotedValue := left.getAt(mid)
227	left.keys = left.keys[:mid]
228	left.values = left.values[:mid]
229
230	if !left.isLeaf() {
231		right.children = append([]*node(nil), left.children[mid+1:]...)
232		left.children = left.children[:mid+1]
233	}
234
235	left.refreshCount()
236	right.refreshCount()
237	return right, promotedKey, promotedValue
238}
239
240func (n *node) remove(key key, minKeys int) (any, bool) {
241	// idx is the key index when found, or the child index to descend into.
242	idx, found := n.lowerBound(key)
243
244	if n.isLeaf() {
245		if found {
246			v := n.values[idx]
247			n.removeAt(idx)
248			n.count--
249			return v, true
250		}
251		return nil, false
252	}
253
254	if found {
255		if n.leftChild(idx).size() > minKeys {
256			v, removed := n.replaceWithLeftMax(idx, minKeys)
257			if removed {
258				n.count--
259			}
260			return v, removed
261		}
262		if idx < n.numOfChildren()-1 && n.rightChild(idx).size() > minKeys {
263			v, removed := n.replaceWithRightMin(idx, minKeys)
264			if removed {
265				n.count--
266			}
267			return v, removed
268		}
269		n.mergeChildren(idx)
270		v, removed := n.children[idx].remove(key, minKeys)
271		if !removed {
272			panic("merged key not found")
273		}
274		n.count--
275		return v, true
276	}
277
278	childIdx := idx
279	if n.children[childIdx].size() <= minKeys {
280		childIdx = n.fill(childIdx, minKeys)
281	}
282	v, removed := n.children[childIdx].remove(key, minKeys)
283	if removed {
284		n.count--
285	}
286	return v, removed
287}
288
289func (n *node) replaceWithLeftMax(keyIdx, minKeys int) (any, bool) {
290	old := n.values[keyIdx]
291	k, v := n.leftChild(keyIdx).removeMax(minKeys)
292	n.updateAt(keyIdx, k, v)
293	return old, true
294}
295
296func (n *node) removeMax(minKeys int) (key, any) {
297	if n.isLeaf() {
298		lastKeyIdx := n.size() - 1
299		k, v := n.getAt(lastKeyIdx)
300		n.removeAt(lastKeyIdx)
301		n.count--
302		return k, v
303	}
304
305	if n.lastChild().size() <= minKeys {
306		lastChildIdx := n.numOfChildren() - 1
307		n.fill(lastChildIdx, minKeys)
308	}
309	k, v := n.lastChild().removeMax(minKeys)
310	n.count--
311	return k, v
312}
313
314func (n *node) replaceWithRightMin(keyIdx, minKeys int) (any, bool) {
315	old := n.values[keyIdx]
316	k, v := n.rightChild(keyIdx).removeMin(minKeys)
317	n.updateAt(keyIdx, k, v)
318	return old, true
319}
320
321func (n *node) removeMin(minKeys int) (key, any) {
322	if n.isLeaf() {
323		k, v := n.getAt(0)
324		n.removeAt(0)
325		n.count--
326		return k, v
327	}
328
329	if n.firstChild().size() <= minKeys {
330		n.fill(0, minKeys)
331	}
332	k, v := n.firstChild().removeMin(minKeys)
333	n.count--
334	return k, v
335}
336
337func (n *node) fill(childIdx int, minKeys int) int {
338	if childIdx-1 >= 0 && n.children[childIdx-1].size() > minKeys {
339		n.borrowFromLeft(childIdx)
340		return childIdx
341	} else if childIdx+1 < n.numOfChildren() && n.children[childIdx+1].size() > minKeys {
342		n.borrowFromRight(childIdx)
343		return childIdx
344	} else if childIdx+1 < n.numOfChildren() {
345		n.mergeChildren(childIdx)
346		return childIdx
347	} else {
348		n.mergeChildren(childIdx - 1)
349		return childIdx - 1
350	}
351}
352
353func (n *node) borrowFromLeft(childIdx int) {
354	parentKey, parentValue := n.getAt(childIdx - 1)
355	leftSibling := n.children[childIdx-1]
356	leftLast := leftSibling.size() - 1
357	leftLastKey, leftLastValue := leftSibling.getAt(leftLast)
358
359	child := n.children[childIdx]
360	child.insertAt(0, parentKey, parentValue)
361
362	n.updateAt(childIdx-1, leftLastKey, leftLastValue)
363	leftSibling.removeAt(leftLast)
364
365	if !leftSibling.isLeaf() {
366		child.insertChildAt(0, leftSibling.lastChild())
367		leftSibling.removeLastChild()
368	}
369
370	leftSibling.refreshCount()
371	child.refreshCount()
372	n.refreshCount()
373}
374
375func (n *node) borrowFromRight(childIdx int) {
376	parentKey, parentValue := n.getAt(childIdx)
377	rightSibling := n.children[childIdx+1]
378	rightFirstKey, rightFirstValue := rightSibling.getAt(0)
379
380	child := n.children[childIdx]
381	child.insertAt(child.size(), parentKey, parentValue)
382
383	n.updateAt(childIdx, rightFirstKey, rightFirstValue)
384	rightSibling.removeAt(0)
385
386	if !rightSibling.isLeaf() {
387		child.insertChildAt(child.numOfChildren(), rightSibling.firstChild())
388		rightSibling.removeFirstChild()
389	}
390
391	rightSibling.refreshCount()
392	child.refreshCount()
393	n.refreshCount()
394}
395
396func (n *node) mergeChildren(keyIdx int) {
397	parentKey := n.keys[keyIdx]
398	parentValue := n.values[keyIdx]
399	left := n.children[keyIdx]
400	right := n.children[keyIdx+1]
401
402	left.keys = append(left.keys, parentKey)
403	left.values = append(left.values, parentValue)
404
405	left.keys = append(left.keys, right.keys...)
406	left.values = append(left.values, right.values...)
407
408	if !right.isLeaf() {
409		left.children = append(left.children, right.children...)
410	}
411
412	n.removeAt(keyIdx)
413	n.removeChildAt(keyIdx + 1)
414	left.refreshCount()
415	n.refreshCount()
416}
417
418func (n *node) getByIndex(index int) (key, any) {
419	rem := index
420	for i := 0; i < n.size(); i++ {
421		if !n.isLeaf() {
422			childSize := n.children[i].subtreeSize()
423			if rem < childSize {
424				return n.children[i].getByIndex(rem)
425			}
426			rem -= childSize
427		}
428		if rem == 0 {
429			return n.getAt(i)
430		}
431		rem--
432	}
433	if !n.isLeaf() && rem < n.children[n.size()].subtreeSize() {
434		return n.children[n.size()].getByIndex(rem)
435	}
436	panic("GetByIndex asked for invalid index")
437}
438
439func (n *node) iterate(start, end key, cb iterCbFn) bool {
440	for i := 0; i < n.size(); i++ {
441		if !n.isLeaf() {
442			if n.children[i].iterate(start, end, cb) {
443				return true
444			}
445		}
446		key, value := n.getAt(i)
447		if start != nil && key.Less(start) {
448			continue
449		}
450		if end != nil && !key.Less(end) {
451			return false
452		}
453		if cb(key, value) {
454			return true
455		}
456	}
457	if !n.isLeaf() {
458		return n.children[n.size()].iterate(start, end, cb)
459	}
460	return false
461}
462
463func (n *node) reverseIterate(start, end key, cb iterCbFn) bool {
464	for i := n.size(); i >= 0; i-- {
465		if !n.isLeaf() {
466			if n.children[i].reverseIterate(start, end, cb) {
467				return true
468			}
469		}
470		if i == 0 {
471			break
472		}
473
474		keyIdx := i - 1
475		key, value := n.getAt(keyIdx)
476		if end != nil && end.Less(key) {
477			continue
478		}
479		if start != nil && key.Less(start) {
480			return false
481		}
482		if cb(key, value) {
483			return true
484		}
485	}
486	return false
487}
488
489func (n *node) iterateByOffset(
490	offset int,
491	limit int,
492	reverse bool,
493	seen *int,
494	visited *int,
495	cb iterCbFn,
496) bool {
497	if n == nil || *visited >= limit {
498		return false
499	}
500	if reverse {
501		return n.reverseIterateByOffset(offset, limit, seen, visited, cb)
502	}
503	return n.forwardIterateByOffset(offset, limit, seen, visited, cb)
504}
505
506func (n *node) forwardIterateByOffset(
507	offset int,
508	limit int,
509	seen *int,
510	visited *int,
511	cb iterCbFn,
512) bool {
513	for i := 0; i < n.size(); i++ {
514		if !n.isLeaf() {
515			if visitChildByOffset(n.children[i], offset, limit, false, seen, visited, cb) {
516				return true
517			}
518			if *visited >= limit {
519				return false
520			}
521		}
522		if visitKeyByOffset(n.keys[i], n.values[i], offset, limit, seen, visited, cb) {
523			return true
524		}
525		if *visited >= limit {
526			return false
527		}
528	}
529	if !n.isLeaf() {
530		return visitChildByOffset(n.children[n.size()], offset, limit, false, seen, visited, cb)
531	}
532	return false
533}
534
535func (n *node) reverseIterateByOffset(
536	offset int,
537	limit int,
538	seen *int,
539	visited *int,
540	cb iterCbFn,
541) bool {
542	if !n.isLeaf() {
543		if visitChildByOffset(n.children[n.size()], offset, limit, true, seen, visited, cb) {
544			return true
545		}
546		if *visited >= limit {
547			return false
548		}
549	}
550	for i := n.size() - 1; i >= 0; i-- {
551		if visitKeyByOffset(n.keys[i], n.values[i], offset, limit, seen, visited, cb) {
552			return true
553		}
554		if *visited >= limit {
555			return false
556		}
557		if !n.isLeaf() {
558			if visitChildByOffset(n.children[i], offset, limit, true, seen, visited, cb) {
559				return true
560			}
561			if *visited >= limit {
562				return false
563			}
564		}
565	}
566	return false
567}
568
569func visitChildByOffset(
570	child *node,
571	offset int,
572	limit int,
573	reverse bool,
574	seen *int,
575	visited *int,
576	cb iterCbFn,
577) bool {
578	if child == nil || *visited >= limit {
579		return false
580	}
581	if *seen+child.subtreeSize() <= offset {
582		*seen += child.subtreeSize()
583		return false
584	}
585	return child.iterateByOffset(offset, limit, reverse, seen, visited, cb)
586}
587
588func visitKeyByOffset(
589	key key,
590	value any,
591	offset int,
592	limit int,
593	seen *int,
594	visited *int,
595	cb iterCbFn,
596) bool {
597	if *seen < offset {
598		*seen++
599		return false
600	}
601	if *visited >= limit {
602		return false
603	}
604	if cb(key, value) {
605		return true
606	}
607	*visited++
608	return false
609}