// Copyright 2020 Joshua J Baker. All rights reserved.
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file.

package btree

import (
	"sync"
)

const maxItems = 255 // max items per node. max children is +1
const minItems = maxItems * 40 / 100

type cow struct {
	_ int // it cannot be an empty struct
}

type node struct {
	cow      *cow
	leaf     bool
	numItems int16
	count    int
	items    [maxItems]interface{}
	children *[maxItems + 1]*node
}

// BTree is an ordered set items
type BTree struct {
	mu    *sync.RWMutex
	cow   *cow
	root  *node
	count int
	less  func(a, b interface{}) bool
	locks bool
}

func (tr *BTree) newNode(leaf bool) *node {
	n := &node{leaf: leaf}
	if !leaf {
		n.children = new([maxItems + 1]*node)
	}
	n.cow = tr.cow
	return n
}

// PathHint is a utility type used with the *Hint() functions. Hints provide
// faster operations for clustered keys.
type PathHint struct {
	used [8]bool
	path [8]uint8
}

// New returns a new BTree
func New(less func(a, b interface{}) bool) *BTree {
	return newBTree(less, true)
}

// NewNonConcurrent returns a new BTree which is not safe for concurrent
// write operations by multiple goroutines.
//
// This is useful for when you do not need the BTree to manage the locking,
// but would rather do it yourself.
func NewNonConcurrent(less func(a, b interface{}) bool) *BTree {
	return newBTree(less, false)
}

func newBTree(less func(a, b interface{}) bool, locks bool) *BTree {
	if less == nil {
		panic("nil less")
	}
	tr := new(BTree)
	tr.mu = new(sync.RWMutex)
	tr.less = less
	tr.locks = locks
	return tr
}

// Less is a convenience function that performs a comparison of two items
// using the same "less" function provided to New.
func (tr *BTree) Less(a, b interface{}) bool {
	return tr.less(a, b)
}

func (n *node) find(key interface{}, less func(a, b interface{}) bool,
	hint *PathHint, depth int,
) (index int16, found bool) {
	low := int16(0)
	high := n.numItems - 1
	if hint != nil && depth < 8 && hint.used[depth] {
		index = int16(hint.path[depth])
		if index >= n.numItems {
			// tail item
			if less(n.items[n.numItems-1], key) {
				if less(key, n.items[n.numItems-1]) {
					index = n.numItems - 1
					found = true
					goto path_match
				} else {
					index = n.numItems
					goto path_match
				}
			}
			index = n.numItems - 1
		}
		if less(key, n.items[index]) {
			if index == 0 || less(n.items[index-1], key) {
				goto path_match
			}
			high = index - 1
		} else if less(n.items[index], key) {
			low = index + 1
		} else {
			found = true
			goto path_match
		}
	}
	for low <= high {
		mid := low + ((high+1)-low)/2
		if !less(key, n.items[mid]) {
			low = mid + 1
		} else {
			high = mid - 1
		}
	}
	if low > 0 && !less(n.items[low-1], key) {
		index = low - 1
		found = true
	} else {
		index = low
		found = false
	}
	if hint == nil || depth >= 8 {
		return index, found
	}

path_match:
	hint.used[depth] = true
	if n.leaf && found {
		hint.path[depth] = byte(index + 1)
	} else {
		hint.path[depth] = byte(index)
	}
	return index, found
}

// SetHint sets or replace a value for a key using a path hint
func (tr *BTree) SetHint(item interface{}, hint *PathHint) (prev interface{}) {
	if item == nil {
		panic("nil item")
	}
	if tr.lock() {
		defer tr.unlock()
	}
	return tr.setHint(item, hint)
}

func (tr *BTree) setHint(item interface{}, hint *PathHint) (prev interface{}) {
	if tr.root == nil {
		tr.root = tr.newNode(true)
		tr.root.items[0] = item
		tr.root.numItems = 1
		tr.root.count = 1
		tr.count = 1
		return
	}
	prev = tr.nodeSet(&tr.root, item, tr.less, hint, 0)
	if prev != nil {
		return prev
	}
	if tr.root.numItems == maxItems {
		n := tr.cowLoad(&tr.root)
		right, median := tr.nodeSplit(n)
		tr.root = tr.newNode(false)
		tr.root.children[0] = n
		tr.root.items[0] = median
		tr.root.children[1] = right
		tr.root.numItems = 1
		tr.root.count = n.count + 1 + right.count
	}
	tr.count++
	return prev
}

// Set or replace a value for a key
func (tr *BTree) Set(item interface{}) (prev interface{}) {
	return tr.SetHint(item, nil)
}

func (tr *BTree) nodeSplit(n *node) (right *node, median interface{}) {
	right = tr.newNode(n.leaf)
	median = n.items[maxItems/2]
	copy(right.items[:maxItems/2], n.items[maxItems/2+1:])
	if !n.leaf {
		copy(right.children[:maxItems/2+1], n.children[maxItems/2+1:])
	}
	right.numItems = maxItems / 2
	if !n.leaf {
		for i := maxItems/2 + 1; i < maxItems+1; i++ {
			n.children[i] = nil
		}
	}
	for i := maxItems / 2; i < maxItems; i++ {
		n.items[i] = nil
	}
	n.numItems = maxItems / 2
	// update counts
	n.updateCount()
	right.updateCount()
	return right, median
}

func (n *node) updateCount() {
	n.count = int(n.numItems)
	if !n.leaf {
		for i := 0; i <= int(n.numItems); i++ {
			n.count += n.children[i].count
		}
	}
}

// This operation should not be inlined because it's expensive and rarely
// called outside of heavy copy-on-write situations. Marking it "noinline"
// allows for the parent cowLoad to be inlined.
// go:noinline
func (tr *BTree) copy(n *node) *node {
	n2 := *n
	n2.cow = tr.cow
	copy(n2.items[:], n.items[:])
	if n.children != nil {
		n2.children = new([maxItems + 1]*node)
		copy(n2.children[:], n.children[:])
	}
	return &n2
}

// cowLoad loads the provide node and, if needed, performs a copy-on-write.
func (tr *BTree) cowLoad(cn **node) *node {
	if (*cn).cow != tr.cow {
		*cn = tr.copy(*cn)
	}
	return *cn
}

func (tr *BTree) nodeSet(cn **node, item interface{},
	less func(a, b interface{}) bool, hint *PathHint, depth int,
) (prev interface{}) {
	n := tr.cowLoad(cn)
	i, found := n.find(item, less, hint, depth)
	if found {
		prev = n.items[i]
		n.items[i] = item
		return prev
	}
	if n.leaf {
		copy(n.items[i+1:n.numItems+1], n.items[i:n.numItems])
		n.items[i] = item
		n.numItems++
		n.count++
		return nil
	}
	prev = tr.nodeSet(&n.children[i], item, less, hint, depth+1)
	if prev != nil {
		return prev
	}
	if n.children[i].numItems == maxItems {
		right, median := tr.nodeSplit(n.children[i])
		copy(n.children[i+1:], n.children[i:])
		copy(n.items[i+1:], n.items[i:])
		n.items[i] = median
		n.children[i+1] = right
		n.numItems++
	}
	n.count++
	return nil
}

func (n *node) scan(iter func(item interface{}) bool) bool {
	if n.leaf {
		for i := int16(0); i < n.numItems; i++ {
			if !iter(n.items[i]) {
				return false
			}
		}
		return true
	}
	for i := int16(0); i < n.numItems; i++ {
		if !n.children[i].scan(iter) {
			return false
		}
		if !iter(n.items[i]) {
			return false
		}
	}
	return n.children[n.numItems].scan(iter)
}

// Get a value for key
func (tr *BTree) Get(key interface{}) interface{} {
	return tr.GetHint(key, nil)
}

// GetHint gets a value for key using a path hint
func (tr *BTree) GetHint(key interface{}, hint *PathHint) interface{} {
	if tr.rlock() {
		defer tr.runlock()
	}
	if tr.root == nil || key == nil {
		return nil
	}
	depth := 0
	n := tr.root
	for {
		i, found := n.find(key, tr.less, hint, depth)
		if found {
			return n.items[i]
		}
		if n.leaf {
			return nil
		}
		n = n.children[i]
		depth++
	}
}

// Len returns the number of items in the tree
func (tr *BTree) Len() int {
	return tr.count
}

// Delete a value for a key
func (tr *BTree) Delete(key interface{}) interface{} {
	return tr.DeleteHint(key, nil)
}

// DeleteHint deletes a value for a key using a path hint
func (tr *BTree) DeleteHint(key interface{}, hint *PathHint) interface{} {
	if tr.lock() {
		defer tr.unlock()
	}
	return tr.deleteHint(key, hint)
}

func (tr *BTree) deleteHint(key interface{}, hint *PathHint) interface{} {
	if tr.root == nil || key == nil {
		return nil
	}
	prev := tr.delete(&tr.root, false, key, tr.less, hint, 0)
	if prev == nil {
		return nil
	}
	if tr.root.numItems == 0 && !tr.root.leaf {
		tr.root = tr.root.children[0]
	}
	tr.count--
	if tr.count == 0 {
		tr.root = nil
	}
	return prev
}

func (tr *BTree) delete(cn **node, max bool, key interface{},
	less func(a, b interface{}) bool, hint *PathHint, depth int,
) interface{} {
	n := tr.cowLoad(cn)
	var i int16
	var found bool
	if max {
		i, found = n.numItems-1, true
	} else {
		i, found = n.find(key, less, hint, depth)
	}
	if n.leaf {
		if found {
			prev := n.items[i]
			// found the items at the leaf, remove it and return.
			copy(n.items[i:], n.items[i+1:n.numItems])
			n.items[n.numItems-1] = nil
			n.numItems--
			n.count--
			return prev
		}
		return nil
	}

	var prev interface{}
	if found {
		if max {
			i++
			prev = tr.delete(&n.children[i], true, "", less, nil, 0)
		} else {
			prev = n.items[i]
			maxItem := tr.delete(&n.children[i], true, "", less, nil, 0)
			n.items[i] = maxItem
		}
	} else {
		prev = tr.delete(&n.children[i], max, key, less, hint, depth+1)
	}
	if prev == nil {
		return nil
	}
	n.count--
	if n.children[i].numItems >= minItems {
		return prev
	}

	// merge / rebalance nodes
	if i == n.numItems {
		i--
	}
	n.children[i] = tr.cowLoad(&n.children[i])
	n.children[i+1] = tr.cowLoad(&n.children[i+1])
	if n.children[i].numItems+n.children[i+1].numItems+1 < maxItems {
		// merge left + item + right
		n.children[i].items[n.children[i].numItems] = n.items[i]
		copy(n.children[i].items[n.children[i].numItems+1:],
			n.children[i+1].items[:n.children[i+1].numItems])
		if !n.children[0].leaf {
			copy(n.children[i].children[n.children[i].numItems+1:],
				n.children[i+1].children[:n.children[i+1].numItems+1])
		}
		n.children[i].numItems += n.children[i+1].numItems + 1
		n.children[i].count += n.children[i+1].count + 1
		copy(n.items[i:], n.items[i+1:n.numItems])
		copy(n.children[i+1:], n.children[i+2:n.numItems+1])
		n.items[n.numItems-1] = nil
		n.children[n.numItems] = nil
		n.numItems--
	} else if n.children[i].numItems > n.children[i+1].numItems {
		// move left -> right
		copy(n.children[i+1].items[1:],
			n.children[i+1].items[:n.children[i+1].numItems])
		if !n.children[0].leaf {
			copy(n.children[i+1].children[1:],
				n.children[i+1].children[:n.children[i+1].numItems+1])
		}
		n.children[i+1].items[0] = n.items[i]
		if !n.children[0].leaf {
			n.children[i+1].children[0] =
				n.children[i].children[n.children[i].numItems]
			n.children[i+1].count += n.children[i+1].children[0].count
		}
		n.children[i+1].numItems++
		n.children[i+1].count++
		n.items[i] = n.children[i].items[n.children[i].numItems-1]
		n.children[i].items[n.children[i].numItems-1] = nil
		if !n.children[0].leaf {
			n.children[i].children[n.children[i].numItems] = nil
			n.children[i].count -= n.children[i+1].children[0].count
		}
		n.children[i].numItems--
		n.children[i].count--
	} else {
		// move left <- right
		n.children[i].items[n.children[i].numItems] = n.items[i]
		if !n.children[0].leaf {
			n.children[i].children[n.children[i].numItems+1] =
				n.children[i+1].children[0]
			n.children[i].count +=
				n.children[i].children[n.children[i].numItems+1].count
		}
		n.children[i].numItems++
		n.children[i].count++
		n.items[i] = n.children[i+1].items[0]
		copy(n.children[i+1].items[:],
			n.children[i+1].items[1:n.children[i+1].numItems])
		n.children[i+1].items[n.children[i+1].numItems-1] = nil
		if !n.children[0].leaf {
			copy(n.children[i+1].children[:],
				n.children[i+1].children[1:n.children[i+1].numItems+1])
			n.children[i+1].children[n.children[i+1].numItems] = nil
			n.children[i+1].count -=
				n.children[i].children[n.children[i].numItems].count
		}
		n.children[i+1].numItems--
		n.children[i+1].count--
	}
	return prev
}

// Ascend the tree within the range [pivot, last]
// Pass nil for pivot to scan all item in ascending order
// Return false to stop iterating
func (tr *BTree) Ascend(pivot interface{}, iter func(item interface{}) bool) {
	if tr.rlock() {
		defer tr.runlock()
	}
	if tr.root == nil {
		return
	}
	if pivot == nil {
		tr.root.scan(iter)
	} else if tr.root != nil {
		tr.root.ascend(pivot, tr.less, nil, 0, iter)
	}
}

func (n *node) ascend(pivot interface{}, less func(a, b interface{}) bool,
	hint *PathHint, depth int, iter func(item interface{}) bool,
) bool {
	i, found := n.find(pivot, less, hint, depth)
	if !found {
		if !n.leaf {
			if !n.children[i].ascend(pivot, less, hint, depth+1, iter) {
				return false
			}
		}
	}
	for ; i < n.numItems; i++ {
		if !iter(n.items[i]) {
			return false
		}
		if !n.leaf {
			if !n.children[i+1].scan(iter) {
				return false
			}
		}
	}
	return true
}

func (n *node) reverse(iter func(item interface{}) bool) bool {
	if n.leaf {
		for i := n.numItems - 1; i >= 0; i-- {
			if !iter(n.items[i]) {
				return false
			}
		}
		return true
	}
	if !n.children[n.numItems].reverse(iter) {
		return false
	}
	for i := n.numItems - 1; i >= 0; i-- {
		if !iter(n.items[i]) {
			return false
		}
		if !n.children[i].reverse(iter) {
			return false
		}
	}
	return true
}

// Descend the tree within the range [pivot, first]
// Pass nil for pivot to scan all item in descending order
// Return false to stop iterating
func (tr *BTree) Descend(pivot interface{}, iter func(item interface{}) bool) {
	if tr.rlock() {
		defer tr.runlock()
	}
	if tr.root == nil {
		return
	}
	if pivot == nil {
		tr.root.reverse(iter)
	} else if tr.root != nil {
		tr.root.descend(pivot, tr.less, nil, 0, iter)
	}
}

func (n *node) descend(pivot interface{}, less func(a, b interface{}) bool,
	hint *PathHint, depth int, iter func(item interface{}) bool,
) bool {
	i, found := n.find(pivot, less, hint, depth)
	if !found {
		if !n.leaf {
			if !n.children[i].descend(pivot, less, hint, depth+1, iter) {
				return false
			}
		}
		i--
	}
	for ; i >= 0; i-- {
		if !iter(n.items[i]) {
			return false
		}
		if !n.leaf {
			if !n.children[i].reverse(iter) {
				return false
			}
		}
	}
	return true
}

// Load is for bulk loading pre-sorted items
func (tr *BTree) Load(item interface{}) interface{} {
	if item == nil {
		panic("nil item")
	}
	if tr.lock() {
		defer tr.unlock()
	}
	if tr.root == nil {
		return tr.setHint(item, nil)
	}
	n := tr.cowLoad(&tr.root)
	for {
		n.count++ // optimistically update counts
		if n.leaf {
			if n.numItems < maxItems-2 {
				if tr.less(n.items[n.numItems-1], item) {
					n.items[n.numItems] = item
					n.numItems++
					tr.count++
					return nil
				}
			}
			break
		}
		n = tr.cowLoad(&n.children[n.numItems])
	}
	// revert the counts
	n = tr.root
	for {
		n.count--
		if n.leaf {
			break
		}
		n = n.children[n.numItems]
	}
	return tr.setHint(item, nil)
}

// Min returns the minimum item in tree.
// Returns nil if the tree has no items.
func (tr *BTree) Min() interface{} {
	if tr.rlock() {
		defer tr.runlock()
	}
	if tr.root == nil {
		return nil
	}
	n := tr.root
	for {
		if n.leaf {
			return n.items[0]
		}
		n = n.children[0]
	}
}

// Max returns the maximum item in tree.
// Returns nil if the tree has no items.
func (tr *BTree) Max() interface{} {
	if tr.rlock() {
		defer tr.runlock()
	}
	if tr.root == nil {
		return nil
	}
	n := tr.root
	for {
		if n.leaf {
			return n.items[n.numItems-1]
		}
		n = n.children[n.numItems]
	}
}

// PopMin removes the minimum item in tree and returns it.
// Returns nil if the tree has no items.
func (tr *BTree) PopMin() interface{} {
	if tr.lock() {
		defer tr.unlock()
	}
	if tr.root == nil {
		return nil
	}
	n := tr.cowLoad(&tr.root)
	var item interface{}
	for {
		n.count-- // optimistically update counts
		if n.leaf {
			item = n.items[0]
			if n.numItems == minItems {
				break
			}
			copy(n.items[:], n.items[1:])
			n.items[n.numItems-1] = nil
			n.numItems--
			tr.count--
			if tr.count == 0 {
				tr.root = nil
			}
			return item
		}
		n = tr.cowLoad(&n.children[0])
	}
	// revert the counts
	n = tr.root
	for {
		n.count++
		if n.leaf {
			break
		}
		n = n.children[0]
	}
	return tr.deleteHint(item, nil)
}

// PopMax removes the minimum item in tree and returns it.
// Returns nil if the tree has no items.
func (tr *BTree) PopMax() interface{} {
	if tr.lock() {
		defer tr.unlock()
	}
	if tr.root == nil {
		return nil
	}
	n := tr.cowLoad(&tr.root)
	var item interface{}
	for {
		n.count-- // optimistically update counts
		if n.leaf {
			item = n.items[n.numItems-1]
			if n.numItems == minItems {
				break
			}
			n.items[n.numItems-1] = nil
			n.numItems--
			tr.count--
			if tr.count == 0 {
				tr.root = nil
			}
			return item
		}
		n = tr.cowLoad(&n.children[n.numItems])
	}
	// revert the counts
	n = tr.root
	for {
		n.count++
		if n.leaf {
			break
		}
		n = n.children[n.numItems]
	}
	return tr.deleteHint(item, nil)
}

// GetAt returns the value at index.
// Return nil if the tree is empty or the index is out of bounds.
func (tr *BTree) GetAt(index int) interface{} {
	if tr.rlock() {
		defer tr.runlock()
	}
	if tr.root == nil || index < 0 || index >= tr.count {
		return nil
	}
	n := tr.root
	for {
		if n.leaf {
			return n.items[index]
		}
		i := 0
		for ; i < int(n.numItems); i++ {
			if index < n.children[i].count {
				break
			} else if index == n.children[i].count {
				return n.items[i]
			}
			index -= n.children[i].count + 1
		}
		n = n.children[i]
	}
}

// DeleteAt deletes the item at index.
// Return nil if the tree is empty or the index is out of bounds.
func (tr *BTree) DeleteAt(index int) interface{} {
	if tr.lock() {
		defer tr.unlock()
	}
	if tr.root == nil || index < 0 || index >= tr.count {
		return nil
	}
	var pathbuf [8]uint8 // track the path
	path := pathbuf[:0]
	var item interface{}
	n := tr.cowLoad(&tr.root)
outer:
	for {
		n.count-- // optimistically update counts
		if n.leaf {
			// the index is the item position
			item = n.items[index]
			if n.numItems == minItems {
				path = append(path, uint8(index))
				break outer
			}
			copy(n.items[index:], n.items[index+1:n.numItems])
			n.items[n.numItems-1] = nil
			n.numItems--
			tr.count--
			if tr.count == 0 {
				tr.root = nil
			}
			return item
		}
		i := 0
		for ; i < int(n.numItems); i++ {
			if index < n.children[i].count {
				break
			} else if index == n.children[i].count {
				item = n.items[i]
				path = append(path, uint8(i))
				break outer
			}
			index -= n.children[i].count + 1
		}
		path = append(path, uint8(i))
		n = tr.cowLoad(&n.children[i])
	}
	// revert the counts
	var hint PathHint
	n = tr.root
	for i := 0; i < len(path); i++ {
		if i < len(hint.path) {
			hint.path[i] = path[i]
			hint.used[i] = true
		}
		n.count++
		if !n.leaf {
			n = n.children[uint8(path[i])]
		}
	}
	return tr.deleteHint(item, &hint)
}

// Height returns the height of the tree.
// Returns zero if tree has no items.
func (tr *BTree) Height() int {
	if tr.rlock() {
		defer tr.runlock()
	}
	var height int
	if tr.root != nil {
		n := tr.root
		for {
			height++
			if n.leaf {
				break
			}
			n = n.children[n.numItems]
		}
	}
	return height
}

// Walk iterates over all items in tree, in order.
// The items param will contain one or more items.
func (tr *BTree) Walk(iter func(item []interface{})) {
	if tr.rlock() {
		defer tr.runlock()
	}
	if tr.root != nil {
		tr.root.walk(iter)
	}
}

func (n *node) walk(iter func(item []interface{})) {
	if n.leaf {
		iter(n.items[:n.numItems])
	} else {
		for i := int16(0); i < n.numItems; i++ {
			n.children[i].walk(iter)
			iter(n.items[i : i+1])
		}
		n.children[n.numItems].walk(iter)
	}
}

// Copy the tree. This operation is very fast because it only performs a
// shadowed copy.
func (tr *BTree) Copy() *BTree {
	if tr.lock() {
		defer tr.unlock()
	}
	tr.cow = new(cow)
	tr2 := *tr
	tr2.mu = new(sync.RWMutex)
	tr2.cow = new(cow)
	return &tr2
}

func (tr *BTree) lock() bool {
	if tr.locks {
		tr.mu.Lock()
	}
	return tr.locks
}

func (tr *BTree) unlock() {
	tr.mu.Unlock()
}

func (tr *BTree) rlock() bool {
	if tr.locks {
		tr.mu.RLock()
	}
	return tr.locks
}

func (tr *BTree) runlock() {
	tr.mu.RUnlock()
}