package me.eater.threedom.utils import me.eater.threedom.dom.IDocument import me.eater.threedom.dom.INode import me.eater.threedom.utils.joml.Vector3d import me.eater.threedom.utils.joml.compareTo import me.eater.threedom.utils.joml.getTranslation import org.joml.Vector3dc class KDTree(private val document: IDocument, private var root: Node = Node(Vector3d(0, 0, 0))) { private val nodeLocMap = mutableMapOf() data class Node( val vertex: Vector3dc = Vector3d(0, 0, 0), val nodeIds: MutableSet = mutableSetOf(), val depth: Long = 0, var left: Node? = null, var right: Node? = null ) { val axis: Int get() = (depth % 3).toInt() fun add(translation: Vector3dc, nodeId: Long) { var current = this while (true) { if (translation == current.vertex) { current.nodeIds.add(nodeId) return } if (translation[current.axis] < current.vertex[current.axis] || (translation[current.axis] == current.vertex[current.axis] && translation < current.vertex)) { if (current.left == null) { current.left = Node(translation, mutableSetOf(nodeId), current.depth + 1) return } else { current = current.left!! } } else { if (current.right == null) { current.right = Node(translation, mutableSetOf(nodeId), current.depth + 1) return } else { current = current.right!! } } } } fun remove(translation: Vector3dc, nodeId: Long) { find(translation)?.nodeIds?.remove(nodeId) } fun find(translation: Vector3dc): Node? { var current: Node? = this while (current != null) { if (translation == current.vertex) { return current } current = if (translation[current.axis] < current.vertex[current.axis] || (translation[current.axis] == current.vertex[current.axis] && translation < current.vertex)) { current.left } else { current.right } } return null } companion object { @Suppress("UNCHECKED_CAST") fun create(nodes: Collection>, depth: Long = 0): Node = create(nodes.groupBy({ it.absolute.getTranslation() }) { it.nodeId } as Map>, depth) fun create(nodes: Map>, depth: Long = 0): Node { if (nodes.isEmpty()) { return Node() } if (nodes.size == 1) { val (loc, onlyNodes) = nodes.entries.first() return Node(loc, onlyNodes.toMutableSet()) } val axis: Int = (depth % 3).toInt() val sorted = nodes.keys.sortedBy { it[axis] } val median = sorted.size / 2 val selected = sorted[median] val left = sorted.slice(0 until median).toSet().takeIf { it.isNotEmpty() }?.let { nodes.filterKeys(it::contains) }?.let { create(it, depth + 1) } val right = sorted.slice(median + 1 until sorted.size).toSet().takeIf { it.isNotEmpty() }?.let { nodes.filterKeys(it::contains) }?.let { create(it, depth + 1) } return Node(selected, nodes[selected]?.toMutableSet() ?: mutableSetOf(), depth, left, right) } } } constructor(document: IDocument, nodes: Collection>) : this(document, Node.create(nodes)) fun add(node: INode<*>) { val vec = node.absolute.getTranslation() nodeLocMap[node.nodeId] = vec root.add(vec, node.nodeId) } fun remove(node: INode<*>) { root.remove(node.absolute.getTranslation(), node.nodeId) nodeLocMap.remove(node.nodeId) } fun find(vertex: Vector3dc) = root.find(vertex)?.nodeIds?.mapNotNull(document::getNodeByNodeId) ?: emptyList() fun update(node: INode<*>) { nodeLocMap[node.nodeId]?.let { root.remove(it, node.nodeId) } add(node) } fun rebalance() { root = Node.create(nodeLocMap.entries.groupBy({ it.value }) { it.key }) } }